I have a PyTorch project where I’m using mixed precision gradient scaling. When using wandb.watch()
to log model gradients is it possible to unscale them using something like scaler.unscale() at some point in the code prior to logging? My code looks something like the below.
wandb.init(project="my_project", name='my_run', config=config, mode='online')
model = Net()
wandb.watch(model, log='all')
optimiser = my_optim(model.parameters(),lr=lr)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
for epoch in range(epochs):
for input, target in train_loader:
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
pred = model(input)
loss = loss_fn(input, target)
scaler.scale(loss).backward()
scaler.step(optimiser)
scaler.update()
optimiser.zero_grad(set_to_none=True)