the snippet:
_, _, y_pred = nn.test(test_img, test_labe)
cm = confusion_matrix(test_labe, y_pred)
# Log confusion matrix to WandB
if args.use_wandb.lower() == "true":
wandb.sklearn.plot_confusion_matrix(test_labe, y_pred, labels=the_labels)
wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
probs=None, y_true=test_labe, preds=y_pred, class_names=the_labels
)})
wandb.finish()