Wandb with Ray and Pytorch lightning not logging last 5 or so epochs

Hi!
I am currently using wandb with pytorch lightning and ray. When I run a ray run - the last 5 or so epochs of my run are not logged on wandb despite them printing they are logging in the terminal (when I look at the run logs on wandb these print statements are cut off).

Here is my callback code:

class CustomRayWandbCallback(Callback):
    """Callback that logs losses and plots to Wandb."""
    # FIXME: if loss is not the MSE (because loss has regularization, or loss=L1),
    # then sqrt(loss) is not the RMSE
    def on_after_backward(self, trainer, pl_module):
        for name, param in pl_module.named_parameters():
            if param.requires_grad:
                gradients = param.grad.detach().cpu()
                wandb.log({"gradients": wandb.Histogram(gradients)})

    def on_train_epoch_end(self, trainer, pl_module):
        """Save train predictions, targets, and parameters as histograms.

        Parameters
        ----------
        trainer : pytorch_lightning.Trainer
            Unused. Here for API compliance.
        pl_module : Proteo LightningModule
            Lightning's module for training.
        """
        train_preds = torch.vstack(pl_module.train_preds).detach().cpu()
        train_targets = torch.vstack(pl_module.train_targets).detach().cpu()
        params = torch.concat([p.flatten() for p in pl_module.parameters()]).detach().cpu()
        train_loss = pl_module.trainer.callback_metrics["train_loss"]
        train_RMSE = math.sqrt(train_loss)
        # Log the first graph ([0, :]) of x0, x1, and x2 to see if oversmoothing is happening, aka if all features across 1 person are the same
        x0 = torch.vstack(pl_module.x0).detach().cpu()[0, :]
        x1 = torch.vstack(pl_module.x1).detach().cpu()[0, :]
        x2 = torch.vstack(pl_module.x2).detach().cpu()[0, :]
        multiscale = torch.vstack(pl_module.multiscale).detach().cpu()
        multiscale = (
            torch.norm(multiscale, dim=1).detach().cpu()
        )  # Average the features across the 3 layers per person to get one value per person
        wandb.log(
            {
                "train_loss": train_loss,
                "train_RMSE": train_RMSE,
                "train_preds": wandb.Histogram(train_preds),
                "train_targets": wandb.Histogram(train_targets),
                "parameters (weights+biases)": wandb.Histogram(params),
                "x0": wandb.Histogram(x0),
                "x1": wandb.Histogram(x1),
                "x2": wandb.Histogram(x2),
                "multiscale norm for all people": wandb.Histogram(multiscale),
                "epoch": pl_module.current_epoch,
            }
        )
        if pl_module.config.y_val in CONTINOUS_Y_VALS:
            if train_loss < pl_module.min_train_loss:
                pl_module.min_train_loss = train_loss
                scatter_plot_data = [
                    [pred, target] for (pred, target) in zip(train_preds, train_targets)
                ]
                table = wandb.Table(data=scatter_plot_data, columns=["pred", "target"])
                wandb.log(
                    {
                        f"Regression Scatter Plot Train": wandb.plot.scatter(
                            table, "pred", "target", title=f"Train Pred vs Train Target Scatter Plot"
                        ),
                        "epoch": pl_module.current_epoch,
                    }
                )
        elif pl_module.config.y_val in BINARY_Y_VALS_MAP:
            train_preds_sigmoid = torch.sigmoid(train_preds)
            predicted_classes = (train_preds_sigmoid > 0.5).int()
            train_accuracy = (predicted_classes == train_targets).float().mean().item()
            # Convert tensors to numpy arrays and ensure they are integers
            train_targets_np = train_targets.numpy().astype(int).flatten()
            predicted_classes_np = predicted_classes.numpy().astype(int).flatten()
            wandb.log(
                {
                    "train_preds_sigmoid": train_preds_sigmoid,
                    "train_accuracy": train_accuracy,
                    "conf_mat train": wandb.plot.confusion_matrix(
                        probs=None,
                        y_true=train_targets_np,
                        preds=predicted_classes_np,
                        class_names=['Control (nfl), 0 (cdr)', 'Carrier (nfl), >0 (cdr)'],
                    ),  # TODO: hacky
                    "epoch": pl_module.current_epoch,
                }
            )
        elif pl_module.config.y_val in MULTICLASS_Y_VALS_MAP:
            softmax_probs = F.softmax(train_preds, dim=1)
            class_preds = torch.argmax(softmax_probs, dim=1)
            train_accuracy = (class_preds == train_targets).float().mean().item()
            class_preds_np = class_preds.numpy().astype(int).flatten()
            train_targets_np = train_targets.numpy().astype(int).flatten()
            wandb.log(
                {
                    "val_preds_softmax": wandb.Histogram(softmax_probs),
                    "val_preds_class": wandb.Histogram(class_preds),
                    "train_accuracy": train_accuracy,
                    "conf_matrix train": wandb.plot.confusion_matrix(
                        probs=None,
                        y_true=train_targets_np,
                        preds=class_preds_np,
                        class_names=['0', '0.5', '1', '2', '3'],  # TODO: Hardcoded
                    ),
                    "epoch": pl_module.current_epoch,
                }
            )
        pl_module.train_preds.clear()  # free memory
        pl_module.train_targets.clear()
        pl_module.x0.clear()
        pl_module.x1.clear()
        pl_module.x2.clear()

    def on_validation_epoch_end(self, trainer, pl_module):
        """Save val predictions and targets as histograms and log confusion matrix.

        Parameters
        ----------
        trainer : pytorch_lightning.Trainer
            Lightning's trainer object.
        pl_module : Proteo LightningModule
            Lightning's module for training.
        """
        if not trainer.sanity_checking:
            val_preds = torch.vstack(pl_module.val_preds).detach().cpu()
            val_targets = torch.vstack(pl_module.val_targets).detach().cpu()
            val_loss = pl_module.trainer.callback_metrics["val_loss"]

            print(f"wandb logging at epoch {pl_module.current_epoch}")
            # Log histograms and metrics
            wandb.log(
                {
                    "val_loss": val_loss,
                    "val_preds": wandb.Histogram(val_preds),
                    "val_targets": wandb.Histogram(val_targets),
                    "epoch": pl_module.current_epoch,
                }
            )
            if pl_module.config.y_val in CONTINOUS_Y_VALS:
                if val_loss < pl_module.min_val_loss:
                    pl_module.min_val_loss = val_loss
                    print("min_loss_val =", pl_module.min_val_loss)
                    print("val_loss =", val_loss)
                    print("epoch =", pl_module.current_epoch)
                    print("val_preds =", val_preds)
                    scatter_plot_data = [
                        [pred, target] for (pred, target) in zip(val_preds, val_targets)
                    ]
                    table = wandb.Table(data=scatter_plot_data, columns=["pred", "target"])
                    wandb.log(
                        {
                            "Regression Scatter Plot Val": wandb.plot.scatter(
                                table, "pred", "target", title=f"Val Pred vs Val Target Scatter Plot"
                            ),
                            "epoch": pl_module.current_epoch,
                        }
                    )
            elif pl_module.config.y_val in BINARY_Y_VALS_MAP:
                val_preds_sigmoid = torch.sigmoid(val_preds)
                # Note this assumes binary classification
                predicted_classes = (val_preds_sigmoid > 0.5).int()
                val_accuracy = (predicted_classes == val_targets).float().mean().item()
                # Convert tensors to numpy arrays and ensure they are integers
                val_targets_np = val_targets.numpy().astype(int).flatten()
                predicted_classes_np = predicted_classes.numpy().astype(int).flatten()

                wandb.log(
                    {
                        "val_preds_sigmoid": wandb.Histogram(val_preds_sigmoid),
                        "val_accuracy": val_accuracy,
                        "conf_mat val": wandb.plot.confusion_matrix(
                            probs=None,
                            y_true=val_targets_np,
                            preds=predicted_classes_np,
                            class_names=[
                                'Control (nfl), 0 (cdr)',
                                'Carrier (nfl), >0 (cdr)',
                            ],  # TODO: hacky
                        ),
                        "epoch": pl_module.current_epoch,
                    }
                )
            elif pl_module.config.y_val in MULTICLASS_Y_VALS_MAP:
                softmax_probs = F.softmax(val_preds, dim=1)
                class_preds = torch.argmax(softmax_probs, dim=1)
                val_accuracy = (class_preds == val_targets).float().mean().item()
                class_preds_np = class_preds.numpy().astype(int).flatten()
                val_targets_np = val_targets.numpy().astype(int).flatten()
                wandb.log(
                    {
                        "val_preds_softmax": wandb.Histogram(softmax_probs),
                        "val_preds_class": wandb.Histogram(class_preds),
                        "val_accuracy": val_accuracy,
                        "conf_matrix val": wandb.plot.confusion_matrix(
                            probs=None,
                            y_true=val_targets_np,
                            preds=class_preds_np,
                            class_names=['0', '0.5', '1', '2', '3'],  # TODO: Hardcoded
                        ),
                        "epoch": pl_module.current_epoch,
                    }
                )

        pl_module.val_preds.clear()  # free memory
        pl_module.val_targets.clear()

And here is my wandb setup:

def train_func(train_loop_config):
    """Train one neural network with Lightning.

    Configure Lightning with Ray.

    Parameters
    ----------
    search_config: Configuration parameters for training.
    """
    torch.set_float32_matmul_precision('medium')  # for performance
    config = read_config_from_file(CONFIG_FILE)
    pl.seed_everything(config.seed)

    # Update default config with sweep-specific train_loop_config
    # Update model specific parameters: hacky - is there a better way?
    model = train_loop_config['model']
    train_loop_config_model = {
        'hidden_channels': train_loop_config['hidden_channels'],
        'heads': train_loop_config['heads'],
        'num_layers': train_loop_config['num_layers'],
    }
    config[model].update(train_loop_config_model)
    # Remove keys that were already updated in nested configuration
    for key in train_loop_config_model:
        train_loop_config.pop(key)
    config.update(train_loop_config)

    setup_wandb(  # wandb.init, but for ray
        config.dict(),  # Transform Config object into dict for wandb
        project=config.project,
        api_key_file=os.path.join(config.root_dir, config.wandb_api_key_path),
        # Directory in dir needs to exist, otherwise wandb saves in /tmp
        dir=os.path.join(config.root_dir, config.output_dir),
        mode="offline" if config.wandb_offline else "online",
    )

    train_dataset, test_dataset = proteo_train.construct_datasets(config)
    train_loader, test_loader = proteo_train.construct_loaders(config, train_dataset, test_dataset)

    avg_node_degree = proteo_train.compute_avg_node_degree(test_dataset)
    pos_weight = 1.0  # default value
    focal_loss_weight = [1.0]  # default value
    if config.y_val in BINARY_Y_VALS_MAP:
        pos_weight = proteo_train.compute_pos_weight(test_dataset, train_dataset)
    elif config.y_val in MULTICLASS_Y_VALS_MAP:
        focal_loss_weight = proteo_train.compute_focal_loss_weight(
            config, test_dataset, train_dataset
        )
    y_mean, y_std = proteo_train.compute_mean_std(config, test_dataset, train_dataset)
    # For wandb logging top proteins
    protein_file_data = proteo_train.read_protein_file(train_dataset.processed_dir, config)
    protein_names = protein_file_data['Protein']
    metrics = protein_file_data['Metric']
    top_proteins_data = [[protein, metric] for protein, metric in zip(protein_names, metrics)]

    module = proteo_train.Proteo(
        config,
        in_channels=train_dataset.feature_dim,  # 1 dim of input
        out_channels=train_dataset.label_dim,  # 1 dim of result
        avg_node_degree=avg_node_degree,
        pos_weight=pos_weight,
        focal_loss_weight=focal_loss_weight,
    )
    if config.y_val in Y_VALS_TO_NORMALIZE:
        wandb.log(
            {
                "histogram original": wandb.Image(
                    os.path.join(
                        train_dataset.processed_dir,
                        f'{config.y_val}_{config.sex}_{config.mutation}_{config.modality}_orig_histogram.jpg',
                    )
                )
            }
        )
    wandb.log(
        {
            "histogram": wandb.Image(
                os.path.join(
                    train_dataset.processed_dir,
                    f'{config.y_val}_{config.sex}_{config.mutation}_{config.modality}_histogram.jpg',
                )
            ),
            "adjacency": wandb.Image(
                os.path.join(
                    train_dataset.processed_dir,
                    f"adjacency_{config.adj_thresh}_num_nodes_{config.num_nodes}_mutation_{config.mutation}_{config.modality}_sex_{config.sex}.jpg",
                )
            ),
            "top_proteins": wandb.Table(
                columns=["Protein", "Metric"], data=top_proteins_data
            ),  # note this is in order from most to least different
            "parameters": wandb.Table(
                columns=["Medium", "Mutation", "Target", "Sex", "Avg Node Degree", "Mean", "Std"],
                data=[
                    [
                        config.modality,
                        config.mutation,
                        config.y_val,
                        config.sex,
                        avg_node_degree,
                        y_mean,
                        y_std,
                    ]
                ],
            ),
        }
    )

    # Define Lightning's Trainer that will be wrapped by Ray's TorchTrainer
    trainer = pl.Trainer(
        devices='auto',
        accelerator='auto',
        strategy=ray_lightning.RayDDPStrategy(),
        callbacks=[
            proteo_callbacks_ray.CustomRayWandbCallback(),
            proteo_callbacks_ray.CustomRayReportLossCallback(),
            TuneReportCheckpointCallback(
                metrics={"val_loss": "val_loss", "train_loss": "train_loss"},
                filename=f"checkpoint.ckpt",
                on="validation_end",
            ),
            # proteo_callbacks_ray.CustomRayCheckpointCallback(
            #    checkpoint_every_n_epochs=config.checkpoint_every_n_epochs,
            # ),
        ],
        # How ray interacts with pytorch lightning
        plugins=[ray_lightning.RayLightningEnvironment()],
        enable_progress_bar=False,
        max_epochs=config.epochs,
        log_every_n_steps=config.log_every_n_steps,
        deterministic=True,
    )
    trainer = ray_lightning.prepare_trainer(trainer)
    # FIXME: When a trial errors, Wandb still shows it as "running".
    trainer.fit(module, train_loader, test_loader)
    time.sleep(5)  # Wait for wandb to finish logging
    wandb.finish()

Maybe the code is running too fast to send the last runs to ray? I never see “wandb: Waiting for W&B process to finish… (success).” printed in the terminal when I run using ray so perhaps something is wrong with the configuration.
EDIT: I have followed the suggestions here: [Tune][wandb] Not all logs and artifacts get uploaded to wandb when a Tune experiment finishes · Issue #33129 · ray-project/ray · GitHub but the issue is still occurring.

Hi @louisacornelis Good day and thank you for reaching out to us. Happy to help you on this!

I just have few questions here. Does this mean that if your script will log 10 epochs, only 5 of them will reach wandb? May I also know your current SDK version? You can get this by running wandb --version. Thank you!

Hi @paulo-sabile! Yes, this is exactly what this means. My SDK version is wandb, version 0.15.12.

Thanks so much,
Louisa

Hi @louisacornelis Good day! I have reviewed this with my team and we’d like to request if you can share us the current version of Ray that you are using?

Hello! It is ray, version 2.34.0.

Thank you @louisacornelis Would it be okay if you can try using Ray 2.4.0 or 2.5.0 and check if the issue exist on these versions?

Additionally, can you also try adding a sleep at the end of the training loop + lowering # of checkpoints to upload?

Hi @paulo-sabile. I lowered the # of checkpoints and added a sleep at the end of the training loop. Unfortunately I cannot use those versions since it will require me to refactor my code implementation. Are there any other potential solutions?
Thanks!

Hi @louisacornelis Could you check the logs what’s printed for Checkpoint? Is the filesystem set to s3 or local?

Hi @paulo-sabile. The filesystem is local. Here is a complete log in wandb from one run:

 16 GPU available: True (cuda), used: True
 17 TPU available: False, using: 0 TPU cores
 18 HPU available: False, using: 0 HPUs
 19 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
 20   | Name  | Type  | Params | Mode
 21 ----------------------------------------
 22 0 | model | GATv4 | 502 K  | train
 23 ----------------------------------------
 24 502 K     Trainable params
 25 0         Non-trainable params
 26 502 K     Total params
 27 2.009     Total estimated model params size (MB)
 28 27        Modules in train mode
 29 0         Modules in eval mode
 30 Metric train_loss does not exist in `trainer.callback_metrics.
 31 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000000)
 32 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000001)
 33 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000002)
 34 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000003)
 35 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000004)
 36 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000005)
...
318 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000286)
319 Checkpoint successfully created at: Checkpoint(filesystem=local, path=/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-07_20-28-00/model=gat-v4,seed=29468_29_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.1000,l1_lambda=0.0021,lr=0.0016,lr_scheduler=LambdaLR_2024-08-07_20-28-01/checkpoint_000287)

I cut out the middle checkpoints but that is how the logs start and end. Let me know if you mean other logs. Thanks!

Thanks for sharing this @louisacornelis

To further investigate the issue, it would help if you could send debug logs on a demo snippet with WANDB_DEBUG=True as we could have more verbose logs.

Once this is enabled, could you please share the debug-internal.log and debug.log for the affected run. These files are under your local folder wandb/run-<date>_<time>-<run-id>/logs in the same directory where you’re running your code. We could review if the logs can point us to why this issue occurs.

Hi Paulo!
I have attached my logs. Thanks so much.
Louisa
debug-internal.log
debug.log

Hi @louisacornelis Hope you have been well. Back in July, you reported facing an issue with Ray and Pytorch not logging the last 5 steps.

Our team has been investigating this, but hasn’t been able to reproduce the problem. I just wanted to check in and see if the issue is persisting for you and your team especially if you can try to log new experiments using the latest versions of wandb? Thank you!

We’ll close this for now but please feel free to reach out to us if you’d like to re-open this and we’ll further review. Thanks!