统计学习:模型评估与选择--查准率与查全率(python代码)
2021/12/20 14:20:53
本文主要是介绍统计学习:模型评估与选择--查准率与查全率(python代码),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
对于模型的性能度量存在多种方式,评价分类器性能的指标一般是分类准确率,其定义是:对于给定的测试数据集,分类器正确分类的样本数与总样本数之比。
对于二分类问题常用的评价指标是精确率(precision)与召回率(recall),通常以关注的类为正类,其余类为负类。分类器在测试数据集上的预测或正确或不正确,可以分为以下几类:
TP ----将正类预测为正类
FN ----将正类预测为负类
FP ----将负类预测为正类
TN ----将负类预测为负类
精确率的定义为:
P = TP/ TP+FP
其含义就是所有预测为正类的样本数里面,真实类别为正类的样本数所占比例
召回率的定义为:
R = TP /TP+FN
其含义就是所有的正样本里面(分母含义),预测为正样本所占比例
此外,由于某些情况下,P和R指标会出现矛盾的情况,这时就引入了一个新的评价指标 --F-Score
F1是F-Score的一种,此时的参数beta=1,表示精确率和召回率的调和平均:
F1 = 2TP / 2TP+FP+FN
F1越高,表示模型的性能越好
创建一个回归模型并使用上面指标评价模型:
from sklearn import datasets from sklearn import metrics from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler import numpy as np import matplotlib.pyplot as plt from hyperopt import fmin, tpe, hp, Trials from sklearn.svm import SVC from sklearn.linear_model import LogisticRegression # 导入手写数据集 mnist = datasets.load_digits() # 数据标准化 mnist.data = StandardScaler().fit_transform(mnist.data) # 分层采样 X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target , test_size=0.3, random_state=0) # 逻辑回归 创建实例并训练数据 model = LogisticRegression().fit(X_train, y_train) y_pre = model.predict(X_test) # 分类正确率 acc = accuracy_score(y_test, y_pre) # 宏查准率 macro = metrics.precision_score(y_test, y_pre, average="macro") # 微查准率 micro = metrics.precision_score(y_test, y_pre, average="micro") # 计算不同的F1 f1 = metrics.f1_score(y_test, y_pre, average="macro") # 加权 f1_weight = metrics.f1_score(y_test, y_pre, average="weighted") # F-BETA fbeta = metrics.fbeta_score(y_test, y_pre, average="macro", beta=1) print(acc, macro, micro, f1, f1_weight, fbeta)
绘制混淆矩阵,可以直观地看到预测值与真实值的分布情况:
官方文档:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html#sklearn.metrics.confusion_matrix
注意:版本更替移除了一项原来的版本函数
from sklearn.metrics import confusion_matrix, plot_confusion_matrix from sklearn.metrics import ConfusionMatrixDisplay # 绘制混淆矩阵 """该方法将在1.2版本之后移除,建议使用下面的展示方法 plot_confusion_matrix(model, X_test, y_test) plt.show() """ cm = confusion_matrix(y_test, y_pre, labels=model.classes_) disp = ConfusionMatrixDisplay(cm, display_labels=model.classes_) disp.plot() plt.show() # 混淆矩阵打印输出 print(cm)
矩阵形状:
对于上面输出的混淆矩阵来说,可以看到对角线上的数字表示预测正确的样本数,而不再对角线上,则表示预测为负样本(FN)的数量:
那么就可以通过混淆矩阵求解精确率P
P = TP / TP+FN
横轴每一列对于的样本总数就是当前正类下,预测为正类的样本总数,而对角线上的样本数则表示的是TP(正类预测为正类)
那么对应的,当分母变为行样本总和时,就变成了R(recall召回率)
# 使用混淆矩阵计算查准率 precision = np.diag(cm) / np.sum(cm, axis=0) # 计算召回率 recall = np.diag(cm) / np.sum(cm, axis=1) # 计算f1 f1_score = 2*precision*recall / (precision+recall) print("precision: \n", precision, "\nrecall: \n", recall)
打印精确率与召回率:
分类报告函数的使用,导入包为:
from sklearn.metrics import confusion_matrix, plot_confusion_matrix, classification_report # 分类报告 cr = classification_report(y_test, y_pre) print(cr)
该工具可以便捷的计算精确率,召回率等等:
除此之外,还可以使用pandas将数据打包输出:
prfs = metrics.precision_recall_fscore_support(y_test, y_pre) score_data = pd.DataFrame(prfs, index=["precision", "recall", "fscore", "support"]) print(score_data)
效果如图:
完整代码:
import pandas as pd from sklearn import datasets from sklearn import metrics from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler import numpy as np import matplotlib.pyplot as plt from hyperopt import fmin, tpe, hp, Trials from sklearn.svm import SVC # 导入混淆矩阵 from sklearn.metrics import confusion_matrix, plot_confusion_matrix, classification_report from sklearn.metrics import ConfusionMatrixDisplay from sklearn.linear_model import LogisticRegression # 导入手写数据集 mnist = datasets.load_digits() # 数据标准化 mnist.data = StandardScaler().fit_transform(mnist.data) # 分层采样 X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target , test_size=0.3, random_state=0) # 逻辑回归 创建实例并训练数据 model = LogisticRegression().fit(X_train, y_train) y_pre = model.predict(X_test) """# 分类正确率 acc = accuracy_score(y_test, y_pre) # 宏查准率 macro = metrics.precision_score(y_test, y_pre, average="macro") # 微查准率 micro = metrics.precision_score(y_test, y_pre, average="micro") # 计算不同的F1 f1 = metrics.f1_score(y_test, y_pre, average="macro") # 加权 f1_weight = metrics.f1_score(y_test, y_pre, average="weighted") # F-BETA fbeta = metrics.fbeta_score(y_test, y_pre, average="macro", beta=1) print(acc, macro, micro, f1, f1_weight, fbeta)""" # 绘制混淆矩阵 """该方法将在1.2版本之后移除,建议使用下面的展示方法 plot_confusion_matrix(model, X_test, y_test) plt.show() """ cm = confusion_matrix(y_test, y_pre, labels=model.classes_) disp = ConfusionMatrixDisplay(cm, display_labels=model.classes_) """disp.plot() plt.show()""" # 混淆矩阵打印输出 """print(cm) # 分类报告 cr = classification_report(y_test, y_pre) print(cr)""" """ prfs = metrics.precision_recall_fscore_support(y_test, y_pre) score_data = pd.DataFrame(prfs, index=["precision", "recall", "fscore", "support"]) print(score_data)""" # 使用混淆矩阵计算查准率 precision = np.diag(cm) / np.sum(cm, axis=0) # 计算召回率 recall = np.diag(cm) / np.sum(cm, axis=1) # 计算f1 f1_score = 2*precision*recall / (precision+recall) support = np.sum(cm, axis=1) support_all = np.sum(cm) accuracy = np.sum(np.diag(cm)) / support_all weight = support /support_all # 宏查准率, 宏查全率, 宏F1 macro_avg = [precision.mean(), recall.mean(), f1_score.mean()] # 加权查准率,查全率。F1 weight_avg = [np.sum(weight*precision), np.sum(weight*recall), np.sum(weight*f1_score)] metrics1 = pd.DataFrame(np.array([precision, recall, f1_score, support]).T, columns=["precision", "recall", "f1_score", "support"]) metrics2 = pd.DataFrame([["", "", "", ""], ["", "", accuracy, support_all], np.hstack([macro_avg, support_all]), np.hstack([weight_avg, support_all])], columns=["precision", "recall", "f1_score", "support"]) metrics_total = pd.concat([metrics1, metrics2], ignore_index=False) print(metrics_total)
这篇关于统计学习:模型评估与选择--查准率与查全率(python代码)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-21Python编程基础教程
- 2024-11-20Python编程基础与实践
- 2024-11-20Python编程基础与高级应用
- 2024-11-19Python 基础编程教程
- 2024-11-19Python基础入门教程
- 2024-11-17在FastAPI项目中添加一个生产级别的数据库——本地环境搭建指南
- 2024-11-16`PyMuPDF4LLM`:提取PDF数据的神器
- 2024-11-16四种数据科学Web界面框架快速对比:Rio、Reflex、Streamlit和Plotly Dash
- 2024-11-14获取参数学习:Python编程入门教程
- 2024-11-14Python编程基础入门