MultiGPU training


I usually train my model using one gpu with slurm and pytorch lightening. The trainging phase takes around 2 days.

I want to use 2 gpus in order to speed up the training. I have read the documentation, chapter Log distributed training experiments. But, I m still confused about what method to use one process or manu processes. I feel that any of these methods suits me.

Ideally, I want to have one wandb run. I split my batch into 2 sub-batches and send each one to a gpu. then the loss is the mean of losses from each gpu. I want the management of multigpu to be absract when we look at loss function from wandb plateform. Is this possible ?

hey @asmabrazi27 -

Yes, it is possible to train your model using multiple GPUs in such a way that you have a single W&B run that logs the mean loss across the GPUs. PyTorch Lightning has built-in support for distributed training.

For your use case, you can use the Distributed Data Parallel (DDP) strategy provided by PyTorch Lightning. DDP is a method where each process operates on a complete copy of the model and only a subset of the data. This is the recommended approach for multi-GPU training with PyTorch Lightning and is well-suited for your requirement of having a single W&B run.

Here’s how you can set it up:

  1. Configure your Trainer: When creating the Trainer object in PyTorch Lightning, specify the number of GPUs you want to use and set the strategy parameter to 'ddp'. This will automatically handle the distribution of your data across the GPUs and the synchronization of the gradients.
trainer = pl.Trainer(
    gpus=2,  # Number of GPUs
    strategy='ddp',  # Use Distributed Data Parallel
    logger=wandb_logger,  # Your Weights & Biases logger
    # ... other parameters
  1. Logging: When using DDP, each process will execute the same code, but you should ensure that only the process with rank 0 logs information to W&B to avoid duplicate logs. PyTorch Lightning’s self.log method automatically takes care of this when using DDP, so you don’t need to manually handle it.
def training_step(self, batch, batch_idx):
    # Your training code here
    loss = ...
    self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True)

Please let me know if this is helpful or if you have further questions.

Hi @asmabrazi27, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!