K最近邻(KNN)模型是一种简单且有效的机器学习算法,主要用于分类和回归任务。KNN的基本思想是:一个样本的类别或数值取决于其最邻近的k个邻居的类别或数值。具体来说,KNN通过计算待分类点与训练数据集中所有点的距离,选择距离最近的k个点作为邻居,然后根据这些邻居的类别或数值来预测该点的类别或数值。
KNN模型的工作原理
KNN算法的工作流程如下:
距离度量:KNN使用距离度量(如欧几里得距离、曼哈顿距离等)来确定数据点之间的相似性。最常见的距离度量是欧几里得距离。
选择K值:K是一个超参数,表示在进行决策时考虑的最近邻居的数量。K的选择对模型的性能有很大影响,一般通过网格搜索等方法来确定最佳的K值。
决策:对于分类任务,KNN算法通过多数投票法来预测新样本的类别;对于回归任务,则计算K个最近邻居的目标值的平均值作为新样本的目标值。
KNN模型的优缺点
优点:
简单易懂:KNN算法概念简单,容易理解和实现。
无需训练阶段:KNN没有显式的训练阶段,只需存储训练数据集,因此在数据集规模和特征数量相同的条件下,建模训练速度较快。
适用于小数据集:对于小数据集,KNN通常表现良好。
缺点:
计算量大:对于大规模数据集,计算每个测试样本与训练集中所有样本的距离非常耗时。
内存消耗大:需要存储所有训练数据以进行距离计算。
对异常值敏感:近邻的选择对异常值非常敏感,可能导致模型不稳定。
应用场景和实际案例
KNN算法在实际应用中有广泛的应用,例如:
文本分类:用于将文本数据分类到不同的类别中。
图像识别:通过比较图像特征来识别图像内容。
推荐系统:根据用户的购买或浏览历史推荐相关商品或内容。
今天我们仍以熟悉的示例数据集为例,演示一下python中K最近邻(KNN)模型的基本操作以及ROC曲线、混淆矩阵评价。
#加载程序包(openpyxl和pandas等)
#使用pandas读取示例数据xlsx文件
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 加载数据集
dataknn = pd.read_excel(r'C:\Users\L\Desktop\示例数据.xlsx')
# 查看前几行数据
print(dataknn.head())
# 分离特征和目标变量
X = dataknn[['指标1', '指标2', '指标3','指标4','指标5','指标6']]
y = dataknn['结局']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建并训练模型
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
# 预测和评估模型
predictions = knn.predict(X_test)
accuracy = knn.score(X_test, y_test)
print(f"Accuracy: {accuracy}")
print("预估值:\n",predictions)
acc = sum(predictions == y_test) / predictions.shape[0]
print("预测的准确率ACC: %.2f%%" % (acc*100))
#混淆矩阵评估模型
#导入第三方模块
from sklearn import metrics
# 混淆矩阵
print("混淆矩阵四格表输出如下:")
print(metrics.confusion_matrix(y_test, predictions, labels = [0, 1]))
Accuracy = metrics._scorer.accuracy_score(y_test, predictions)
Sensitivity = metrics._scorer.recall_score(y_test, predictions)
Specificity = metrics._scorer.recall_score(y_test, predictions, pos_label=0)
print("KNN模型混淆矩阵评价结果如下:")
print('模型准确率为%.2f%%' %(Accuracy*100))
print('正例覆盖率为%.2f%%' %(Sensitivity*100))
print('负例覆盖率为%.2f%%' %(Specificity*100))
# 使用Seaborn的heatmap绘制混淆矩阵
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
sns.heatmap(metrics.confusion_matrix(y_test, predictions), annot=True, fmt='d')
plt.title('Confusion Matrix')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.show()
#进行ROC曲线绘制计算准备
# у得分为模型预测正例的概率
y_score =knn.predict_proba(X_test)[:,1]
#计算不同阈值下,fpr和tpr的组合值,其中fpr表示1-Specificity,tpr表示sensitivity
fpr,tpr,threshold =metrics.roc_curve(y_test,y_score)
# 计算AUC的值
roc_auc = metrics.auc(fpr,tpr)
print("KNN模型预测测试集数据ROC曲线的AUC:",roc_auc)
KNN模型预测测试集数据ROC曲线的AUC: 0.9356444444444444
#绘制ROC曲线
import matplotlib.pyplot as plt
import seaborn as sns
# 绘制面积图
plt.stackplot(fpr, tpr, color='steelblue', alpha = 0.5,edgecolor = 'black')
# 添加边际线
plt.plot(fpr, tpr, color='black',lw = 1)
# 添加对角线
plt.plot([0,1],[0,1],color ='red',linestyle ='--')
# 添加文本信息
plt.text(0.5,0.3,'Roc curve(area =%.2f)'% roc_auc)
# 添加轴标签
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
# 显示图形
plt.show()
医学统计数据分析分享交流SPSS、R语言、Python、ArcGis、Geoda、GraphPad、数据分析图表制作等心得。承接数据分析,论文修回,医学统计,空间分析,问卷分析业务。若有投稿和数据分析代做需求,可以直接联系我,谢谢!