I couldn’t find the best practice when it comes to saving the model definition to wandb, including the forward call.
Most of my research is done by changing the forward function, so it is an important piece of data I want to track.
I tried using inspect.getsource(class) however, there seems to be an issue with using it in IPython.
I am aware that I can save the whole notebook / file, but this means a lot of auxiliary information is also saved which makes it hard to compare just the models.
Please let me know how you would approach this issue.
I would recommend having your model definition (with the forward method) in an external .py file like model.py. Once saved in such a file you can use wandb's log_code function shown here to log your model definition. This allows you to grab the model code as you would normally would with an artifact. As such you would be able to download() your logged model definition and import it like any other .py file.
You also get the nice side effect of versioning of your model definitions as it changes with little change for your downstream scripts!
pytorch has a built-in solution for this for newer versions of pytorch (added in 1.9 I think), and these model packages can be logged to wandb as an artifact.
you may like to check this out as well: torch.package — PyTorch 1.10.1 documentation
pytorch has a built-in solution for this for newer versions of pytorch (added in 1.9 I think), and these model packages can be logged to wandb as an artifact.