Hi, I am trying to log the three losses to the wandb, however, it is showing a chart for the iteration loss and not for the validation and training loss. I really appreciate any help you can provide.
Hi @ankurb125! I looked over your logs for one of your runs here, and I do not see any: print(f'Epoch: {epoch} Validation Loss: {val_loss:.3f}') printed which is supposed to get printed right before you wandb log validation loss.
I also do see that you run is marked as crashed, could you please confirm that your run doesn’t crash before you start logging the validation loss?
I think i see what is happening here, when you are logging: wandb.log({'Iteration Loss': loss.item()}, step=epoch * len(train_dataloader) + i)
and then try logging: wandb.log({'Validation Loss': val_loss}, step=epoch)
and wandb.log({' Training Loss': average_loss}, step=epoch)
The first value of your step you are trying to log epoch * len(train_dataloader) + i, is larger than the values you are trying to log after that, because of it other values are completely ignored.
What I mean by that is that wandb can only log values at a consistently increasing step, the step cannot go back and forth.
For example you are able to log values at step =1, 2, 3, 4, 5, or =2,4,6,8,10. But in your case above, you first try logging for a large step: step=epoch * len(train_dataloader) + i and then try logging step=epoch which is decreasing and getting smaller in between those calls and therefore is being ignored. Because of that, your second iteration of step in wandb.log({'Iteration Loss': loss.item()}, step=epoch * len(train_dataloader) + i) is again going to be larger than where you have a simple step=epoch, and therefore it will get ignored again and will only record 'Iteration Loss'
Hi @ankurb125! No worries at all, I will close this ticket for tracking purposes on our side, but once you get the resources to test it out please feel free to start a new thread and refer to this one if you are still seeing this concern on your side!