21 secs to upload 8 files, 32kB.
MRE below. Scales with N_EPOCHS
.
Not reproduced with LOG_MODEL != "all"
or OFFLINE = 1
.
Confirmed it’s model uploading, since:
wandb: ⣟ uploading artifact model-vjl4ilqt (13s)
wandb: ⣟ uploading artifact model-vjl4ilqt (12s)
peaks out at the observed total slowdown time.
# -*- coding: utf-8 -*-
import os
import wandb
import torch
import lightning as L
import torch.nn as nn
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from time import time
# KEY CONFIGS ---------
LOG_MODEL = (False, True, "all")[2] # "all"= slow
OFFLINE = 0 # (0 = slow)
# --- Configuration ---
# W&B
WANDB_PROJECT =
WANDB_ENTITY =
os.environ['WANDB_API_KEY'] =
# Data & Training
DATA_LEN = 8000
N_SAMPLES = 100
N_EPOCHS = 10
# Other
if OFFLINE:
os.environ["WANDB_MODE"] = "offline"
# Helpers --------------------------------------------------------------------
# Dataset helpers
class DummyAudioDataset(Dataset):
def __getitem__(self, idx):
x = torch.randn(1, DATA_LEN) # add channel dim
y = torch.randint(0, 10, (1,)).squeeze()
return x, y
def __len__(self):
return N_SAMPLES
# Model helpers
def build_cnn(input_channels=1):
return nn.Sequential(
nn.Conv1d(input_channels, 16, kernel_size=3, padding=1),
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(16, 10)
)
class AudioClassifierPL(L.LightningModule):
def __init__(self, data_len=8000):
super().__init__()
self.data_len = data_len
self.model = build_cnn()
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x) # x: (batch, 1, data_len)
def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.criterion(pred, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# Main logic -----------------------------------------------------------------
def main():
# W&B
wandb.login(key=os.environ['WANDB_API_KEY'])
wandb_logger = WandbLogger(
name="cnn_mre",
settings=wandb.Settings(save_code=False),
project=WANDB_PROJECT,
entity=WANDB_ENTITY,
log_model=LOG_MODEL
)
# Trigger logger init and get run details
_ = wandb_logger.experiment
# Data
dummy_dataset = DummyAudioDataset()
dummy_dataloader = DataLoader(
dummy_dataset, batch_size=16, shuffle=True, num_workers=0)
# PL module
pl_model = AudioClassifierPL(data_len=DATA_LEN)
# Trainer
trainer = L.Trainer(max_epochs=N_EPOCHS, logger=[wandb_logger],
callbacks=[], enable_progress_bar=True)
# Train
trainer.fit(pl_model, train_dataloaders=dummy_dataloader)
print("Training complete.")
# Slowdown occurs here
t1 = time()
if wandb.run:
wandb.finish()
print('\n\nTime after "Training complete.": %.3g' % (time() - t1))
if __name__ == "__main__":
t0 = time()
main()
print("Total time: %.3g" % (time() - t0))