审稿人:贝叶斯分类模型不是“画圈圈”,跟判别模型其实不一样!

文摘   2024-12-11 17:31   爱尔兰  

很多初学者在接触分类和判别模型时,会产生一个常见的误解:认为建模就是在数据分布上“画圈圈”。按照这个逻辑呢,大家发现,分类模型和判别模型最终都可以画出决策边界或分布圈,便错误地以为二者相同。但实际上,这两类模型在思想和方法上有着根本区别。

分类模型着眼于整个数据的生成机制,试图回答“这个数据是如何生成的”。它通过学习数据的联合分布来实现分类。判别模型则更专注于如何区分不同类别的数据,直接学习输入和输出之间的条件关系,目标是找到最优的分类边界。

在我们之前的教程中,我们讲解过许多判别模型,比如 PLS-DA (审稿人:用R做的PLS-DA,嗨,你小子还真是个天才!。今天,我们将更新内容,转向分类模型的世界,从一个简单但非常经典的模型——朴素贝叶斯(Naive Bayes) 开始讲解。

朴素贝叶斯分类模型:R语言实现

1. 加载并探索数据集

我们选择R内置的 iris 数据集进行演示。

# 加载数据集
data(iris)

# 查看数据集的基本信息
str(iris)

## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

summary(iris)

## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
## Median :5.800 Median :3.000 Median :4.350 Median :1.300
## Mean :5.843 Mean :3.057 Mean :3.758 Mean :1.199
## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
## Species
## setosa :50
## versicolor:50
## virginica :50
##
##
##

# 绘制数据分布
pairs(iris[, -5], col = iris$Species, main = "Iris 数据集的特征分布")

2. 划分训练集和验证集

为了评估模型性能,我们需要将数据分为训练集和验证集。

set.seed(123) # 设置随机种子保证结果可复现
library(caTools)

# 划分数据
split <- sample.split(iris$Species, SplitRatio = 0.7)
train_data <- subset(iris, split == TRUE)
test_data <- subset(iris, split == FALSE)

# 检查数据划分结果
table(train_data$Species)

##
## setosa versicolor virginica
## 35 35 35

table(test_data$Species)

##
## setosa versicolor virginica
## 15 15 15

3. 构建朴素贝叶斯模型

我们使用 e1071 包中的 naiveBayes 函数构建模型。

# 加载包
library(e1071)

# 构建朴素贝叶斯模型
nb_model <- naiveBayes(Species ~ ., data = train_data)

# 查看模型
print(nb_model)

##
## Naive Bayes Classifier for Discrete Predictors
##
## Call:
## naiveBayes.default(x = X, y = Y, laplace = laplace)
##
## A-priori probabilities:
## Y
## setosa versicolor virginica
## 0.3333333 0.3333333 0.3333333
##
## Conditional probabilities:
## Sepal.Length
## Y [,1] [,2]
## setosa 4.940000 0.3541352
## versicolor 5.920000 0.5166635
## virginica 6.634286 0.5422952
##
## Sepal.Width
## Y [,1] [,2]
## setosa 3.405714 0.3685766
## versicolor 2.777143 0.3144423
## virginica 2.925714 0.2831990
##
## Petal.Length
## Y [,1] [,2]
## setosa 1.445714 0.1930298
## versicolor 4.217143 0.4462166
## virginica 5.565714 0.5075563
##
## Petal.Width
## Y [,1] [,2]
## setosa 0.2428571 0.1092372
## versicolor 1.3114286 0.1827429
## virginica 2.0428571 0.2714728

4. 评价模型性能:更多指标

除了混淆矩阵和准确率,我们还可以计算更多评价指标,例如精确率(Precision)、召回率(Recall)、F1分数等。

# 加载包
library(caret)

# 模型预测
predictions <- predict(nb_model, test_data)

# 生成混淆矩阵并计算指标
confusion_matrix <- confusionMatrix(predictions, test_data$Species)

# 查看指标
print(confusion_matrix)


最后,通过可视化工具展示分类结果与决策边界,帮助理解模型的实际表现。

# 加载绘图包
library(ggplot2)

# 1. PCA 降维
pca <- prcomp(iris[, -5], center = TRUE, scale. = TRUE)

# 2. 提取训练数据的 PCA 坐标
pca_train <- data.frame(pca$x[1:nrow(train_data), 1:2], Species = train_data$Species)
colnames(pca_train) <- c("PC1", "PC2", "Species")

# 3. 提取测试数据的 PCA 坐标
pca_test <- predict(pca, test_data[, -5])
pca_test <- data.frame(PC1 = pca_test[, 1], PC2 = pca_test[, 2],
Actual = test_data$Species, Predicted = predictions)

# 4. 绘制分类结果
ggplot() +
# 绘制训练数据点
geom_point(data = pca_train, aes(x = PC1, y = PC2, color = Species), size = 3, alpha = 0.7) +
# 绘制测试数据点
geom_point(data = pca_test, aes(x = PC1, y = PC2, shape = Predicted), size = 4, alpha = 0.8) +
labs(title = "分类结果与实际类别对比",
x = "主成分1 (PC1)",
y = "主成分2 (PC2)") +
theme_minimal() +
scale_color_brewer(palette = "Set1") +
scale_shape_manual(values = c(16, 17, 18)) # 为不同分类结果设置不同的形状

## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 12 2
## virginica 0 3 13
##
## Overall Statistics
##
## Accuracy : 0.8889
## 95% CI : (0.7595, 0.9629)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 1.408e-14
##
## Kappa : 0.8333
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.8000 0.8667
## Specificity 1.0000 0.9333 0.9000
## Pos Pred Value 1.0000 0.8571 0.8125
## Neg Pred Value 1.0000 0.9032 0.9310
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.2667 0.2889
## Detection Prevalence 0.3333 0.3111 0.3556
## Balanced Accuracy 1.0000 0.8667 0.8833
# 提取每类精确率、召回率和 F1 分数
precision <- confusion_matrix$byClass[, "Precision"]
recall <- confusion_matrix$byClass[, "Recall"]
f1_score <- confusion_matrix$byClass[, "F1"]

# 打印结果
cat("精确率 (Precision): ", precision, "\n")
## 精确率 (Precision):  1 0.8571429 0.8125
cat("召回率 (Recall): ", recall, "\n")
## 召回率 (Recall):  1 0.8 0.8666667
cat("F1分数 (F1 Score): ", f1_score, "\n")
## F1分数 (F1 Score):  1 0.8275862 0.8387097


5. 分类结果的可视化

最后,通过可视化工具展示分类结果与决策边界,帮助理解模型的实际表现。

# 加载绘图包
library(ggplot2)

# 1. PCA 降维
pca <- prcomp(iris[, -5], center = TRUE, scale. = TRUE)

# 2. 提取训练数据的 PCA 坐标
pca_train <- data.frame(pca$x[1:nrow(train_data), 1:2], Species = train_data$Species)
colnames(pca_train) <- c("PC1", "PC2", "Species")

# 3. 提取测试数据的 PCA 坐标
pca_test <- predict(pca, test_data[, -5])
pca_test <- data.frame(PC1 = pca_test[, 1], PC2 = pca_test[, 2],
Actual = test_data$Species, Predicted = predictions)

# 4. 绘制分类结果
ggplot() +
# 绘制训练数据点
geom_point(data = pca_train, aes(x = PC1, y = PC2, color = Species), size = 3, alpha = 0.7) +
# 绘制测试数据点
geom_point(data = pca_test, aes(x = PC1, y = PC2, shape = Predicted), size = 4, alpha = 0.8) +
labs(title = "分类结果与实际类别对比",
x = "主成分1 (PC1)",
y = "主成分2 (PC2)") +
theme_minimal() +
scale_color_brewer(palette = "Set1") +
scale_shape_manual(values = c(16, 17, 18)) # 为不同分类结果设置不同的形状



通过上述可视化,我们可以直观观察分类模型的效果,并对分类错误的情况有更清晰的认识。

总结

今天以朴素贝叶斯模型为例,我们完整演示了分类模型从数据加载到模型评价的全过程。这样的分析套路可以引申到你需要分析的任何分类模型,因为即使模型的种类不同,但解决问题的思路是一致的。欢迎使用你自己的数据来练习。后续,我们将进一步探索其他分类模型,提升对分类任务的全面理解!

感谢关注,你的支持是我不懈的动力!

科研代码
专注R和Python的数据分析。
 最新文章