使用Sklearn在Pytorch上进行的精确度、召回率和F1得分。
使用Sklearn在Pytorch上进行的精确度、召回率和F1得分。
我一直在浏览样本,但不知道如何将精确度(precision)、召回率(recall)和F1度量指标整合到我的模型中。我的代码如下:
for epoch in range(num_epochs): # 计算准确率(堆叠教程中没有n_total) n_correct = 0 n_total = 0 for i, (words, labels) in enumerate(train_loader): words = words.to(device) labels = labels.to(dtype=torch.long).to(device) # 前向传播 outputs = model(words) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() # 前馈教程解决方案 _, predicted = torch.max(outputs, 1) n_correct += (predicted == labels).sum().item() n_total += labels.shape[0] accuracy = 100 * n_correct/n_total # 推送到matplotlib train_losses.append(loss.item()) train_epochs.append(epoch) train_acc.append(accuracy) # 损失和准确率 if (epoch+1) % 10 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.2f}, Acc: {accuracy:.2f}')
问题背景:如何使用Sklearn和Pytorch计算Precision、Recall和F1 Score?
解决方法:
1. 在每个epoch开始时,初始化两个空列表,一个用于真实标签,一个用于预测标签。
2. 在epoch循环期间,将预测和标签分别添加到对应的列表中。
3. 在epoch结束时,使用Sklearn的precision_recall_fscore_support函数,以预测标签和真实标签作为输入,计算Precision、Recall和F1 Score。
注意事项:
1. 可能需要使用类似于flatten函数将列表进行展开。
2. 了解并使用torch.no_grad()函数,以在计算指标时避免梯度计算。
代码示例:
import torch from sklearn.metrics import precision_recall_fscore_support predicted_labels, ground_truth_labels = [], [] for epoch in range(num_epochs): ... _, predicted = torch.max(outputs, 1) predicted_labels.append(predicted.cpu().detach().numpy()) ground_truth_labels.append(labels.cpu().detach().numpy()) ... predicted_labels = [item for sublist in predicted_labels for item in sublist] ground_truth_labels = [item for sublist in ground_truth_labels for item in sublist] precision, recall, f1_score, _ = precision_recall_fscore_support(predicted_labels, ground_truth_labels) print("Precision:", precision) print("Recall:", recall) print("F1 Score:", f1_score)
参考链接:
- [Flattening a list of numpy arrays](https://stackoverflow.com/questions/33711985)
- [What is the use of torch.no_grad() in PyTorch?](https://datascience.stackexchange.com/questions/32651/what-is-the-use-of-torch-no-grad-in-pytorch)