How to log custom criterion function?

We can use wandb.watch(model, criterion, ...) in order to log a model + a loss function.
But my loss function is not something simple like: criterion = nn.CrossEntropyLoss().

Rather, here’s how I calculate my loss:

            # `set_to_none=True` boosts performance
            optimizer.zero_grad(set_to_none=True)
            masks_pred = model(imgs)

            probs = F.softmax(masks_pred, dim=1).float()
            ground_truth = F.one_hot(masks, model.n_classes).permute(0, 3, 1, 2).float()

            loss = criterion(masks_pred, masks) + dice_loss(probs, ground_truth)
            loss.backward()
            optimizer.step()

As you can see, the loss is a composition of 2 functions: the criterion and the dice_loss function.
What should I pass to wandb.watch for the criterion argument?

Hi @vroomerify,

Thanks for reaching out. wandb.watch expects a torch function as a criterion parameter. You can set up a custom criterion function by subclassing torch.nn.Module.

Thanks,
Ramit

Hi Vedant,

We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.

Best,
Weights & Biases

Hi Vedant, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!