Hello WandB community,
I am having trouble doing something that I believe should be pretty straightforward. Here is the link to the Dashboard
I am training ten different splits on some data and I am logging some metrics, pretty standard stuff. Specifically I am logging some training and validation. Whenever the Validation Area Under Curve (Val AUC
) surpasses the previous highest, I compute some metrics for the Test Set. After the training script is done, I download all the information in .csv
files
The problem arises when comparing plots as the X-axis
or Step
is different for each one. For example: if we take a look at the run Pleasant-elevator-24
the highest Val AUC
happens at Step
530. But the logged value of Test auc
is at 530.
What I would like to have the X Steps
of all the plots synced.
My code to plot this is the following:
if max_auc<metrics[0] and epoch >5:
max_auc = metrics[0]
torch.save(model.state_dict(), os.path.join(args.results_dir, "s_{}_{}_checkpoint.pt".format(round(metrics[0],3), cur)))
wandb.log({"P-R curve" : wandb.plot.pr_curve(metrics[1], metrics[2], labels=['Negative', 'Positive'])})
results_dict, test_error, test_auc, acc_logger, metrics_test = summary(model, test_loader, args.n_classes)
wandb.log({"P-R curve test" : wandb.plot.pr_curve(metrics_test[1], metrics_test[2], labels=['Negative', 'Positive'])})
wandb.log({'Test auc' : metrics_test[0],
'Test bal acc' : metrics_test[3],
'Test sensitivity': metrics_test[4],
'Test specificity': metrics_test[5]})
Is there a way in which I can solve this issue without having to re-run the experiments again?
Maybe I am getting something wrong here, I appreciate all the help!