How do I keep only the last checkpoint artifact in wandb?
I am using lightning’s ModelCheckpoint to periodically save my checkpoint artifact to wandb. However, these artifacts are really large. If I keep multiple checkpoint artifact versions on wandb, they get big really quickly.
However, I can’t just checkpoint at the end of training. My GPUs occasionally terminate, so I need to checkpoint periodically.
How do I make sure that only the last checkpoint artifact is kept on wandb?
In this example, the checkpoint will be saved at the end of each epoch, but you can set whatever value you want. And if you want to save the checkpoints based on steps or time, you just need to set every_n_train_steps or train_time_interval, respectively.
If you’re looking for more specific information, I highly recommend you to check out the official docs:
Also, if you want to delete artifacts after training, you can use the wandb.Api.
import wandb
"""
deletes all models that do not have a tag attached
by default this means wandb will delete all but the "latest" or "best" models
set dry_run == False to delete...
"""
project_name='demo-project'
entity='_scott'
dry_run = True
api = wandb.Api(overrides={"project": project_name, "entity": entity})
project = api.project(project_name)
for artifact_type in project.artifacts_types():
for artifact_collection in artifact_type.collections():
for version in api.artifact_versions(artifact_type.type, artifact_collection.name):
if artifact_type.type == 'model':
if len(version.aliases) > 0:
# print out the name of the one we are keeping
print(f'KEEPING {version.name}')
else:
print(f'DELETING {version.name}')
if not dry_run:
version.delete()