Wanda.watch() not logging gradients in the base case

Hello,
I have a project where I was trying to log the gradients using Wandb.watch but there is no gradients logged. What am I missing?
Here is a the training code snippet:

# 1. Start a W&B Run
wandb.init(
    project="vqvae",
    notes="This is experimenting with batch sizes",
    tags=["baseline", "vqvae_simple"],
)
num_epochs = 100
wandb.log({"Beginning_epochs": num_epochs})

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=200,
    num_training_steps=(len(train_dataloader) * num_epochs),
)
accelerator = Accelerator()
train_dataloader, val_dataloader, model, optimizer, lr_scheduler = accelerator.prepare(train_dataloader, val_dataloader, model, optimizer, lr_scheduler)
model.train()
global_step = 0
wandb.watch(model, log='all')
for epoch in range(num_epochs):
    progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
    progress_bar.set_description(f"Epoch {epoch}")
    for batch in train_dataloader:
        optimizer.zero_grad()
        model.zero_grad()
        
        x_recon, latent_loss, _ = model(batch)
        another_reconstruction_loss = F.mse_loss(batch, x_recon)
        loss = another_reconstruction_loss  + 0.25 * latent_loss
        global_step += 1
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
        wandb.log({"loss": loss})
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

Hello, This a reply from our support bot designed to assist you with your Weights & Biases related queries. To reach a human please reply to this message.

To ensure gradients are logged using wandb.watch, you need to follow a few best practices and check for compatibility issues, especially when integrating with other libraries like Hugging Face’s Accelerator. Here’s a breakdown to help troubleshoot and ensure gradients are logged correctly:

  1. Correct Placement of wandb.watch: The wandb.watch call should be placed after initializing the model and before the training loop. This allows wandb to hook into the model and log gradients. Your code correctly places wandb.watch(model, log='all') after the model initialization and before the training loop starts, which is good practice.

  2. Compatibility with Hugging Face’s Accelerator: Ensure that there are no known compatibility issues between wandb.watch and Hugging Face’s Accelerator. From the provided context, there’s no direct mention of compatibility issues. However, it’s important to note that wandb.watch is designed to work with PyTorch models. If Accelerator abstracts the model in a way that wandb cannot hook into its gradients, this might cause issues. It’s recommended to check the documentation or forums for both wandb and Hugging Face’s Accelerator for any known issues or updates.

  3. Logging Gradients: The log='all' parameter in wandb.watch is set correctly to log all gradients and parameters. This should typically ensure that gradients are logged. If gradients are still not appearing, it might be worth checking if the model’s gradients are being computed as expected. You can do this by printing or inspecting the gradients directly in PyTorch using model.parameters() and checking if gradients are non-null after the backward pass.

  4. Use of Accelerator: The use of Hugging Face’s Accelerator might require additional steps or considerations for logging to work correctly. Since Accelerator is designed to simplify running models on different devices or in mixed precision, ensure that it doesn’t interfere with how wandb hooks into the model for gradient logging. You might need to consult the Accelerator documentation for any additional steps required to expose gradients to external libraries.

  5. Ensure Proper Integration: Double-check that wandb and any other libraries (like Hugging Face’s Accelerator) are up to date. Sometimes, compatibility issues are resolved in newer versions of the libraries.

  6. Debugging Tips: If you’re still facing issues, try simplifying the setup. Temporarily remove the Accelerator and see if wandb.watch logs gradients as expected. This can help isolate whether the issue is with wandb integration or the interaction with Accelerator.

  7. Community and Support: If the issue persists, consider reaching out to the wandb community or support. The wandb community forum and GitHub issues are good places to ask for help, as there might be others who have faced similar issues.

In summary, your code snippet seems correctly set up for logging gradients with wandb.watch. If gradients are not being logged, the issue might lie in the interaction with Hugging Face’s Accelerator or a specific model configuration. Double-check compatibility, ensure all libraries are up to date, and consider reaching out to community forums for further assistance.

To reach a human please reply to this message.

-WandBot :robot:

Based on the comment from the bot, I moved the wand.watch() before the accelerator and did a model.train() before the wand.watch(). Neither seemed to do the trick.