Wandb not compatible with Torch Script

I’m using wandb logging in conjunction with pytorch + pytorch lightning, and it seems like some of the code in wandb makes it so I cannot JIT my model into torchscript. Here’s the error

torch.jit.frontend.UnsupportedNodeError: Set aren't supported:
  File "/home/peter/catkin_ws/src/venv/lib/python3.8/site-packages/wandb/wandb_torch.py", line 355
        
            # hook has been processed
            self._graph_hooks -= {id(module)}
                                 ~ <--- HERE
        
            if not self._graph_hooks:

is there a workaround for this?

Hey @petermitrano,

I’m sorry you are facing this issue. Would it be possible for you to share a minimal reproduction of this issue? It will help us debug this for you.

Thanks,
Ramit

This demonstrates the problem:

#!/usr/bin/env python

import torch
from torch import nn

import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger

class LitMLP(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(10, 10)

    def forward(self, x):
        return self.layer_1(x)


wandb_logger = WandbLogger(project="torchscript_debugging")

trainer = pl.Trainer(logger=wandb_logger)

model = LitMLP()

# comment in the line below and the model will fail to compile
# wandb_logger.watch(model)

script = model.to_torchscript()

print(script(torch.ones(10)))

Hi @petermitrano,

Thanks for the reproduction! This is a known issue and there is already an internal ticket tracking this bug - I’m going to bump the priority of this issue for you and will let you know as soon as I have some news on the progress of this issue.

Thanks,
Ramit

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.