How to keep only last checkpoint artifact?

How do I keep only the last checkpoint artifact in wandb?

I am using lightning’s ModelCheckpoint to periodically save my checkpoint artifact to wandb. However, these artifacts are really large. If I keep multiple checkpoint artifact versions on wandb, they get big really quickly.

However, I can’t just checkpoint at the end of training. My GPUs occasionally terminate, so I need to checkpoint periodically.

How do I make sure that only the last checkpoint artifact is kept on wandb?

Hey @turian,
You need to define a custom checkpoint callback which is straightforward:

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# define WANDB logger
wandb_logger = WandbLogger(log_model="all")

# define pytorch lightning checkpoint callback
checkpoint_callback = ModelCheckpoint(every_n_epochs=1)

# define trainer
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])

In this example, the checkpoint will be saved at the end of each epoch, but you can set whatever value you want. And if you want to save the checkpoints based on steps or time, you just need to set every_n_train_steps or train_time_interval, respectively.

If you’re looking for more specific information, I highly recommend you to check out the official docs:

Hope it does help you.

1 Like

Also, if you want to delete artifacts after training, you can use the wandb.Api.

import wandb

"""
deletes all models that do not have a tag attached

by default this means wandb will delete all but the "latest" or "best" models

set dry_run == False to delete...
"""
project_name='demo-project'
entity='_scott'
dry_run = True
api = wandb.Api(overrides={"project": project_name, "entity": entity})
project = api.project(project_name)
for artifact_type in project.artifacts_types():
    for artifact_collection in artifact_type.collections():
        for version in api.artifact_versions(artifact_type.type, artifact_collection.name):
            if artifact_type.type == 'model':
                if len(version.aliases) > 0:
                    # print out the name of the one we are keeping
                    print(f'KEEPING {version.name}')
                else:
                    print(f'DELETING {version.name}')
                    if not dry_run:
                        version.delete()

Source for this snippet:

Hi Joseph, thanks for your question! Would the solutions proposed by Matteo and Scott work for you?

Hi there, I wanted to follow up on this request. Please let us know if we can be of further assistance or if your issue has been resolved.

Hi @_scott thanks for the code. As I mentioned here, this doesn’t appear to delete the artifacts any more, even with dry run disabled.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.