【scikit-learn基础】--『监督学习』之 K-近邻分类
2024/1/4 14:02:53
本文主要是介绍【scikit-learn基础】--『监督学习』之 K-近邻分类,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
KNN
(K-近邻),全称K-Nearest Neighbors
,是一种常用的分类算法。KNN
算法的历史可以追溯到1957年,当时Cover和Hart提出了“最近邻分类”的概念。
但是,这个算法真正得到广泛认知和应用是在1992年,由Altman发表的一篇名为“K-Nearest Neighbors
”的文章。
近年来,随着大数据和机器学习的快速发展,KNN
算法因其简单且表现优秀,被广泛应用于各种数据分类问题中。
1. 算法概述
KNN
算法的基本原理是:在特征空间中,如果一个样本的最接近的k个邻居中大多数属于某一个类别,则该样本也属于这个类别。
换句话说,KNN
算法假设类别是由其邻居决定的。
那么,KNN
算法判断数据是否相似是关键,也就是数据之间的距离是如何计算的呢?
最常用的距离计算公式有:
- 曼哈顿距离:\(L_1(x_i,x_j)= \sum_{l=1}^{n} |x_i^{(l)}-x_j^{(l)}|\)
- 欧氏距离:\(L_2(x_i,x_j) = (\sum_{l=1}^{n} \; |x_i^{(l)}-x_j^{(l)}|^{2})^{\frac{1}{2}}\)
- 闵可夫斯基距离:\(L_p(x_i,x_j) = (\sum_{l=1}^{n} \; |x_i^{(l)}-x_j^{(l)}|^{2})^{\frac{1}{p}}\)
- 等等
使用不同的距离,就会得到不同的分类效果。
2. 创建样本数据
这次用scikit-learn
中的样本生成器make_classification
来生成分类用的样本数据。
import matplotlib.pyplot as plt from sklearn.datasets import make_classification # 分类数据的样本生成器 X, y = make_classification(n_samples=1000, n_classes=4, n_clusters_per_class=1) plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=25) plt.show()
关于样本生成器的详细内容,请参考:TODO
3. 模型训练
首先,分割训练集和测试集。
from sklearn.model_selection import train_test_split # 分割训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
这次按照8:2的比例来划分训练集和测试集。
然后用scikit-learn
中的KNeighborsClassifier
模型来训练:
from sklearn.neighbors import KNeighborsClassifier # 定义KNN模型(设置4个分类,因为样本数据是4个分类) reg = KNeighborsClassifier(n_neighbors=4) # 训练模型 reg.fit(X_train, y_train) # 在测试集上进行预测 y_pred = reg.predict(X_test)
KNeighborsClassifier
的主要参数包括:
-
n_neighbors:这是
kNN
算法中的k值,即选择最近的k个点。默认值为5。 -
weights:此参数默认为'
uniform
',也可以设置为'distance
',或者用户自定义的函数。其中,'uniform
'表示所有的邻近点的权重都是相等的,'distance
'表示距离近的点比距离远的点的影响大。 -
algorithm:此参数默认为'
auto
',也可以设置为'auto
','ball_tree
','kd_tree
',或'brute
'。这决定了在计算最近邻时使用的算法。 - leaf_size:此参数默认为30,也可以设置为一个整数,用于指定传递给构建叶子节点时使用的最小样本数。
-
p:此参数默认为2,也可以设置为一个值<=1。这决定了在计算
Minkowski
距离度量时使用的p值。 -
metric:此参数默认为'
minkowski
',也可以设置为'euclidean
','manhattan
'等。这决定了距离度量方法。 - metric_params:此参数默认为None,也可以是一个字典,包含了额外的关键字参数传递给距离度量函数。
- n_jobs:此参数默认为None,也可以是一个大于等于1的整数,表示可用于执行并行计算的CPU数量。
最后验证模型的训练效果:
# 比较测试集中有多少个分类预测正确 correct_pred = np.sum(y_pred == y_test) print("预测正确率:{}%".format(correct_pred/len(y_pred)*100)) # 运行结果 预测正确率:68.5%
模型使用了默认的参数,可以看出,模型正确率不高。
感兴趣的同学可以试试调整KNeighborsClassifier
的参数,看看是否可以提高模型的预测正确率。
4. 总结
KNN
算法被广泛应用于各种不同的应用场景,如图像识别、文本分类、垃圾邮件识别、客户流失预测等。
这些场景的一个共同特点是,需要对一个未知的样本进行快速的分类或预测。
KNN
算法主要优势在于:
-
简单直观:
KNN
算法的概念简单直观,容易理解和实现。 -
适用于小样本数据:
KNN
算法在小样本数据上的表现往往优于其他机器学习算法。 -
对数据预处理要求较低:
KNN
算法不需要对数据进行复杂的预处理,例如标准化、归一化等。
不过,KNN
算法也有不足之处:
-
计算量大:对于大型数据集,
KNN
算法可能需要大量的存储空间和计算时间,因为需要计算每个样本与所有已知样本的距离。 - 选择合适的K值困难:K值的选择对结果影响很大,选择不当可能会导致结果的不稳定。
-
对噪声数据敏感:如果数据集中存在噪声数据,
KNN
算法可能会受到较大影响。
这篇关于【scikit-learn基础】--『监督学习』之 K-近邻分类的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23增量更新怎么做?-icode9专业技术文章分享
- 2024-11-23压缩包加密方案有哪些?-icode9专业技术文章分享
- 2024-11-23用shell怎么写一个开机时自动同步远程仓库的代码?-icode9专业技术文章分享
- 2024-11-23webman可以同步自己的仓库吗?-icode9专业技术文章分享
- 2024-11-23在 Webman 中怎么判断是否有某命令进程正在运行?-icode9专业技术文章分享
- 2024-11-23如何重置new Swiper?-icode9专业技术文章分享
- 2024-11-23oss直传有什么好处?-icode9专业技术文章分享
- 2024-11-23如何将oss直传封装成一个组件在其他页面调用时都可以使用?-icode9专业技术文章分享
- 2024-11-23怎么使用laravel 11在代码里获取路由列表?-icode9专业技术文章分享
- 2024-11-22怎么实现ansible playbook 备份代码中命名包含时间戳功能?-icode9专业技术文章分享