Can someone help me to set up the WandbLogger with PyTorch Lightning such that I can save the top K checkpoints and the last checkpoint to GCS? The current behavior that I see is that only the last checkpoint is saved with the example code below:
import os
import pytorch_lightning as L
from pytorch_lightning.loggers import WandbLogger
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#customize-checkpointing-behavior-intermediate
ENCODER = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
DECODER = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
model = LitAutoEncoder(ENCODER, DECODER)
# https://stackoverflow.com/questions/64092369/validation-dataset-in-pytorch-using-dataloaders
train_loader = utils.data.DataLoader(MNIST(os.getcwd(), download=True, transform=ToTensor()))
# calling this the val dataset for testing
val_loader = utils.data.DataLoader(MNIST(os.getcwd(), download=True, transform=ToTensor(), train=False))
wandb_logger = WandbLogger(
project='myproject',
name='myname',
log_model='all', # TODO want both top checkpoint and last checkpoint to be saved (but not every epoch's)
save_dir='gs://mybucket/mypath/',
)
trainer = L.Trainer(
limit_train_batches=2,
limit_test_batches=2,
max_epochs=5,
logger=wandb_logger,
)
trainer.fit(
model=model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
wandb_logger.finalize('success')
Ideally, I’d like to automatically add the ckpt files as GCS references in wandb, but I’m happy to add that myself by adding a hook or something if it’s not supported (or I could put up a PR on the wandb client to add it if you’d prefer).
# pip list
...
wandb 0.15.12
pytorch-lightning 2.1.2
torch 2.1.1
torchmetrics 1.2.0
torchvision 0.16.1
gcsfs 2023.10.0
google-api-core 2.14.0
google-api-python-client 2.108.0
google-auth 2.23.4
google-auth-httplib2 0.1.1
google-auth-oauthlib 1.1.0
google-cloud-bigquery 3.13.0
google-cloud-core 2.3.3
google-cloud-secret-manager 2.16.4
google-cloud-storage 2.13.0
google-crc32c 1.5.0
google-resumable-media 2.6.0
googleapis-common-protos 1.61.0
grpc-google-iam-v1 0.12.7
On an M2 MBP, using poetry for env management.
Happy to provide any other useful info!