How to log custom criterion function?

We can use, criterion, ...) in order to log a model + a loss function.
But my loss function is not something simple like: criterion = nn.CrossEntropyLoss().

Rather, here’s how I calculate my loss:

            # `set_to_none=True` boosts performance
            masks_pred = model(imgs)

            probs = F.softmax(masks_pred, dim=1).float()
            ground_truth = F.one_hot(masks, model.n_classes).permute(0, 3, 1, 2).float()

            loss = criterion(masks_pred, masks) + dice_loss(probs, ground_truth)

As you can see, the loss is a composition of 2 functions: the criterion and the dice_loss function.
What should I pass to for the criterion argument?

Hi @vroomerify,

Thanks for reaching out. expects a torch function as a criterion parameter. You can set up a custom criterion function by subclassing torch.nn.Module.


