Wandb.watch with PyTorch Lightning not logging


I moved all of my calls to wandb from the training loop to PyTorch Lightning (PL)'s Callback module. All of my wandb.log() calls are working properly, but the gradients and parameter tabs in my wandb dashboard are empty. I checked two threads:

  • Wandb.watch with pytorch not logging anything
  • When is one supposed to run wandb.watch so that weights and biases tracks params and gradients?

For the first thread, the link to the run has expired and I don’t fully understand the context of the solution "… was using forward() instead of __call__()".

For the second thread, wandb.log is getting called after PL’s Callback hook on_train_batch_end, so wandb.log should be getting called after a backward pass.

Below is a portion of the code defined in the Callback Module. At the start of training (on_fit_start) I initialize the wandb run and call wandb.watch. And after a batch is completed, (on_train_batch_end) I log all the metrics.

Class PatentLoggerCallback(Callback):
   # Omitted non-relevant code 

    def on_fit_start(self, trainer, pl_module):


    def on_train_batch_end(
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:

        metrics = outputs['metrics']
        for metric, value in metrics.items():
            wandb.log({f'train/{metric}': value})

I would like to have produced a google collab for reproducibility, but there is a lot of code involved. The next best thing I can offer is this Jupyter Notebook that runs through my entire code. I’m not expecting you to clone to repository, but if you do make sure you’re on the “nakama” branch.


In my model’s forward pass I wrote return torch.rand(input_ids.shape[0], requires_grad=True) to faster debug other wandb.log calls by skipping expensive computation. B/c none of the model’s parameters were being updated, there was no backward pass over the model being watched. The only thing that underwent the backward pass was the torch.rand() hence no histograms.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.