How can I log best values of a metric/loss in wandb summary using Pytorch-Lightning?

In wandb documentation there are examples like this: wandb.run.summary["best_accuracy"] = best_accuracy
Not sure how to do something like that in training_step/validation_step using self.log, considering the values of metrics are automatically calculated on epochs under the hood of the API.

Hi,
Can you give a bit more info of what you’re trying to do? Have you tried doing the above in your own lightning module? You could also try adding a Callback which has access to the logger.

1 Like

I have the following code:

def training_step(self, batch, *args, **kwargs):  # type: ignore
    pred = self(batch).squeeze(-1)
    loss = self.loss(pred, batch['p'], batch['u_out']).mean()
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    score = self.metric(pred, batch['p'], batch['u_out']).mean()
    self.log(f'train_mae', score, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

As you can see, I calculate the metric and the loss only during the training step (and the validation step) and the values at epoch level are aggregated and logged automatically by pytorch-lightning and wandb.

I suppose I could return the loss and the metric in training_step and then calculate epoch level values manually in training_epoch_end and then log the best value.

But I hope there is a better solution.

1 Like

One thing you can do is use wandb.run.define_metric("train_mae", summary="max")

This will automatically store both the latest and the max value in the summary for filtering.

1 Like

It turns out, that wandb logger doesn’t have a possibility to define the metric or summary. In the end I use a simple hacking solution:

self.mae = score if score < self.mae else self.mae
self.log('mae', self.mae, on_step=False, on_epoch=True, prog_bar=True, logger=True)

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.