突破最强分类算法,高斯混合模型!!

文摘   2024-10-18 14:36   北京  

哈喽,我是cos大壮!~

咱们今儿来聊一个基础算法模型:高斯混合模型。

整篇文章从最开始的简单解释到最后的完整案例,给大家做一个详细的解释。

简单来说,高斯混合模型(Gaussian Mixture Model,简称 GMM)是一种用来处理「分类」问题的统计模型,它可以帮助我们将一堆数据点分成不同的类别。这种模型假设数据是由多个高斯分布(也叫正态分布)组成的,所以它名字里有“高斯混合”这个词。

老规矩如果大家伙觉得近期文章还不错!欢迎大家点个赞、转个发,文末赠送《机器学习学习小册》

文末可取本文PDF版本~

要理解高斯混合模型,我们先来看什么是高斯分布。你可以把高斯分布想象成一个钟形曲线,表示某个现象的概率分布,像身高、体重这样的数据通常符合高斯分布。大多数人的身高集中在平均值附近,少数人的身高偏高或偏低,形成一个钟形的概率分布。

高斯混合模型 核心思想

高斯混合模型的想法是:你手上的数据可能并不是来自一个单一的高斯分布,而是来自多个不同的高斯分布「混合」在一起的。我们需要做的就是把这些数据点分配到不同的高斯分布上,或者换句话说,把这些数据点分成几类。

举个例子

假设我们有一群不同种类的小动物,它们的体重分布有点复杂。比如,我们有兔子、猫、和狗,它们的体重在不同的范围内,但数据混在一起,我们不知道每个动物属于哪一类。我们能看到的只是每个动物的体重,但不知道它们是兔子、猫还是狗。

任务:用体重把这些动物分成不同的类别。

假设我们记录了很多动物的体重,这些数据可能会呈现三种不同的模式:一种模式对应较轻的兔子,另一种对应中等体重的猫,还有一种对应体重较大的狗。但是,我们不知道哪只动物是兔子、哪只动物是猫或狗。我们只知道这些体重数据混在一起。

这时候,高斯混合模型就派上用场了。它会假设这些体重数据是由三种不同的高斯分布混合起来的。我们的任务是利用这些数据,推测出每个动物属于哪个高斯分布(也就是哪一类:兔子、猫或狗)。

高斯混合模型怎么工作的?

  1. 初始化:首先,假设我们的数据点(动物的体重)来自几个不同的类别(兔子、猫、狗),并且为每个类别初始化一个高斯分布(其实就是初步猜测每个类别的平均体重和体重的波动范围)。

  2. 计算概率:然后,模型会计算每个数据点(每个动物的体重)分别属于不同类别的概率。例如,一个体重3公斤的动物,它更有可能是兔子,而不是狗。通过计算,模型能推断出每个动物分别属于兔子、猫、狗的概率。

  3. 更新参数:接下来,模型会根据这些概率来重新调整每个类别的高斯分布参数,找出更合适的平均值和标准差(波动范围),从而更好地解释数据。

  4. 重复迭代:通过不断重复这两个步骤(计算概率和更新参数),模型会逐渐收敛,最终把所有数据点归类到合适的类别。

经过模型的训练,最终我们得到三个类别,它们各自对应兔子、猫、和狗。每个动物被归类到不同的类别中,虽然我们一开始并不知道它们的真实类别,但通过模型的学习,我们可以比较准确地将它们分类。

你也可以把高斯混合模型想象成一个「盲人分水果」的过程。假设你有一篮子混在一起的苹果、梨和橙子(动物的体重数据),你看不见它们(不知道它们的类别),但你可以根据它们的大小和形状去猜测。高斯混合模型就是通过不停地猜测和调整,最终把苹果、梨和橙子分成三类。

总的来说,高斯混合模型就是在不知道类别的情况下,通过猜测数据点的概率分布,把它们分成几个类别。它假设数据来自多个高斯分布,并通过不断迭代,最终将这些数据点归类到不同的高斯分布中。

下面,咱们从高斯混合模型的原理和案例再来详细聊聊~

1. 高斯混合模型的公式推导

高斯分布公式

首先,我们需要了解单变量的高斯分布公式,它描述了在某个均值  和方差  下,某个数据点  出现的概率:

多元高斯分布公式为:

其中:

  •   -维的特征向量
  •  是均值向量
  •  是协方差矩阵

高斯混合模型

高斯混合模型(GMM) 是由多个高斯分布混合而成的。假设数据集是由  个高斯分布组成的混合模型,那么给定数据点 ,它的概率分布可以表示为每个分布的加权和:

其中:

  •  是第  个高斯分布的权重,且 
  •  表示第  个高斯分布
  •  是模型的所有参数:均值 ,方差 ,以及混合系数 

EM算法推导

由于我们不知道每个点属于哪个高斯分布,因此 GMM 采用 EM算法(期望最大化算法)来迭代估计参数。

1. E 步(期望步):计算后验概率,即每个点  属于第  个分布的概率:

其中  是点  属于第  个高斯分布的概率。

2. M 步(最大化步):根据 E 步的结果,更新 GMM 的参数:

通过多次迭代,EM算法可以让这些参数收敛到合适的值。

2. 案例实现

我们会从 Kaggle 下载一个数据集,使用 GMM 对数据进行分类。为了演示原理,我们使用一个简单的二维数据集。并且根据原理进行代码的编写。

数据集下载

选择一个简单的、适合分类的二维数据集,例如 Kaggle 上的 Iris 数据集  Mall Customers 数据集。我们将以 Mall Customers 为例,用顾客的年收入和消费得分来进行聚类分析。

数据集获取:点击名片,回复「数据集」即可!~

数据分析及可视化

我们要画出 4 个及以上的分析图,来逐步理解数据和模型效果。

Python 实现

下面是 Mall Customers 数据集 的完整 Python 实现:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

# 读取数据集
data = pd.read_csv("Mall_Customers.csv")
X = data[['Annual Income (k$)''Spending Score (1-100)']].values

# 初始化参数
def initialize_params_fixed(X, K):
    n, d = X.shape
    pi = np.ones(K) / K  # 初始化每个混合成分的权重
    mu = X[np.random.choice(n, K, False), :]  # 随机选择K个初始均值
    sigma = np.array([np.eye(d) for _ in range(K)])  # 初始化协方差矩阵为单位矩阵
    return pi, mu, sigma

# 计算多元正态分布
def multivariate_gaussian(X, mu, sigma):
    return multivariate_normal(mean=mu, cov=sigma).pdf(X)

# E 步:计算每个点属于每个成分的责任值 (gamma)
def expectation_step_stable(X, pi, mu, sigma):
    N = X.shape[0]
    K = len(pi)
    gamma = np.zeros((N, K))
    
    for k in range(K):
        try:
            gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
        except np.linalg.LinAlgError:
            # 如果协方差矩阵是奇异矩阵,加入微小正则化项以确保正定性
            sigma[k] += np.eye(X.shape[1]) * 1e-6
            gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
    
    # 防止零除错误,保证数值稳定性
    gamma_sum = np.sum(gamma, axis=1, keepdims=True)
    gamma_sum[gamma_sum == 0] = 1e-10  # 防止除以零
    gamma = gamma / gamma_sum
    
    return gamma

# M 步:更新GMM的参数
def maximization_step(X, gamma):
    N, d = X.shape
    K = gamma.shape[1]
    
    Nk = np.sum(gamma, axis=0)  # 计算每个聚类的总责任值
    pi = Nk / N  # 更新混合系数
    mu = np.dot(gamma.T, X) / Nk[:, np.newaxis]  # 更新均值
    
    sigma = np.zeros((K, d, d))  # 更新协方差矩阵
    for k in range(K):
        X_centered = X - mu[k]
        gamma_diag = np.diag(gamma[:, k])
        sigma[k] = np.dot(X_centered.T, np.dot(gamma_diag, X_centered)) / Nk[k]
    
    return pi, mu, sigma

# 计算对数似然
def compute_log_likelihood(X, pi, mu, sigma):
    N = X.shape[0]
    K = len(pi)
    log_likelihood = 0
    
    for n in range(N):
        tmp = 0
        for k in range(K):
            tmp += pi[k] * multivariate_gaussian(X[n], mu[k], sigma[k])
        log_likelihood += np.log(tmp)
    
    return log_likelihood

# GMM 实现,包含数值稳定性修复
def gmm_fixed_stable(X, K, max_iter=100, tol=1e-6):
    pi, mu, sigma = initialize_params_fixed(X, K)
    log_likelihoods = []
    
    for i in range(max_iter):
        # E 步
        gamma = expectation_step_stable(X, pi, mu, sigma)
        
        # M 步
        pi, mu, sigma = maximization_step(X, gamma)
        
        # 添加小的正则化项,确保协方差矩阵为正定
        sigma += np.eye(sigma.shape[1]) * 1e-6
        
        # 计算对数似然
        log_likelihood = compute_log_likelihood(X, pi, mu, sigma)
        log_likelihoods.append(log_likelihood)
        
        # 检查是否收敛
        if i > 0 and abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
            break
    
    return pi, mu, sigma, log_likelihoods, gamma

# 数据可视化:原始数据分布
def plot_original_data(X):
    plt.scatter(X[:, 0], X[:, 1], c='blue', label='Data points', alpha=0.5)
    plt.title('Original Data Distribution')
    plt.xlabel('Annual Income (k$)')
    plt.ylabel('Spending Score (1-100)')
    plt.show()

# 分类结果展示
def plot_clusters(X, gamma, mu):
    K = gamma.shape[1]
    colors = ['r''g''b''y''m']
    
    for k in range(K):
        plt.scatter(X[:, 0], X[:, 1], c=gamma[:, k], cmap='viridis', label=f'Cluster {k+1}', alpha=0.6)
    
    plt.scatter(mu[:, 0], mu[:, 1], c='black', marker='x', s=100, label='Centroids')
    plt.title('GMM Clustering')
    plt.xlabel('Annual Income (k$)')
    plt.ylabel('Spending Score (1-100)')
    plt.legend()
    plt.show()

# 对数似然收敛图
def plot_log_likelihood(log_likelihoods):
    plt.plot(log_likelihoods)
    plt.title('Log Likelihood Convergence')
    plt.xlabel('Iterations')
    plt.ylabel('Log Likelihood')
    plt.show()

# 各类别概率分布图
def plot_probability_distributions(gamma):
    K = gamma.shape[1]
    for k in range(K):
        plt.hist(gamma[:, k], bins=20, alpha=0.5, label=f'Cluster {k+1}')
    
    plt.title('Probability Distributions for Each Cluster')
    plt.xlabel('Probability')
    plt.ylabel('Number of Points')
    plt.legend()
    plt.show()

# 运行 GMM 算法
K = 3  # 假设数据有 3 个聚类
pi, mu, sigma, log_likelihoods, gamma = gmm_fixed_stable(X, K)

# 绘制图形
plot_original_data(X)  # 原始数据分布图
plot_clusters(X, gamma, mu)  # 分类结果图
plot_log_likelihood(log_likelihoods)  # 对数似然收敛图
plot_probability_distributions(gamma)  # 各类别概率分布图

代码部分细节解释:

  • 我们首先初始化参数,包括每个高斯分布的均值、协方差矩阵和权重。
  • 通过 EM 算法迭代更新参数,直到对数似然值收敛。
  • 最后,使用可视化展示分类结果和模型收敛情况。

1. 原始数据分布图:展示客户年收入和消费得分的散点图,帮助我们直观理解数据分布情况。

2. 分类结果图:展示 GMM 分类后每个客户所属的类别,以及每个类别的均值点(质心)。

3. 对数似然收敛图:展示对数似然值的收敛过程,判断模型是否收敛。

3. 各类别概率分布图:展示不同类别的概率分布,帮助理解分类的置信度。

通过高斯混合模型(GMM)的推导与 Python 实现,咱们完成了从基础原理到实际应用的完整过程。有问题,大家可以评论区讨论~

最后

大家有问题可以直接在评论区留言即可~

喜欢本文的朋友可收藏、点赞、转发起来!

需要本文PDF的同学,扫码备注「基础算法」即可~ 
关注本号,带来更多算法干货实例,提升工作学习效率!
最后,给大家准备了《机器学习学习小册》PDF版本16大块的内容,124个问题总结
100个超强算法模型,大家如果觉得有用,可以点击查看~

推荐阅读

原创、超强、精华合集
100个超强机器学习算法模型汇总
机器学习全路线
机器学习各个算法的优缺点
7大方面,30个最强数据集
6大部分,20 个机器学习算法全面汇总
铁汁,都到这了,别忘记点赞呀~

深夜努力写Python
Python、机器学习算法
 最新文章