I’m using wandb logging in conjunction with pytorch + pytorch lightning, and it seems like some of the code in wandb makes it so I cannot JIT my model into torchscript. Here’s the error
torch.jit.frontend.UnsupportedNodeError: Set aren't supported:
File "/home/peter/catkin_ws/src/venv/lib/python3.8/site-packages/wandb/wandb_torch.py", line 355
# hook has been processed
self._graph_hooks -= {id(module)}
~ <--- HERE
if not self._graph_hooks:
#!/usr/bin/env python
import torch
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
class LitMLP(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(10, 10)
def forward(self, x):
return self.layer_1(x)
wandb_logger = WandbLogger(project="torchscript_debugging")
trainer = pl.Trainer(logger=wandb_logger)
model = LitMLP()
# comment in the line below and the model will fail to compile
# wandb_logger.watch(model)
script = model.to_torchscript()
print(script(torch.ones(10)))
Thanks for the reproduction! This is a known issue and there is already an internal ticket tracking this bug - I’m going to bump the priority of this issue for you and will let you know as soon as I have some news on the progress of this issue.