哈喽,我是cos大壮!~
咱们今儿来聊一个基础算法模型:高斯混合模型。
整篇文章从最开始的简单解释到最后的完整案例,给大家做一个详细的解释。
简单来说,高斯混合模型(Gaussian Mixture Model,简称 GMM)是一种用来处理「分类」问题的统计模型,它可以帮助我们将一堆数据点分成不同的类别。这种模型假设数据是由多个高斯分布(也叫正态分布)组成的,所以它名字里有“高斯混合”这个词。
文末可取本文PDF版本~
要理解高斯混合模型,我们先来看什么是高斯分布。你可以把高斯分布想象成一个钟形曲线,表示某个现象的概率分布,像身高、体重这样的数据通常符合高斯分布。大多数人的身高集中在平均值附近,少数人的身高偏高或偏低,形成一个钟形的概率分布。
高斯混合模型 核心思想
高斯混合模型的想法是:你手上的数据可能并不是来自一个单一的高斯分布,而是来自多个不同的高斯分布「混合」在一起的。我们需要做的就是把这些数据点分配到不同的高斯分布上,或者换句话说,把这些数据点分成几类。
举个例子
假设我们有一群不同种类的小动物,它们的体重分布有点复杂。比如,我们有兔子、猫、和狗,它们的体重在不同的范围内,但数据混在一起,我们不知道每个动物属于哪一类。我们能看到的只是每个动物的体重,但不知道它们是兔子、猫还是狗。
任务:用体重把这些动物分成不同的类别。
假设我们记录了很多动物的体重,这些数据可能会呈现三种不同的模式:一种模式对应较轻的兔子,另一种对应中等体重的猫,还有一种对应体重较大的狗。但是,我们不知道哪只动物是兔子、哪只动物是猫或狗。我们只知道这些体重数据混在一起。
这时候,高斯混合模型就派上用场了。它会假设这些体重数据是由三种不同的高斯分布混合起来的。我们的任务是利用这些数据,推测出每个动物属于哪个高斯分布(也就是哪一类:兔子、猫或狗)。
高斯混合模型怎么工作的?
初始化:首先,假设我们的数据点(动物的体重)来自几个不同的类别(兔子、猫、狗),并且为每个类别初始化一个高斯分布(其实就是初步猜测每个类别的平均体重和体重的波动范围)。
计算概率:然后,模型会计算每个数据点(每个动物的体重)分别属于不同类别的概率。例如,一个体重3公斤的动物,它更有可能是兔子,而不是狗。通过计算,模型能推断出每个动物分别属于兔子、猫、狗的概率。
更新参数:接下来,模型会根据这些概率来重新调整每个类别的高斯分布参数,找出更合适的平均值和标准差(波动范围),从而更好地解释数据。
重复迭代:通过不断重复这两个步骤(计算概率和更新参数),模型会逐渐收敛,最终把所有数据点归类到合适的类别。
经过模型的训练,最终我们得到三个类别,它们各自对应兔子、猫、和狗。每个动物被归类到不同的类别中,虽然我们一开始并不知道它们的真实类别,但通过模型的学习,我们可以比较准确地将它们分类。
你也可以把高斯混合模型想象成一个「盲人分水果」的过程。假设你有一篮子混在一起的苹果、梨和橙子(动物的体重数据),你看不见它们(不知道它们的类别),但你可以根据它们的大小和形状去猜测。高斯混合模型就是通过不停地猜测和调整,最终把苹果、梨和橙子分成三类。
总的来说,高斯混合模型就是在不知道类别的情况下,通过猜测数据点的概率分布,把它们分成几个类别。它假设数据来自多个高斯分布,并通过不断迭代,最终将这些数据点归类到不同的高斯分布中。
下面,咱们从高斯混合模型的原理和案例再来详细聊聊~
1. 高斯混合模型的公式推导
高斯分布公式
首先,我们需要了解单变量的高斯分布公式,它描述了在某个均值 和方差 下,某个数据点 出现的概率:
多元高斯分布公式为:
其中:
是 -维的特征向量 是均值向量 是协方差矩阵
高斯混合模型
高斯混合模型(GMM) 是由多个高斯分布混合而成的。假设数据集是由 个高斯分布组成的混合模型,那么给定数据点 ,它的概率分布可以表示为每个分布的加权和:
其中:
是第 个高斯分布的权重,且 表示第 个高斯分布 是模型的所有参数:均值 ,方差 ,以及混合系数
EM算法推导
由于我们不知道每个点属于哪个高斯分布,因此 GMM 采用 EM算法(期望最大化算法)来迭代估计参数。
1. E 步(期望步):计算后验概率,即每个点 属于第 个分布的概率:
其中 是点 属于第 个高斯分布的概率。
2. M 步(最大化步):根据 E 步的结果,更新 GMM 的参数:
通过多次迭代,EM算法可以让这些参数收敛到合适的值。
2. 案例实现
我们会从 Kaggle 下载一个数据集,使用 GMM 对数据进行分类。为了演示原理,我们使用一个简单的二维数据集。并且根据原理进行代码的编写。
数据集下载
选择一个简单的、适合分类的二维数据集,例如 Kaggle 上的 Iris 数据集 或 Mall Customers 数据集。我们将以 Mall Customers 为例,用顾客的年收入和消费得分来进行聚类分析。
数据集获取:点击名片,回复「数据集」即可!~
数据分析及可视化
我们要画出 4 个及以上的分析图,来逐步理解数据和模型效果。
Python 实现
下面是 Mall Customers 数据集 的完整 Python 实现:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
# 读取数据集
data = pd.read_csv("Mall_Customers.csv")
X = data[['Annual Income (k$)', 'Spending Score (1-100)']].values
# 初始化参数
def initialize_params_fixed(X, K):
n, d = X.shape
pi = np.ones(K) / K # 初始化每个混合成分的权重
mu = X[np.random.choice(n, K, False), :] # 随机选择K个初始均值
sigma = np.array([np.eye(d) for _ in range(K)]) # 初始化协方差矩阵为单位矩阵
return pi, mu, sigma
# 计算多元正态分布
def multivariate_gaussian(X, mu, sigma):
return multivariate_normal(mean=mu, cov=sigma).pdf(X)
# E 步:计算每个点属于每个成分的责任值 (gamma)
def expectation_step_stable(X, pi, mu, sigma):
N = X.shape[0]
K = len(pi)
gamma = np.zeros((N, K))
for k in range(K):
try:
gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
except np.linalg.LinAlgError:
# 如果协方差矩阵是奇异矩阵,加入微小正则化项以确保正定性
sigma[k] += np.eye(X.shape[1]) * 1e-6
gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
# 防止零除错误,保证数值稳定性
gamma_sum = np.sum(gamma, axis=1, keepdims=True)
gamma_sum[gamma_sum == 0] = 1e-10 # 防止除以零
gamma = gamma / gamma_sum
return gamma
# M 步:更新GMM的参数
def maximization_step(X, gamma):
N, d = X.shape
K = gamma.shape[1]
Nk = np.sum(gamma, axis=0) # 计算每个聚类的总责任值
pi = Nk / N # 更新混合系数
mu = np.dot(gamma.T, X) / Nk[:, np.newaxis] # 更新均值
sigma = np.zeros((K, d, d)) # 更新协方差矩阵
for k in range(K):
X_centered = X - mu[k]
gamma_diag = np.diag(gamma[:, k])
sigma[k] = np.dot(X_centered.T, np.dot(gamma_diag, X_centered)) / Nk[k]
return pi, mu, sigma
# 计算对数似然
def compute_log_likelihood(X, pi, mu, sigma):
N = X.shape[0]
K = len(pi)
log_likelihood = 0
for n in range(N):
tmp = 0
for k in range(K):
tmp += pi[k] * multivariate_gaussian(X[n], mu[k], sigma[k])
log_likelihood += np.log(tmp)
return log_likelihood
# GMM 实现,包含数值稳定性修复
def gmm_fixed_stable(X, K, max_iter=100, tol=1e-6):
pi, mu, sigma = initialize_params_fixed(X, K)
log_likelihoods = []
for i in range(max_iter):
# E 步
gamma = expectation_step_stable(X, pi, mu, sigma)
# M 步
pi, mu, sigma = maximization_step(X, gamma)
# 添加小的正则化项,确保协方差矩阵为正定
sigma += np.eye(sigma.shape[1]) * 1e-6
# 计算对数似然
log_likelihood = compute_log_likelihood(X, pi, mu, sigma)
log_likelihoods.append(log_likelihood)
# 检查是否收敛
if i > 0 and abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
break
return pi, mu, sigma, log_likelihoods, gamma
# 数据可视化:原始数据分布
def plot_original_data(X):
plt.scatter(X[:, 0], X[:, 1], c='blue', label='Data points', alpha=0.5)
plt.title('Original Data Distribution')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.show()
# 分类结果展示
def plot_clusters(X, gamma, mu):
K = gamma.shape[1]
colors = ['r', 'g', 'b', 'y', 'm']
for k in range(K):
plt.scatter(X[:, 0], X[:, 1], c=gamma[:, k], cmap='viridis', label=f'Cluster {k+1}', alpha=0.6)
plt.scatter(mu[:, 0], mu[:, 1], c='black', marker='x', s=100, label='Centroids')
plt.title('GMM Clustering')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.legend()
plt.show()
# 对数似然收敛图
def plot_log_likelihood(log_likelihoods):
plt.plot(log_likelihoods)
plt.title('Log Likelihood Convergence')
plt.xlabel('Iterations')
plt.ylabel('Log Likelihood')
plt.show()
# 各类别概率分布图
def plot_probability_distributions(gamma):
K = gamma.shape[1]
for k in range(K):
plt.hist(gamma[:, k], bins=20, alpha=0.5, label=f'Cluster {k+1}')
plt.title('Probability Distributions for Each Cluster')
plt.xlabel('Probability')
plt.ylabel('Number of Points')
plt.legend()
plt.show()
# 运行 GMM 算法
K = 3 # 假设数据有 3 个聚类
pi, mu, sigma, log_likelihoods, gamma = gmm_fixed_stable(X, K)
# 绘制图形
plot_original_data(X) # 原始数据分布图
plot_clusters(X, gamma, mu) # 分类结果图
plot_log_likelihood(log_likelihoods) # 对数似然收敛图
plot_probability_distributions(gamma) # 各类别概率分布图
代码部分细节解释:
我们首先初始化参数,包括每个高斯分布的均值、协方差矩阵和权重。 通过 EM 算法迭代更新参数,直到对数似然值收敛。 最后,使用可视化展示分类结果和模型收敛情况。
1. 原始数据分布图:展示客户年收入和消费得分的散点图,帮助我们直观理解数据分布情况。
2. 分类结果图:展示 GMM 分类后每个客户所属的类别,以及每个类别的均值点(质心)。
3. 对数似然收敛图:展示对数似然值的收敛过程,判断模型是否收敛。
3. 各类别概率分布图:展示不同类别的概率分布,帮助理解分类的置信度。
通过高斯混合模型(GMM)的推导与 Python 实现,咱们完成了从基础原理到实际应用的完整过程。有问题,大家可以评论区讨论~
最后
大家有问题可以直接在评论区留言即可~
喜欢本文的朋友可以收藏、点赞、转发起来!
推荐阅读
原创、超强、精华合集 100个超强机器学习算法模型汇总 机器学习全路线 机器学习各个算法的优缺点 7大方面,30个最强数据集 6大部分,20 个机器学习算法全面汇总 铁汁,都到这了,别忘记点赞呀~