哈喽,我是cos大壮!~
上次咱们聊了 SVM 的内容,今儿想和大家讨论一下关于 SVR 的内容,即:支持向量回归~
SVR 是一种基于支持向量机(SVM)的回归模型,用来解决回归问题。它的目标是找到一个最优的回归线(或高维空间中的超平面),使得大多数数据点离这条线的距离都在一定的容忍范围内。
文末可取本文PDF版本~
首先,我们用一个简单的例子来解释。
例子:预测冰淇淋的销量
假设你是一个冰淇淋店老板,你想根据天气的温度来预测冰淇淋的销量。你有过去几天的记录,显示每天的气温和对应的冰淇淋销量。
假设数据是这样的:
25°C,卖出了200个冰淇淋 30°C,卖出了300个冰淇淋 35°C,卖出了400个冰淇淋 40°C,卖出了500个冰淇淋
目标:找到一个「回归线」
我们希望通过这些数据找到一个回归线,它能根据温度来预测销量。这条线应该尽量靠近所有的数据点。
传统的回归(比如线性回归)会尝试画一条线,让每个数据点和这条线之间的误差尽可能小。但在支持向量回归中,我们允许有些误差,只要这些误差在一个「容忍范围」内即可。
SVR的「容忍范围」
在SVR中,有一个概念叫做epsilon带。你可以想象这是一条回归线两边的一个带状区域,数据点只要落在这个区域内,都是可以接受的(即使它们不完全落在回归线上)。
例如,我们可以允许有±50个冰淇淋的误差,所以如果在30°C时,我们的回归模型预测了250到350个冰淇淋,这都可以被接受。这就是epsilon带的作用。
如何找到最优的回归线?
SVR的目标是找到一条线,使得大部分数据点都尽量落在epsilon带内,同时我们也希望这条线尽量「平滑」,即避免过度弯曲(太复杂的模型)。
具体来说,SVR要找到一个支持向量(离回归线最远但仍在epsilon带边界上的数据点),这些支持向量决定了回归线的位置。然后通过优化算法找到一个既符合数据规律、又尽量简单的模型。
总结几点
回归线:SVR试图找到一条线来预测连续变量(比如冰淇淋销量)。 epsilon带:允许一定范围的误差,数据点可以离回归线有一定距离,只要在这个带内都是可以接受的。 支持向量:决定这条回归线的关键点。
通过SVR,你可以建立一个模型来预测某个温度下的冰淇淋销量,而这个模型既不容易过度拟合(即太复杂),又能有效处理一定程度的数据噪声。
有了上面的认识,下面,我们通过具体的公式和案例再和大家详细聊聊~
原理和案例
需要我们首先从理论上解释其核心部分,然后再逐步实现,并通过数据可视化来展示它的性能。由于SVR基于支持向量机(SVM)的思想,我们将从线性回归的优化问题逐步推导到SVR。
1. SVR的公式推导
1. 问题定义:
SVR的目标是在给定的训练数据 ,其中 是输入, 是输出,找到一个函数 来近似这些数据。
2. 优化目标:
我们希望找到一个回归函数 ,使得它能预测输出,并且允许有一些小的误差。为了做到这一点,我们引入了ε-insensitive loss,这个损失函数的目标是让误差在 范围内忽略不计。
优化问题可以写为:
其中, 是为了使模型尽量平滑(即避免过度拟合),约束条件是为了控制误差在 范围内。
3. 引入松弛变量:
为了允许有些点不在ε带内,我们引入松弛变量 和 来表示超过ε带的偏差。新的优化问题变为:
其中, 是一个超参数,控制模型对误差的容忍度。
4. 对偶问题:
通过拉格朗日乘子法,可以将上面的优化问题转换为对偶形式。最后得到的决策函数为:
其中, 是核函数, 和 是拉格朗日乘子。
2. 手动实现SVR
我们将使用Kaggle中的「汽车燃油效率」数据集(或类似的数据集),从头实现SVR的训练过程,并进行数据可视化分析。
数据集获取:点击名片,回复「数据集」即可~
步骤:
1. 加载数据集
2. 数据预处理
3. 定义SVR的训练过程
4. 可视化分析
1. 加载与预处理数据
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 加载数据
data = pd.read_csv('auto-mpg.csv')
# 数据预处理
data = data[['mpg', 'horsepower', 'weight', 'acceleration']]
data = data.dropna() # 删除缺失值
# 将'horsepower'列转换为数值型,并处理无效数据
data['horsepower'] = pd.to_numeric(data['horsepower'], errors='coerce')
data = data.dropna() # 处理转换后可能出现的缺失值
X = data[['horsepower', 'weight', 'acceleration']].values
y = data['mpg'].values
# 标准化输入数据
X = (X - X.mean(axis=0)) / X.std(axis=0)
y = (y - y.mean()) / y.std()
2. 定义SVR的训练过程
为了简单,我们实现一个线性SVR,不使用现有的库来直接调用SVR算法。
class SVR:
def __init__(self, C=1.0, epsilon=0.1, lr=0.001, max_iter=1000):
self.C = C
self.epsilon = epsilon
self.lr = lr
self.max_iter = max_iter
def fit(self, X, y):
n_samples, n_features = X.shape
self.w = np.zeros(n_features)
self.b = 0
for _ in range(self.max_iter):
for i in range(n_samples):
if np.abs(y[i] - (np.dot(X[i], self.w) + self.b)) > self.epsilon:
if y[i] > np.dot(X[i], self.w) + self.b:
self.w += self.lr * (X[i] - self.C * self.w)
self.b += self.lr * 1
else:
self.w -= self.lr * (X[i] + self.C * self.w)
self.b -= self.lr * 1
def predict(self, X):
return np.dot(X, self.w) + self.b
3. 训练模型
# 训练SVR模型
model = SVR(C=1.0, epsilon=0.1, lr=0.001, max_iter=10000)
model.fit(X, y)
# 预测
y_pred = model.predict(X)
4. 数据可视化分析
我们将通过以下四个图表来分析数据和模型表现:
1. 原始数据的分布
2. 训练后的回归线与实际数据的对比
3. 残差分布
4. 预测值与实际值的对比
# 1. 原始数据的分布
plt.figure(figsize=(8,6))
plt.scatter(X[:,0], y, color='blue', label='Actual')
plt.title('Original Data Distribution')
plt.xlabel('Horsepower')
plt.ylabel('MPG')
plt.legend()
plt.show()
# 2. 回归线与实际数据对比
plt.figure(figsize=(8,6))
plt.scatter(X[:,0], y, color='blue', label='Actual')
plt.plot(X[:,0], y_pred, color='red', label='Predicted')
plt.title('SVR Fit')
plt.xlabel('Horsepower')
plt.ylabel('MPG')
plt.legend()
plt.show()
# 3. 残差分布
residuals = y - y_pred
plt.figure(figsize=(8,6))
plt.hist(residuals, bins=20, color='green')
plt.title('Residuals Distribution')
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.show()
# 4. 预测值与实际值的对比
plt.figure(figsize=(8,6))
plt.scatter(y, y_pred, color='purple')
plt.title('Predicted vs Actual MPG')
plt.xlabel('Actual MPG')
plt.ylabel('Predicted MPG')
plt.show()
这个案例中,我们手动实现了一个简单的SVR模型,并且通过四个图形分析了模型的表现:
1. 原始数据分布 展示了我们正在处理的数据的基本特征。
2. 回归线与实际数据对比 展示了SVR模型的拟合效果。
3. 残差分布 让我们直观了解模型误差的分布情况。
4. 预测值与实际值的对比 帮助我们评估模型的预测效果。
上面的代码手动实现了 SVR,通过图形进行数据分析。在这个过程中,希望帮助大家深入理解了SVR的数学原理,并通过手动实现掌握了它的内在工作逻辑。
最后
大家有问题可以直接在评论区留言即可~
喜欢本文的朋友可以收藏、点赞、转发起来!
推荐阅读
原创、超强、精华合集 100个超强机器学习算法模型汇总 机器学习全路线 机器学习各个算法的优缺点 7大方面,30个最强数据集 6大部分,20 个机器学习算法全面汇总 铁汁,都到这了,别忘记点赞呀~