I have a notebook based on Supercharge your Training with PyTorch Lightning + Weights & Biases and I’m wondering what the easiest approach to load a model with the best checkpoint after training finishes.
I’m assuming that after training the “model” instance will just have the weights of the most recent epoch, which might not be the most accurate model (in case it started overfitting etc).
Specifically I was looking for an easy way to get the directory where the checkpoints artifacts are stored, which in my case look like this: ./MnistKaggle/1vzsgin6/checkpoints
, where 1vzsgin6
is the run id auto-generated by wandb.
One (clunky) way to do it would be:
wandb_logger = WandbLogger(project="MnistKaggle")
checkpoint_dir_path = None
def my_after_save_checkpoint(checkpoint):
checkpoint_dir_path = checkpoint.dirpath
wandb_logger.after_save_checkpoint = my_after_save_checkpoint
# Now find the checkpoint file in the checkpoint_dir_path directory and load the model from that.
Is there an easier way? I was sorta expecting the WandbLogger
object to have an easy method like get_save_checkpoint_dirpath()
, but I’m not seeing anything.
Thanks in advance for any help!