导师:你怎么连PLS-DA都不会?奥,我也没教过你。

文摘   2024-11-17 09:01   英国  

PLS-DA(Partial Least Squares Discriminant Analysis)是一种常用的分类方法,适用于在变量之间存在共线性的高维数据中进行分类。在这篇文章中,我们将使用Python来完成PLS-DA分类分析,通过调用sklearn和matplotlib库,对经典的鸢尾花数据集进行建模、预测和可视化。希望这篇教程能帮助大家深入理解PLS-DA的分析流程。

一、什么是PLS-DA?

PLS-DA是一种基于偏最小二乘法的判别分析技术。不同于传统的PCA(主成分分析),PLS-DA将类别标签引入模型,通过最大化类别之间的差异来优化分类效果,尤其适用于生物信息学和化学数据中的分类任务。

二、准备工作

在进行分析前,我们需要安装scikit-learnmatplotlib库。您可以通过以下命令安装:

pip install scikit-learn matplotlib

三、导入数据和库

我们将使用sklearn库中的datasets模块导入经典的鸢尾花数据集。该数据集包含了三种不同鸢尾花(Setosa、Versicolor、Virginica)的花瓣和花萼长度、宽度,是分类分析的常用数据集。

# 导入所需库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

导入数据

# 加载iris数据集
iris = datasets.load_iris()
X = iris.data # 特征数据
y = iris.target # 类别标签
target_names = iris.target_names # 类别名称

四、数据预处理

为了提高模型的表现,我们通常需要对数据进行标准化处理,以确保特征的尺度一致。此外,我们将数据集划分为训练集和测试集(70%训练,30%测试)。

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

五、构建PLS-DA模型

我们使用PLSRegression类来实现PLS-DA。PLS-DA要求将类别标签转换为哑变量(dummy variables),以适应模型输入要求。

# 将类别标签y转换为哑变量(One-hot编码)
y_train_dummies = pd.get_dummies(y_train).values
y_test_dummies = pd.get_dummies(y_test).values

# 初始化PLS-DA模型并选择成分数
plsda = PLSRegression(n_components=2)
plsda.fit(X_train, y_train_dummies)


在这里,我们将成分数设置为2,以便于后续可视化。成分数的选择可以通过交叉验证来确定。


六、模型预测与评估

在模型训练完成后,我们可以使用测试集进行预测,并计算模型的准确率和混淆矩阵,以评估模型的分类效果。
# 使用测试集进行预测
y_pred = plsda.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)

# 计算模型的准确率
accuracy = accuracy_score(y_test, y_pred_classes)
print("模型准确率:", round(accuracy * 100, 2), "%")
## 模型准确率: 91.11 %
# 计算混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred_classes)
print("混淆矩阵:\n", conf_matrix)
## 混淆矩阵:
## [[18 0 0]
## [ 0 10 0]
## [ 0 4 13]]
此处的混淆矩阵可以直观显示每种鸢尾花的预测结果,而准确率则展示了模型在测试集上的总体分类效果。

七、结果可视化

我们可以通过将数据在前两个成分空间中进行投影,来可视化PLS-DA的分类效果。下图展示了三类鸢尾花在成分1和成分2上的分布情况,帮助我们直观了解分类效果。
# 提取PLS-DA的成分得分
X_train_scores = plsda.transform(X_train)

# 绘制PLS-DA分类结果
plt.figure(figsize=(10, 6))
for i, target_name in enumerate(target_names):
plt.scatter(X_train_scores[y_train == i, 0],
X_train_scores[y_train == i, 1],
label=target_name)

plt.xlabel("F 1")
plt.ylabel("F 2")
plt.title("PLS-DA plot")
plt.legend()
plt.grid(True)
plt.show()
在该图中,不同类别的鸢尾花在前两个成分上表现出较好的分离效果,可以看到PLS-DA较好地实现了分类目标。

八、总结

本文详细介绍了如何使用Python进行PLS-DA分析,从数据预处理、建模到结果可视化的完整流程。PLS-DA的优势在于可以处理多变量共线性和高维数据,适合生物信息、化学等领域的分类分析。希望大家能够从中学习到更多分类分析的实用技巧!

九、完整代码

为了方便大家复制粘贴运行,以下是本次分析的完整代码:
# 导入所需库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# 加载iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
target_names = iris.target_names

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 将类别标签转换为哑变量
y_train_dummies = pd.get_dummies(y_train).values
y_test_dummies = pd.get_dummies(y_test).values

# 初始化PLS-DA模型
plsda = PLSRegression(n_components=2)
plsda.fit(X_train, y_train_dummies)

# 使用测试集预测
y_pred = plsda.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)

# 模型评估
accuracy = accuracy_score(y_test, y_pred_classes)
print("模型准确率:", round(accuracy * 100, 2), "%")
conf_matrix = confusion_matrix(y_test, y_pred_classes)
print("混淆矩阵:\n", conf_matrix)

# 可视化结果
X_train_scores = plsda.transform(X_train)
plt.figure(figsize=(10, 6))
for i, target_name in enumerate(target_names):
plt.scatter(X_train_scores[y_train == i, 0],
X_train_scores[y_train == i, 1],
label=target_name)

plt.xlabel("F1")
plt.ylabel("F2")
plt.title("PLS-DA plot")
plt.legend()
plt.grid(True)
plt.show()
以上代码演示了如何使用Python实现PLS-DA的分类分析。希望大家通过这篇教程可以掌握PLS-DA的基本用法,并能够在自己的数据中实践。
点赞,收藏,转发,一键三连啊!

科研代码
专注R和Python的数据分析。
 最新文章