机器学习分类模型的性能衡量

科技   2024-09-27 13:41   广东  

 今天是生信星球陪你的第994天


   

公众号里的文章大多数需要编程基础,如果因为代码看不懂,而跟不上正文的节奏,可以来找我学习,相当于给自己一个新手保护期。我的课程都是循环开课,点进去咨询微信↓

生信分析直播课程(9月30日下一期)

生信新手保护学习小组(10月初下一期)

单细胞陪伴学习小组(10月初下一期

目录
  • 1.衡量模型的准确程度

  • 2.训练集与测试集的拆分

  • 3.模型的复杂程度

  • 4.可视化k值和预测准确率的关系

1.衡量模型的准确程度

准确率(Accuracy)即:预测正确的样本数量/样本
可以用构建模型的数据来计算准确率,但这个准确率不能代表模型泛化到其其他数据的准确率。
所以将数据拆分开,一部分用于拟合模型(训练集),另一部分用于衡量模型性能(测试集),是更好的选择。

2.训练集与测试集的拆分

将数据拆分开,用训练集构建模型,用测试集作为自变量输入,计算它的准确率。
from sklearn.model_selection import train_test_split 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3
                                                    random_state=21, stratify=y) 
knn = KNeighborsClassifier(n_neighbors=6
knn.fit(X_train, y_train) 
print(knn.score(X_test, y_test)) 
## 0.8733333333333333
test_size=0.3是测试集占所有样本总数的30%
random_state是设置了随机种子,让结果可重复
stratify=y是控制数据拆分的均匀程度,希望训练集和测试集的0/1比例相同,再详细解释一下就是:假如有30%是1,这样拆分后,训练集和测试集都有30%的1。

3.模型的复杂程度

k值越大,结果受周围噪声的影响就越小。
Larger k = less complex model = can cause underfitting  (欠拟合)
Smaller k = more complex model = can lead to overfitting  (过拟合)
学到这里,没自动理解到为什么。kimi的解释如下:
在KNN算法中,k值的选择对模型的性能有显著影响。k值代表在进行分类决策时考虑的最近邻居的数量。如果k值选择较小,模型可能会更复杂,因为它对训练数据中的噪声更敏感,这可能导致过拟合。相反,如果k值选择较大,模型可能会变得更简单,因为它会平滑决策边界,可能会忽略一些重要的模式,从而导致欠拟合。

4.可视化k值和预测准确率的关系

先用for循环计算当k等于1-25时,训练集和测试集的准确率
train_accuracies = {} 
test_accuracies = {} 
neighbors = np.arange(126# 整数1-25
for neighbor in neighbors: 
    knn = KNeighborsClassifier(n_neighbors=neighbor) 
    knn.fit(X_train, y_train) 
    train_accuracies[neighbor] = knn.score(X_train, y_train) 
    test_accuracies[neighbor] = knn.score(X_test, y_test)
我自行加了几句代码,查看计算结果
import pandas as pd
ac = pd.DataFrame([train_accuracies,test_accuracies]).transpose()
ac.columns = ['train_accuracies','test_accuracies']
ac.reset_index(inplace=True)
ac.head()
然后绘制折线图
plt.figure(figsize=(86)) 
plt.title("KNN: Varying Number of Neighbors"
plt.plot(neighbors, train_accuracies.values(), label="Training Accuracy"
plt.plot(neighbors, test_accuracies.values(), label="Testing Accuracy"
plt.legend() 
plt.xlabel("Number of Neighbors"
plt.ylabel("Accuracy"
plt.show() 

虽然训练集和测试集的准确率都计算并画出来了,但主要还是看测试集的指标表现。

由图上可见,随着K值的增加,测试集的准确率先螺旋上升,然后趋于稳定。


生信星球
一个零基础学生信的平台-- 原创结构化图文/教程,精选阶段性资料,带你少走弯路早入门,收获成就感,早成生信小能手~
 最新文章