Easiest way to load the best model checkpoint after training w/ pytorch lightning

I have a notebook based on Supercharge your Training with PyTorch Lightning + Weights & Biases and I’m wondering what the easiest approach to load a model with the best checkpoint after training finishes.

I’m assuming that after training the “model” instance will just have the weights of the most recent epoch, which might not be the most accurate model (in case it started overfitting etc).

Specifically I was looking for an easy way to get the directory where the checkpoints artifacts are stored, which in my case look like this: ./MnistKaggle/1vzsgin6/checkpoints, where 1vzsgin6 is the run id auto-generated by wandb.

One (clunky) way to do it would be:

wandb_logger = WandbLogger(project="MnistKaggle")
checkpoint_dir_path = None

def my_after_save_checkpoint(checkpoint):
  checkpoint_dir_path = checkpoint.dirpath

wandb_logger.after_save_checkpoint = my_after_save_checkpoint

# Now find the checkpoint file in the checkpoint_dir_path directory and load the model from that.

Is there an easier way? I was sorta expecting the WandbLogger object to have an easy method like get_save_checkpoint_dirpath(), but I’m not seeing anything.

Thanks in advance for any help!

Hi @tleyden , happy to help. Please review the following resource on model checkpointing and retrieval.

A common flow would be to log a model checkpoint as in the example then to also log a “best model” artifact. Since artifacts are versioned you don’t have to worry about renaming the new “best model” artifact. Then at the end of your run you not only have an artifact history of your model at each of the checkpoints but also a versioned history of all the best models.

Hi @tleyden since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!

Thanks for the tip about the “latest/best” aliases, I hadn’t seen that. So if I understand correctly, this would be downloading the model checkpoint locally via the API - which is somewhat redundant since I assume it’s already saved locally, but it provides more control in terms of being able to specify those aliases.