Sweeps not showing val loss with keras

I’ve defined a sweep config:

sweep_config = {
    # Sweep Method
    "method": "random",

    # Metrics to track and optimise
    "metric": {
        "name" : "val_loss",
        "goal": "minimize"
    },
...

And then writing my train loop:

loss_fn = "mse"
batch_size=64
epochs=20
patience=4
min_delta=0
min_epoch=4

early_stop = train_utils.CustomStopper(
    monitor="val_loss",
    patience=patience,
    min_delta=min_delta,
    verbose=1,
    min_epoch=min_epoch,
)

def train(config=None):
    with wandb.init(config=config):
        config = wandb.config # sweep agent passes in a config
        wandb.config["training_params"] = training_params

        preproc_instance = build_preprocessing_instance(config.sequence_length, config.scaler)
    
        model = build_model(
            config.sequence_length,
            config.lstm_layers,
            config.lstm_neurons,
            config.lstm_activation,
            config.dense_layers,
            config.dense_neurons,
            config.dense_activation
        )
    
        x_train, y_train, x_test, y_test = build_dataset(model)

        rsquare = RSquare()
        
        model.compile(optimizer="Adam", loss=loss_fn, metrics=["mae", rsquare])
        train_history = model.fit(
            x_train,
            y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(x_test, y_test),
            callbacks=[
                early_stop,
                WandbMetricsLogger(),
            ],
        )
        
    wandb.finish()
    
    return model, train_history

On my runs page I can see val_loss is logged but on the sweeps page it’s coming up as null and won’t plot any comparison plots properly.

I figure I must be specifying the metric name incorrectly but I also tried “mse” and that didn’t work. How do I get it to track the metric properly?

I’ve figured it out. The built in WandbMetricLogger doesn’t seem to log properly and gives me weird metrics like epoch/epoch and epoch/val_loss.

I wrote my own callback and it works properly.

class WandbMetricsLogger(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        wandb.log(logs)
1 Like

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.