Hi!
I am currently using wandb with pytorch lightning and ray. When I run a ray run - the last 5 or so epochs of my run are not logged on wandb despite them printing they are logging in the terminal (when I look at the run logs on wandb these print statements are cut off).
Here is my callback code:
class CustomRayWandbCallback(Callback):
"""Callback that logs losses and plots to Wandb."""
# FIXME: if loss is not the MSE (because loss has regularization, or loss=L1),
# then sqrt(loss) is not the RMSE
def on_after_backward(self, trainer, pl_module):
for name, param in pl_module.named_parameters():
if param.requires_grad:
gradients = param.grad.detach().cpu()
wandb.log({"gradients": wandb.Histogram(gradients)})
def on_train_epoch_end(self, trainer, pl_module):
"""Save train predictions, targets, and parameters as histograms.
Parameters
----------
trainer : pytorch_lightning.Trainer
Unused. Here for API compliance.
pl_module : Proteo LightningModule
Lightning's module for training.
"""
train_preds = torch.vstack(pl_module.train_preds).detach().cpu()
train_targets = torch.vstack(pl_module.train_targets).detach().cpu()
params = torch.concat([p.flatten() for p in pl_module.parameters()]).detach().cpu()
train_loss = pl_module.trainer.callback_metrics["train_loss"]
train_RMSE = math.sqrt(train_loss)
# Log the first graph ([0, :]) of x0, x1, and x2 to see if oversmoothing is happening, aka if all features across 1 person are the same
x0 = torch.vstack(pl_module.x0).detach().cpu()[0, :]
x1 = torch.vstack(pl_module.x1).detach().cpu()[0, :]
x2 = torch.vstack(pl_module.x2).detach().cpu()[0, :]
multiscale = torch.vstack(pl_module.multiscale).detach().cpu()
multiscale = (
torch.norm(multiscale, dim=1).detach().cpu()
) # Average the features across the 3 layers per person to get one value per person
wandb.log(
{
"train_loss": train_loss,
"train_RMSE": train_RMSE,
"train_preds": wandb.Histogram(train_preds),
"train_targets": wandb.Histogram(train_targets),
"parameters (weights+biases)": wandb.Histogram(params),
"x0": wandb.Histogram(x0),
"x1": wandb.Histogram(x1),
"x2": wandb.Histogram(x2),
"multiscale norm for all people": wandb.Histogram(multiscale),
"epoch": pl_module.current_epoch,
}
)
if pl_module.config.y_val in CONTINOUS_Y_VALS:
if train_loss < pl_module.min_train_loss:
pl_module.min_train_loss = train_loss
scatter_plot_data = [
[pred, target] for (pred, target) in zip(train_preds, train_targets)
]
table = wandb.Table(data=scatter_plot_data, columns=["pred", "target"])
wandb.log(
{
f"Regression Scatter Plot Train": wandb.plot.scatter(
table, "pred", "target", title=f"Train Pred vs Train Target Scatter Plot"
),
"epoch": pl_module.current_epoch,
}
)
elif pl_module.config.y_val in BINARY_Y_VALS_MAP:
train_preds_sigmoid = torch.sigmoid(train_preds)
predicted_classes = (train_preds_sigmoid > 0.5).int()
train_accuracy = (predicted_classes == train_targets).float().mean().item()
# Convert tensors to numpy arrays and ensure they are integers
train_targets_np = train_targets.numpy().astype(int).flatten()
predicted_classes_np = predicted_classes.numpy().astype(int).flatten()
wandb.log(
{
"train_preds_sigmoid": train_preds_sigmoid,
"train_accuracy": train_accuracy,
"conf_mat train": wandb.plot.confusion_matrix(
probs=None,
y_true=train_targets_np,
preds=predicted_classes_np,
class_names=['Control (nfl), 0 (cdr)', 'Carrier (nfl), >0 (cdr)'],
), # TODO: hacky
"epoch": pl_module.current_epoch,
}
)
elif pl_module.config.y_val in MULTICLASS_Y_VALS_MAP:
softmax_probs = F.softmax(train_preds, dim=1)
class_preds = torch.argmax(softmax_probs, dim=1)
train_accuracy = (class_preds == train_targets).float().mean().item()
class_preds_np = class_preds.numpy().astype(int).flatten()
train_targets_np = train_targets.numpy().astype(int).flatten()
wandb.log(
{
"val_preds_softmax": wandb.Histogram(softmax_probs),
"val_preds_class": wandb.Histogram(class_preds),
"train_accuracy": train_accuracy,
"conf_matrix train": wandb.plot.confusion_matrix(
probs=None,
y_true=train_targets_np,
preds=class_preds_np,
class_names=['0', '0.5', '1', '2', '3'], # TODO: Hardcoded
),
"epoch": pl_module.current_epoch,
}
)
pl_module.train_preds.clear() # free memory
pl_module.train_targets.clear()
pl_module.x0.clear()
pl_module.x1.clear()
pl_module.x2.clear()
def on_validation_epoch_end(self, trainer, pl_module):
"""Save val predictions and targets as histograms and log confusion matrix.
Parameters
----------
trainer : pytorch_lightning.Trainer
Lightning's trainer object.
pl_module : Proteo LightningModule
Lightning's module for training.
"""
if not trainer.sanity_checking:
val_preds = torch.vstack(pl_module.val_preds).detach().cpu()
val_targets = torch.vstack(pl_module.val_targets).detach().cpu()
val_loss = pl_module.trainer.callback_metrics["val_loss"]
print(f"wandb logging at epoch {pl_module.current_epoch}")
# Log histograms and metrics
wandb.log(
{
"val_loss": val_loss,
"val_preds": wandb.Histogram(val_preds),
"val_targets": wandb.Histogram(val_targets),
"epoch": pl_module.current_epoch,
}
)
if pl_module.config.y_val in CONTINOUS_Y_VALS:
if val_loss < pl_module.min_val_loss:
pl_module.min_val_loss = val_loss
print("min_loss_val =", pl_module.min_val_loss)
print("val_loss =", val_loss)
print("epoch =", pl_module.current_epoch)
print("val_preds =", val_preds)
scatter_plot_data = [
[pred, target] for (pred, target) in zip(val_preds, val_targets)
]
table = wandb.Table(data=scatter_plot_data, columns=["pred", "target"])
wandb.log(
{
"Regression Scatter Plot Val": wandb.plot.scatter(
table, "pred", "target", title=f"Val Pred vs Val Target Scatter Plot"
),
"epoch": pl_module.current_epoch,
}
)
elif pl_module.config.y_val in BINARY_Y_VALS_MAP:
val_preds_sigmoid = torch.sigmoid(val_preds)
# Note this assumes binary classification
predicted_classes = (val_preds_sigmoid > 0.5).int()
val_accuracy = (predicted_classes == val_targets).float().mean().item()
# Convert tensors to numpy arrays and ensure they are integers
val_targets_np = val_targets.numpy().astype(int).flatten()
predicted_classes_np = predicted_classes.numpy().astype(int).flatten()
wandb.log(
{
"val_preds_sigmoid": wandb.Histogram(val_preds_sigmoid),
"val_accuracy": val_accuracy,
"conf_mat val": wandb.plot.confusion_matrix(
probs=None,
y_true=val_targets_np,
preds=predicted_classes_np,
class_names=[
'Control (nfl), 0 (cdr)',
'Carrier (nfl), >0 (cdr)',
], # TODO: hacky
),
"epoch": pl_module.current_epoch,
}
)
elif pl_module.config.y_val in MULTICLASS_Y_VALS_MAP:
softmax_probs = F.softmax(val_preds, dim=1)
class_preds = torch.argmax(softmax_probs, dim=1)
val_accuracy = (class_preds == val_targets).float().mean().item()
class_preds_np = class_preds.numpy().astype(int).flatten()
val_targets_np = val_targets.numpy().astype(int).flatten()
wandb.log(
{
"val_preds_softmax": wandb.Histogram(softmax_probs),
"val_preds_class": wandb.Histogram(class_preds),
"val_accuracy": val_accuracy,
"conf_matrix val": wandb.plot.confusion_matrix(
probs=None,
y_true=val_targets_np,
preds=class_preds_np,
class_names=['0', '0.5', '1', '2', '3'], # TODO: Hardcoded
),
"epoch": pl_module.current_epoch,
}
)
pl_module.val_preds.clear() # free memory
pl_module.val_targets.clear()
And here is my wandb setup:
def train_func(train_loop_config):
"""Train one neural network with Lightning.
Configure Lightning with Ray.
Parameters
----------
search_config: Configuration parameters for training.
"""
torch.set_float32_matmul_precision('medium') # for performance
config = read_config_from_file(CONFIG_FILE)
pl.seed_everything(config.seed)
# Update default config with sweep-specific train_loop_config
# Update model specific parameters: hacky - is there a better way?
model = train_loop_config['model']
train_loop_config_model = {
'hidden_channels': train_loop_config['hidden_channels'],
'heads': train_loop_config['heads'],
'num_layers': train_loop_config['num_layers'],
}
config[model].update(train_loop_config_model)
# Remove keys that were already updated in nested configuration
for key in train_loop_config_model:
train_loop_config.pop(key)
config.update(train_loop_config)
setup_wandb( # wandb.init, but for ray
config.dict(), # Transform Config object into dict for wandb
project=config.project,
api_key_file=os.path.join(config.root_dir, config.wandb_api_key_path),
# Directory in dir needs to exist, otherwise wandb saves in /tmp
dir=os.path.join(config.root_dir, config.output_dir),
mode="offline" if config.wandb_offline else "online",
)
train_dataset, test_dataset = proteo_train.construct_datasets(config)
train_loader, test_loader = proteo_train.construct_loaders(config, train_dataset, test_dataset)
avg_node_degree = proteo_train.compute_avg_node_degree(test_dataset)
pos_weight = 1.0 # default value
focal_loss_weight = [1.0] # default value
if config.y_val in BINARY_Y_VALS_MAP:
pos_weight = proteo_train.compute_pos_weight(test_dataset, train_dataset)
elif config.y_val in MULTICLASS_Y_VALS_MAP:
focal_loss_weight = proteo_train.compute_focal_loss_weight(
config, test_dataset, train_dataset
)
y_mean, y_std = proteo_train.compute_mean_std(config, test_dataset, train_dataset)
# For wandb logging top proteins
protein_file_data = proteo_train.read_protein_file(train_dataset.processed_dir, config)
protein_names = protein_file_data['Protein']
metrics = protein_file_data['Metric']
top_proteins_data = [[protein, metric] for protein, metric in zip(protein_names, metrics)]
module = proteo_train.Proteo(
config,
in_channels=train_dataset.feature_dim, # 1 dim of input
out_channels=train_dataset.label_dim, # 1 dim of result
avg_node_degree=avg_node_degree,
pos_weight=pos_weight,
focal_loss_weight=focal_loss_weight,
)
if config.y_val in Y_VALS_TO_NORMALIZE:
wandb.log(
{
"histogram original": wandb.Image(
os.path.join(
train_dataset.processed_dir,
f'{config.y_val}_{config.sex}_{config.mutation}_{config.modality}_orig_histogram.jpg',
)
)
}
)
wandb.log(
{
"histogram": wandb.Image(
os.path.join(
train_dataset.processed_dir,
f'{config.y_val}_{config.sex}_{config.mutation}_{config.modality}_histogram.jpg',
)
),
"adjacency": wandb.Image(
os.path.join(
train_dataset.processed_dir,
f"adjacency_{config.adj_thresh}_num_nodes_{config.num_nodes}_mutation_{config.mutation}_{config.modality}_sex_{config.sex}.jpg",
)
),
"top_proteins": wandb.Table(
columns=["Protein", "Metric"], data=top_proteins_data
), # note this is in order from most to least different
"parameters": wandb.Table(
columns=["Medium", "Mutation", "Target", "Sex", "Avg Node Degree", "Mean", "Std"],
data=[
[
config.modality,
config.mutation,
config.y_val,
config.sex,
avg_node_degree,
y_mean,
y_std,
]
],
),
}
)
# Define Lightning's Trainer that will be wrapped by Ray's TorchTrainer
trainer = pl.Trainer(
devices='auto',
accelerator='auto',
strategy=ray_lightning.RayDDPStrategy(),
callbacks=[
proteo_callbacks_ray.CustomRayWandbCallback(),
proteo_callbacks_ray.CustomRayReportLossCallback(),
TuneReportCheckpointCallback(
metrics={"val_loss": "val_loss", "train_loss": "train_loss"},
filename=f"checkpoint.ckpt",
on="validation_end",
),
# proteo_callbacks_ray.CustomRayCheckpointCallback(
# checkpoint_every_n_epochs=config.checkpoint_every_n_epochs,
# ),
],
# How ray interacts with pytorch lightning
plugins=[ray_lightning.RayLightningEnvironment()],
enable_progress_bar=False,
max_epochs=config.epochs,
log_every_n_steps=config.log_every_n_steps,
deterministic=True,
)
trainer = ray_lightning.prepare_trainer(trainer)
# FIXME: When a trial errors, Wandb still shows it as "running".
trainer.fit(module, train_loader, test_loader)
time.sleep(5) # Wait for wandb to finish logging
wandb.finish()
Maybe the code is running too fast to send the last runs to ray? I never see “wandb: Waiting for W&B process to finish… (success).” printed in the terminal when I run using ray so perhaps something is wrong with the configuration.
EDIT: I have followed the suggestions here: [Tune][wandb] Not all logs and artifacts get uploaded to wandb when a Tune experiment finishes · Issue #33129 · ray-project/ray · GitHub but the issue is still occurring.