Wandb.watch() when using mixed precision and torch.cuda.amp.GradScaler()

I have a PyTorch project where I’m using mixed precision gradient scaling. When using wandb.watch() to log model gradients is it possible to unscale them using something like scaler.unscale() at some point in the code prior to logging? My code looks something like the below.

wandb.init(project="my_project", name='my_run', config=config, mode='online')
model = Net()
wandb.watch(model, log='all')
optimiser = my_optim(model.parameters(),lr=lr)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

for epoch in range(epochs):
    for input, target in train_loader:
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
             pred = model(input) 
             loss = loss_fn(input, target)
        scaler.scale(loss).backward()
        scaler.step(optimiser)
        scaler.update()
        optimiser.zero_grad(set_to_none=True)

Hi @dt_90 thanks for writing in! I wanted to follow up on this request, and see if you’ve already tried it and if you ran into any issues? also, was wondering what would be the use case to log the unscaled gradients instead?

Hi @dt_90 just checking in here to see if you still experience any issues with this, and if you could provide some more information what errors you get? if possible to share a minimal code example would greatly help to reproduce the issue and further assist you with this. Thanks!

Hi @dt_90 since we haven’t heard back from you any additional information regarding this issue, we will close this ticket for now. However, please let us know if this issue still persists for you and we will be happy to keep investigating.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.