Distributed data parallel with pytorch lightning

Hi there,

I’m running over multiple GPUs using lightning’s DDP wrapper, and each GPU thread creates a new experiment on wandb. Is it possible to just track the rank 0 process? I can see how to do this in just vanilla pytorch, but not while using the lightning wandblogger.

Thanks!

Hi @chris-pedersen, thank you for reaching out with your question. While tracking only the rank zero process is not currently possible, when using WandbLogger with multiple GPUs, only one Runs should be created in W&B.

Would you mind sharing the URL for the workspace you see multiple Runs being created, as well as the versions of PTL and wandb you are currently running and a snippet of code showing how the WandbLogger and ptl Trainer are being configured?

Thanks,
Francesco

Hi @chris-pedersen , I wanted to follow up on this request. Please let us know if we can be of further assistance or if your issue has been resolved.

Hi @chris-pedersen, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!

Thanks for getting back to me, and apologies for the delay! Sure, an example snippet is below:

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision import transforms 
import wandb
import pytorch_lightning as L
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

def create_mnist_dataloaders(batch_size,image_size=28,num_workers=2):
    
    preprocess=transforms.Compose([transforms.Resize(image_size),\
                                    transforms.ToTensor()])

    train_dataset=MNIST(root="/scratch/cp3759/mnist/",\
                        train=True,\
                        download=False,\
                        transform=preprocess
                        )
    test_dataset=MNIST(root="/scratch/cp3759/mnist/",\
                        train=False,\
                        download=False,\
                        transform=preprocess
                        )


    return DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers),\
            DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)

# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

train_loader,test_loader=create_mnist_dataloaders(batch_size=64,image_size=28)

wandb.init(project="ddp_test",entity="chris-pedersen",dir="/scratch/cp3759/thermalizer_data/wandb_data")
logger = WandbLogger()

# init the autoencoder
model = LitAutoEncoder(encoder, decoder)
wandb.watch(model)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(accelerator="auto",
                    limit_train_batches=100,
                    max_epochs=10,
                    logger=logger,
                    enable_checkpointing=False,
                    enable_progress_bar=False,
                    devices=4,
                    strategy="ddp")
trainer.fit(model=model,
            train_dataloaders=train_loader,
            )

My lightning version is 2.1.3, wandb version is 0.13.10.

Some things in the above snippet are specific to my system, particularly where the data is loaded. The wandb dashboard for this run is here - for this single job there are 4 experiments created, but 3 are not tracking anything on the dashboard. Is there any way around this, do you know?

Thanks again!

Hi @fmamberti-wandb , just bumping this!

Hey, any updates @fmamberti-wandb . This is also important for me…