After calling wandb.watch while training, I saved my model object by pickeling with torch.save(). Then, in a new script, I reloaded the model object and ran a similar training script to continue training for more epochs. I get an error that says:
ValueError: You can only call `wandb.watch` once per model. Pass a new instance of the model if you need to call wandb.watch again in your code.
I can’t figure out what attribute of the model I should delete to avoid wandb seeing it as a model that has already been watched. Or perhaps, its more that on the server side, wandb recognizes this object as the same object as before and already has its hash or something?
Either way, I would like advice on how to either (a) have a conditional block like
if not [somehow check if the model has been watched]:
wandb.watch(model)
or a way to “clean” the model of its wandb.watch history, like
del model._hidden_property_created_by_wandb
Thanks for any advice
Sam
Hey @samlapp, in your wandb.watch()
call, are you specifying any arguments? It is intended behavior to only call wandb.watch
once per model.
I recommend calling wandb.unwatch()
on the model before pickling it. This would remove the hooks that are being saved. You can do this with a line similar to the following:
wandb.unwatch(model)
Please let me know if this helps!
1 Like
Hi @samlapp,
We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.
Best,
Uma
Hi samlapp, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!
Apologies for the delay. The .unwatch() solution hasn’t solved the issue for me. For instance, if I
model.train(...)
then
wandb.unwatch(model)
then model.train(...)
again, I still get the same error: ValueError: You can only call
wandb.watch once per model. Pass a new instance of the model if you need to call wandb.watch again in your code.
1 Like
Here’s how I’m calling it:
wandb_session = wandb.init(...)
wandb_session.watch(
model, log="all", log_freq=log_freq, log_graph=(True)
)
and
wandb_session.unwatch(model)
I think the issue is specific to using log_graph=True. When I set it to False, I no longer have the error.