Plotting confusion matrix

I’m training a model and I’m trying to add a confusion matrix, which would be displayed in my wandb, but got lost a bit. Basically, the matrix works, I can print it, but it’s not loaded into wandb. Everything should be ok, except it’s not. Can you please help me? I’m new to all this. Thanks a lot!

                nb_classes = 7

                confusion_matrix = torch.zeros(nb_classes, nb_classes)
                with torch.no_grad():
                    for i, (inputs, classes) in enumerate(dataloaders['val']):
                        inputs = inputs.to(device)
                        classes = classes.to(device)
                        outputs = model_ft(inputs)
                        _, preds = torch.max(outputs, 1)
                    
                    for t, p in zip(classes.view(-1), preds.view(-1)):
                        confusion_matrix[t.long(), p.long()] += 1
              wandb.log({'matrix' : confusion_matrix})
                           

Hi @diakonua , please provide a link to your space for us to investigate the not loading concern you have. Thank-you.

Hi, wandb is not to blame, it’s definitely my fault. At my wandb workspace I don’t see the plot at all, yet the code prints matrix. I then thought to use the wandb’s confusion matrix, but I cannot define it properly. Could you please help with that? Sorry, for a really primitive problem and not trying to figure the code on my own, but I’m really lost here and new to all of this.

wandb’s code

confusion_matrix = wandb.plot.confusion_matrix(
    y_true=ground_truth,
    preds=predictions,
    class_names=class_names)
    
wandb.log({"confusion_matrix": confusion_matrix })

my piece of code

def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
            else:
                model.eval()   

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                from sklearn.metrics import f1_score
                f1_score = f1_score(labels.cpu().data, preds.cpu(), average=None)
                       
                wandb.log({'F1 score' : f1_score})

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            wandb.log({'epoch loss': epoch_loss,
                    'epoch acc': epoch_acc})
            
            data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
            table = wandb.Table(data=data, columns=["step", "height"])
            wandb.log({'line-plot1': wandb.plot.line(table, "step", "height")})

        
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, f1_score))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
        print()
    
    nb_classes = 2

    confusion_matrix = torch.zeros(nb_classes, nb_classes)
    with torch.no_grad():
          for i, (inputs, classes) in enumerate(dataloaders['val']):
              inputs = inputs.to(device)
              classes = classes.to(device)
              outputs = model_ft(inputs)
              _, preds = torch.max(outputs, 1)
                    
          for t, p in zip(classes.view(-1), preds.view(-1)):
              confusion_matrix[t.long(), p.long()] += 1
                        
    sns.heatmap(confusion_matrix, annot=True)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('f1_score: {}'.format(f1_score))
   
    model.load_state_dict(best_model_wts)
    return model

Hi @diakonua , from your code block above, it’s not clear where you log your confusion matrix to wandb. I see that you’ve logged the F1 score, epoch loss, and line-plot1. Below are three example of logging confusion matrix to wandb.

from sklearn.metrics import confusion_matrix
import plotly.express as px
import seaborn as sns
import wandb

#Seaborn
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
cfm = confusion_matrix(y_true, y_pred)
cfm1 = sns.heatmap(cfm, annot=True)

#Plotly
z = [[.1, .3, .5, .7, .9],
     [1, .8, .6, .4, .2],
     [.2, 0, .5, .7, .9],
     [.9, .8, .4, .2, 0],
     [.3, .4, .5, .7, 1]]

cfm2 = px.imshow(z, text_auto=True)

#Wandb Matrix
class_names = ["zero","one","two"]
cfm3 = wandb.plot.confusion_matrix(
    y_true=y_true,
    preds=y_pred,
    class_names=class_names)

wandb.init(entity="<entity>", project="<project-name>")
wandb.log({"cfm1": cfm1})
wandb.log({"cfm2": cfm2})
wandb.log({"cfm3": cfm3})
wandb.finish()

There are additional examples in our documents. Hope this helps and please let us know if you have additional questions.