I’m training a model and I’m trying to add a confusion matrix, which would be displayed in my wandb, but got lost a bit. Basically, the matrix works, I can print it, but it’s not loaded into wandb. Everything should be ok, except it’s not. Can you please help me? I’m new to all this. Thanks a lot!
nb_classes = 7
confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
wandb.log({'matrix' : confusion_matrix})