审稿人竟然说SVM只能用来二元分类?绷不住了

文摘   2024-08-13 08:19   爱尔兰  
支持向量机(SVM)在其最基本形式中确实是用于二分类问题的。原理上,SVM 是通过寻找一个能够最大化类间间隔的超平面来分离两类数据点,因此基础的 SVM 模型只针对二分类任务。这个基本事实是为什么有些人会说 SVM 只能用于二分类的原因。但是,各种算法早已赋予了SVM判别多元分类的能力,每个人的知识都应该与时俱进!

1. SVM 如何处理二分类问题

在二分类问题中,SVM 寻找一个线性决策边界(超平面)来最大化两类数据点之间的间隔(margin)。支持向量是离决策边界最近的数据点,SVM 通过这些支持向量来确定最优的决策边界。所有数据点根据其在决策边界的哪一侧被分类为相应的类。

2. SVM 如何处理多分类问题

虽然基础的 SVM 只能处理二分类问题,但在实践中,SVM 可以通过扩展技术来处理多分类问题。最常用的两种方法是:
  1. 一对多(One-vs-Rest, OvR):
  2. 方法将多分类问题拆解为多个二分类问题。假设有 n个类别,OvR 会创建 n个二分类模型,每个模型将一个类别作为正类,其他类别作为负类。最终,使用所有模型的预测结果选择得分最高的类别作为最终分类结果。
  3. 一对一(One-vs-One, OvO):
  • OvO 方法将多分类问题拆解为所有类别之间的二分类组合。每个模型处理两个类别之间的分类问题。最终的分类结果由多数投票决定。

3. 实际应用中的 SVM 多分类

scikit-learn 中的 SVM 实现已经内置了对多分类问题的支持。用户可以直接应用 SVM 模型进行多分类,而无需手动构建 OvR 或 OvO 模型。例如:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# 1. 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2] # 取前两个特征(为了便于可视化)
y = iris.target

# 2. 数据集分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 3. 创建 SVM 模型,并选择线性核函数(支持多分类)

model = SVC(kernel='linear', decision_function_shape='ovr')

# ovr 是默认参数

model.fit(X_train, y_train)


# 4. 预测与评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')
## Accuracy: 0.80
# 5. 可视化决策边界
def plot_decision_boundary(model, X, y):
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
np.arange(y_min, y_max, 0.01))

Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.Paired)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o', cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('SVM Decision Boundary for Multi-Class Classification')
plt.show()

# 6. 绘制训练集的决策边界
plot_decision_boundary(model, X_train, y_train)

# 7. 绘制测试集的决策边界
plot_decision_boundary(model, X_test, y_test)

4. 总结

尽管基础的 SVM 是一个二分类模型,但通过一对多(OvR)或一对一(OvO)的方法,SVM 可以扩展用于多分类问题。现代机器学习库(如 scikit-learn)通常已经内置了这些多分类扩展,使得用户可以直接应用 SVM 进行多分类任务。
因此,虽然 SVM 本质上是一个二分类工具,但在实际应用中,它完全可以用于多分类问题。


感谢关注

科研代码
专注R和Python的数据分析。
 最新文章