`wandb.finish()` with `log_model="all"` is very slow when online

21 secs to upload 8 files, 32kB.

MRE below. Scales with N_EPOCHS.

Not reproduced with LOG_MODEL != "all" or OFFLINE = 1.

Confirmed it’s model uploading, since:

wandb: ⣟ uploading artifact model-vjl4ilqt (13s)
wandb: ⣟ uploading artifact model-vjl4ilqt (12s)

peaks out at the observed total slowdown time.

# -*- coding: utf-8 -*-
import os
import wandb
import torch
import lightning as L
import torch.nn as nn
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from time import time

# KEY CONFIGS ---------
LOG_MODEL = (False, True, "all")[2]  # "all"= slow
OFFLINE = 0  # (0 = slow)

# --- Configuration ---
# W&B
WANDB_PROJECT =
WANDB_ENTITY =
os.environ['WANDB_API_KEY'] =

# Data & Training
DATA_LEN = 8000
N_SAMPLES = 100
N_EPOCHS = 10

# Other
if OFFLINE:
    os.environ["WANDB_MODE"] = "offline"

# Helpers --------------------------------------------------------------------
# Dataset helpers
class DummyAudioDataset(Dataset):
    def __getitem__(self, idx):
        x = torch.randn(1, DATA_LEN)  # add channel dim
        y = torch.randint(0, 10, (1,)).squeeze()
        return x, y

    def __len__(self):
        return N_SAMPLES

# Model helpers
def build_cnn(input_channels=1):
    return nn.Sequential(
        nn.Conv1d(input_channels, 16, kernel_size=3, padding=1),
        nn.AdaptiveAvgPool1d(1),
        nn.Flatten(),
        nn.Linear(16, 10)
    )

class AudioClassifierPL(L.LightningModule):
    def __init__(self, data_len=8000):
        super().__init__()
        self.data_len = data_len
        self.model = build_cnn()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)  # x: (batch, 1, data_len)

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.criterion(pred, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Main logic -----------------------------------------------------------------
def main():
    # W&B
    wandb.login(key=os.environ['WANDB_API_KEY'])
    wandb_logger = WandbLogger(
        name="cnn_mre",
        settings=wandb.Settings(save_code=False),
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        log_model=LOG_MODEL
    )
    # Trigger logger init and get run details
    _ = wandb_logger.experiment

    # Data
    dummy_dataset = DummyAudioDataset()
    dummy_dataloader = DataLoader(
        dummy_dataset, batch_size=16, shuffle=True, num_workers=0)

    # PL module
    pl_model = AudioClassifierPL(data_len=DATA_LEN)

    # Trainer
    trainer = L.Trainer(max_epochs=N_EPOCHS, logger=[wandb_logger],
                        callbacks=[], enable_progress_bar=True)

    # Train
    trainer.fit(pl_model, train_dataloaders=dummy_dataloader)
    print("Training complete.")

    # Slowdown occurs here
    t1 = time()
    if wandb.run:
        wandb.finish()
    print('\n\nTime after "Training complete.": %.3g' % (time() - t1))


if __name__ == "__main__":
    t0 = time()
    main()
    print("Total time: %.3g" % (time() - t0))