Log three losses

Hi, I am trying to log the three losses to the wandb, however, it is showing a chart for the iteration loss and not for the validation and training loss. I really appreciate any help you can provide.

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    dataloader_iterator = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}", total=len(train_dataloader))
    best_validation_loss = float('inf')
    for i,batch in dataloader_iterator:
      count +=1
      images, box_coord, labels =batch
      images, box_coord, labels = images.to(fabric.device), box_coord.to(fabric.device), labels.to(fabric.device)
      optimizer.zero_grad()
      outputs = model((images, box_coord))
      euclidean_distance = torch.sqrt(torch.sum((outputs - labels) ** 2, dim=1))
      loss = euclidean_distance.mean()
      fabric.backward(loss)
      fabric.clip_gradients(model, optimizer, clip_val=clip_val)
      optimizer.step()
      running_loss += loss.item()
      scheduler.step(running_loss)
      wandb.log({'Iteration Loss': loss.item()}, step=epoch * len(train_dataloader) + i)
      if count % validation_frequency == 0:
            model.eval()
            val_loss = evaluate_model(model, val_dataloader)
            print(f'Epoch: {epoch} Validation Loss: {val_loss:.3f}')
            wandb.log({'Validation Loss': val_loss}, step=epoch)
            running_loss =0
            if val_loss < best_validation_loss:
                best_validation_loss = val_loss
                torch.save(model.state_dict(), best_model_path)
                print("Best model checkpoint saved.")
            model.train()
      if (count % accumulation_steps == 0) or (count == len(train_dataloader)):
            optimizer.step()
            optimizer.zero_grad()      
    running_loss = running_loss/accumulation_steps
    average_loss = running_loss / num_epochs
    print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {average_loss:.4f}")
     wandb.log({' Training Loss': average_loss}, step=epoch)
    torch.cuda.empty_cache()
wandb.finish()

Hi @ankurb125! Thank you for writing in! Could you please send us the link to the workspace where you are seeing this behavior?

Greetings of the day, Artsiom,
https://wandb.ai/ankurb125/vit_training?workspace=user-ankurb125

Hi @ankurb125! I looked over your logs for one of your runs here, and I do not see any:
print(f'Epoch: {epoch} Validation Loss: {val_loss:.3f}') printed which is supposed to get printed right before you wandb log validation loss.

I also do see that you run is marked as crashed, could you please confirm that your run doesn’t crash before you start logging the validation loss?

Hi there, I wanted to follow up on this request. Please let us know if we can be of further assistance or if your issue has been resolved.

Greetings of the day, Artsiom,
Sorry for not getting back to you sooner, but I am also seeing the same behavior now. (The script is running now)

I think i see what is happening here, when you are logging:
wandb.log({'Iteration Loss': loss.item()}, step=epoch * len(train_dataloader) + i)
and then try logging:
wandb.log({'Validation Loss': val_loss}, step=epoch)
and
wandb.log({' Training Loss': average_loss}, step=epoch)

The first value of your step you are trying to log epoch * len(train_dataloader) + i, is larger than the values you are trying to log after that, because of it other values are completely ignored.

What I mean by that is that wandb can only log values at a consistently increasing step, the step cannot go back and forth.

For example you are able to log values at step =1, 2, 3, 4, 5, or =2,4,6,8,10. But in your case above, you first try logging for a large step: step=epoch * len(train_dataloader) + i and then try logging step=epoch which is decreasing and getting smaller in between those calls and therefore is being ignored. Because of that, your second iteration of step in wandb.log({'Iteration Loss': loss.item()}, step=epoch * len(train_dataloader) + i) is again going to be larger than where you have a simple step=epoch, and therefore it will get ignored again and will only record 'Iteration Loss'

Hi! Wanted to follow up with you regarding this thread!

Hi,
I am extremely sorry for the slow response.
I will try to look into it based on your suggestions.
Thank you so much.

No problem! Any updates on your side?

Hi Artsiom,
There’s a problem with resources, once, it is solved, I will do the suggested things by you and let you know.

Hi @ankurb125! No worries at all, I will close this ticket for tracking purposes on our side, but once you get the resources to test it out please feel free to start a new thread and refer to this one if you are still seeing this concern on your side!