We can use wandb.watch(model, criterion, ...)
in order to log a model + a loss function.
But my loss function is not something simple like: criterion = nn.CrossEntropyLoss()
.
Rather, here’s how I calculate my loss:
# `set_to_none=True` boosts performance
optimizer.zero_grad(set_to_none=True)
masks_pred = model(imgs)
probs = F.softmax(masks_pred, dim=1).float()
ground_truth = F.one_hot(masks, model.n_classes).permute(0, 3, 1, 2).float()
loss = criterion(masks_pred, masks) + dice_loss(probs, ground_truth)
loss.backward()
optimizer.step()
As you can see, the loss is a composition of 2 functions: the criterion and the dice_loss
function.
What should I pass to wandb.watch
for the criterion
argument?