Thanks for getting back to me, and apologies for the delay! Sure, an example snippet is below:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision import transforms
import wandb
import pytorch_lightning as L
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
def create_mnist_dataloaders(batch_size,image_size=28,num_workers=2):
preprocess=transforms.Compose([transforms.Resize(image_size),\
transforms.ToTensor()])
train_dataset=MNIST(root="/scratch/cp3759/mnist/",\
train=True,\
download=False,\
transform=preprocess
)
test_dataset=MNIST(root="/scratch/cp3759/mnist/",\
train=False,\
download=False,\
transform=preprocess
)
return DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers),\
DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)
# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
train_loader,test_loader=create_mnist_dataloaders(batch_size=64,image_size=28)
wandb.init(project="ddp_test",entity="chris-pedersen",dir="/scratch/cp3759/thermalizer_data/wandb_data")
logger = WandbLogger()
# init the autoencoder
model = LitAutoEncoder(encoder, decoder)
wandb.watch(model)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(accelerator="auto",
limit_train_batches=100,
max_epochs=10,
logger=logger,
enable_checkpointing=False,
enable_progress_bar=False,
devices=4,
strategy="ddp")
trainer.fit(model=model,
train_dataloaders=train_loader,
)
My lightning version is 2.1.3
, wandb version is 0.13.10
.
Some things in the above snippet are specific to my system, particularly where the data is loaded. The wandb dashboard for this run is here - for this single job there are 4 experiments created, but 3 are not tracking anything on the dashboard. Is there any way around this, do you know?
Thanks again!