How can I log best values of a metric/loss in wandb summary using Pytorch-Lightning?

In wandb documentation there are examples like this: wandb.run.summary["best_accuracy"] = best_accuracy
Not sure how to do something like that in training_step/validation_step using self.log, considering the values of metrics are automatically calculated on epochs under the hood of the API.

Hi,
Can you give a bit more info of what you’re trying to do? Have you tried doing the above in your own lightning module? You could also try adding a Callback which has access to the logger.

1 Like

I have the following code:

def training_step(self, batch, *args, **kwargs):  # type: ignore
    pred = self(batch).squeeze(-1)
    loss = self.loss(pred, batch['p'], batch['u_out']).mean()
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    score = self.metric(pred, batch['p'], batch['u_out']).mean()
    self.log(f'train_mae', score, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

As you can see, I calculate the metric and the loss only during the training step (and the validation step) and the values at epoch level are aggregated and logged automatically by pytorch-lightning and wandb.

I suppose I could return the loss and the metric in training_step and then calculate epoch level values manually in training_epoch_end and then log the best value.

But I hope there is a better solution.

1 Like

One thing you can do is use wandb.run.define_metric("train_mae", summary="max")

This will automatically store both the latest and the max value in the summary for filtering.