Thanks for getting back to me, and apologies for the delay! Sure, an example snippet is below:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision import transforms
import wandb
import pytorch_lightning as L
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
def create_mnist_dataloaders(batch_size,image_size=28,num_workers=2):
preprocess=transforms.Compose([transforms.Resize(image_size),\
transforms.ToTensor()])
train_dataset=MNIST(root="/scratch/cp3759/mnist/",\
train=True,\
download=False,\
transform=preprocess
)
test_dataset=MNIST(root="/scratch/cp3759/mnist/",\
train=False,\
download=False,\
transform=preprocess
)
return DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers),\
DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)
# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
train_loader,test_loader=create_mnist_dataloaders(batch_size=64,image_size=28)
wandb.init(project="ddp_test",entity="chris-pedersen",dir="/scratch/cp3759/thermalizer_data/wandb_data")
logger = WandbLogger()
# init the autoencoder
model = LitAutoEncoder(encoder, decoder)
wandb.watch(model)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(accelerator="auto",
limit_train_batches=100,
max_epochs=10,
logger=logger,
enable_checkpointing=False,
enable_progress_bar=False,
devices=4,
strategy="ddp")
trainer.fit(model=model,
train_dataloaders=train_loader,
)
My lightning version is 2.1.3, wandb version is 0.13.10.
Some things in the above snippet are specific to my system, particularly where the data is loaded. The wandb dashboard for this run is here - for this single job there are 4 experiments created, but 3 are not tracking anything on the dashboard. Is there any way around this, do you know?
Thanks again!