使用Sklearn在Pytorch上进行的精确度、召回率和F1得分。

6 浏览
0 Comments

使用Sklearn在Pytorch上进行的精确度、召回率和F1得分。

我一直在浏览样本,但不知道如何将精确度(precision)、召回率(recall)和F1度量指标整合到我的模型中。我的代码如下:

for epoch in range(num_epochs):
    # 计算准确率(堆叠教程中没有n_total)
    n_correct = 0
    n_total = 0
    for i, (words, labels) in enumerate(train_loader):
        words = words.to(device)
        labels = labels.to(dtype=torch.long).to(device)
        # 前向传播
        outputs = model(words)
        loss = criterion(outputs, labels)
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 前馈教程解决方案
        _, predicted = torch.max(outputs, 1)
        n_correct += (predicted == labels).sum().item()
        n_total += labels.shape[0]
    accuracy = 100 * n_correct/n_total
    # 推送到matplotlib
    train_losses.append(loss.item())
    train_epochs.append(epoch)
    train_acc.append(accuracy)
    # 损失和准确率
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.2f}, Acc: {accuracy:.2f}')

0
0 Comments

问题背景:如何使用Sklearn和Pytorch计算Precision、Recall和F1 Score?

解决方法:

1. 在每个epoch开始时,初始化两个空列表,一个用于真实标签,一个用于预测标签。

2. 在epoch循环期间,将预测和标签分别添加到对应的列表中。

3. 在epoch结束时,使用Sklearn的precision_recall_fscore_support函数,以预测标签和真实标签作为输入,计算Precision、Recall和F1 Score。

注意事项:

1. 可能需要使用类似于flatten函数将列表进行展开。

2. 了解并使用torch.no_grad()函数,以在计算指标时避免梯度计算。

代码示例:

import torch
from sklearn.metrics import precision_recall_fscore_support
predicted_labels, ground_truth_labels = [], []
for epoch in range(num_epochs):
    ...
    _, predicted = torch.max(outputs, 1)
    predicted_labels.append(predicted.cpu().detach().numpy())
    ground_truth_labels.append(labels.cpu().detach().numpy())
    ...
predicted_labels = [item for sublist in predicted_labels for item in sublist]
ground_truth_labels = [item for sublist in ground_truth_labels for item in sublist]
precision, recall, f1_score, _ = precision_recall_fscore_support(predicted_labels, ground_truth_labels)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1_score)

参考链接:

- [Flattening a list of numpy arrays](https://stackoverflow.com/questions/33711985)

- [What is the use of torch.no_grad() in PyTorch?](https://datascience.stackexchange.com/questions/32651/what-is-the-use-of-torch-no-grad-in-pytorch)

0