PyTorch Lightning WandbLogger how to save top K checkpoints + last checkpoint to GCS?

Can someone help me to set up the WandbLogger with PyTorch Lightning such that I can save the top K checkpoints and the last checkpoint to GCS? The current behavior that I see is that only the last checkpoint is saved with the example code below:

import os

import pytorch_lightning as L
from pytorch_lightning.loggers import WandbLogger
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#customize-checkpointing-behavior-intermediate
ENCODER = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
DECODER = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model = LitAutoEncoder(ENCODER, DECODER)
# https://stackoverflow.com/questions/64092369/validation-dataset-in-pytorch-using-dataloaders
train_loader = utils.data.DataLoader(MNIST(os.getcwd(), download=True, transform=ToTensor()))
# calling this the val dataset for testing
val_loader = utils.data.DataLoader(MNIST(os.getcwd(), download=True, transform=ToTensor(), train=False))
wandb_logger = WandbLogger(
    project='myproject',
    name='myname',
    log_model='all', # TODO want both top checkpoint and last checkpoint to be saved (but not every epoch's)
    save_dir='gs://mybucket/mypath/',
)
trainer = L.Trainer(
    limit_train_batches=2,
    limit_test_batches=2,
    max_epochs=5,
    logger=wandb_logger,
)
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)
wandb_logger.finalize('success')

Ideally, I’d like to automatically add the ckpt files as GCS references in wandb, but I’m happy to add that myself by adding a hook or something if it’s not supported (or I could put up a PR on the wandb client to add it if you’d prefer).

# pip list
...
wandb                       0.15.12
pytorch-lightning           2.1.2
torch                       2.1.1
torchmetrics                1.2.0
torchvision                 0.16.1
gcsfs                       2023.10.0
google-api-core             2.14.0
google-api-python-client    2.108.0
google-auth                 2.23.4
google-auth-httplib2        0.1.1
google-auth-oauthlib        1.1.0
google-cloud-bigquery       3.13.0
google-cloud-core           2.3.3
google-cloud-secret-manager 2.16.4
google-cloud-storage        2.13.0
google-crc32c               1.5.0
google-resumable-media      2.6.0
googleapis-common-protos    1.61.0
grpc-google-iam-v1          0.12.7

On an M2 MBP, using poetry for env management.

Happy to provide any other useful info!

Hey @nathanwilk7,

Thank you for writing in.

The WandbLogger itself currently can’t be configured to save top k models, however, it may be possible to use PTL ‘ModelCheckpoint’ to save the top_k checkpoints to a preferred location: ModelCheckpoint.

For example, you could configure a callback:

from pytorch_lightning.callbacks import ModelCheckpoint

k = 2

checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',  # Replace with your validation metric
    mode='min',          # 'min' if the metric should be minimized (e.g., loss), 'max' for maximization (e.g., accuracy)
    save_top_k=k,        # Save top k checkpoints based on the monitored metric
    save_last=True,      # Save the last checkpoint at the end of training
    dirpath='gs://mybucket/mypath/',  # Directory where the checkpoints will be saved
    filename='{epoch}-{train_loss:.2f}'  # Checkpoint file naming pattern
)

to be added to your Trainer:

trainer = L.Trainer(
    limit_train_batches=2,
    limit_test_batches=2,
    max_epochs=5,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)

This should save the best k models and the last one to the defined location. Additionally, with 'log_model='all' set for the wandb_logger, all the models should be saved as Artifacts in the W&B project and tags such as latest, best and best_k should be added to them.

Please let me know if this helps and feel free to provide more information about your use case so we can look into raising a feature request for the WandbLogger to log the top k models only.

Thanks,
Francesco

Hi @nathanwilk7,

I wanted to follow up and check if you had the chance to review my previous suggestion, and if this was helpful in your case.

Thanks!
Francesco

Thanks for following up. I did end up using some of the ideas you’ve provided here, I appreciate it!

I ended up extending the WandbLogger class and overriding the after_save_checkpoint method and providing a ModelCheckpoint with save best/last and a GCS dirpath. The overriden after_save_checkpoint method finds the best/last checkpoints saved to GCS, and then deletes any wandb checkpoints registered whose checkpoint files in GCS have been overridden/deleted.

You are welcome! I am glad to hear we were able to provide you with some ideas to achieve your goals.

If you don’t have further questions, I will go ahead and close this.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.