I have a notebook based on Supercharge your Training with PyTorch Lightning + Weights & Biases and I’m wondering what the easiest approach to load a model with the best checkpoint after training finishes.
I’m assuming that after training the “model” instance will just have the weights of the most recent epoch, which might not be the most accurate model (in case it started overfitting etc).
Specifically I was looking for an easy way to get the directory where the checkpoints artifacts are stored, which in my case look like this:
1vzsgin6 is the run id auto-generated by wandb.
One (clunky) way to do it would be:
wandb_logger = WandbLogger(project="MnistKaggle") checkpoint_dir_path = None def my_after_save_checkpoint(checkpoint): checkpoint_dir_path = checkpoint.dirpath wandb_logger.after_save_checkpoint = my_after_save_checkpoint # Now find the checkpoint file in the checkpoint_dir_path directory and load the model from that.
Is there an easier way? I was sorta expecting the
WandbLogger object to have an easy method like
get_save_checkpoint_dirpath(), but I’m not seeing anything.
Thanks in advance for any help!