机器学习数据降维与可视化:t-SNE详解与实践【附代码】

科技   2024-11-16 12:01   英国  

大家好,我是章北海

最近在看了几篇数据降维相关文章,顺便总结记录一下。

在机器学习和数据挖掘领域,经常面临高维(很多特征或属性)数据的挑战。

高维数据不仅在存储和计算上带来困难,更重要的是,我们很难直观地理解高维空间中数据点的分布和结构。

因此,降维成为了一项重要的数据预处理任务。

什么是降维?顾名思义,就是将高维数据转换到低维空间 (通常是二维或三维) 中,同时尽量保持数据点之间的内在结构。

常见的降维方法有 PCA、、LDA、LLE、Isomap 等。

而今天我们要重点介绍的是 t-SNE (t-distributed Stochastic Neighbor Embedding)。

t-SNE 由 Laurens van der Maaten 和 Geoffrey Hinton 在 2008 年提出,特别适合将高维数据降维并可视化。与 PCA 等线性降维方法不同,t-SNE 是一种非线性降维算法。

它的核心思想是:在高维空间和低维空间中,都使用条件概率来表示数据点之间的相似性,然后最小化两个条件概率分布之间的 KL 散度,从而找到最优的低维嵌入。

t-SNE 的算法流程可以简要概括为:

  1. 在高维空间中计算数据点之间的相似性 (条件概率)
  2. 在低维空间中随机初始化数据点
  3. 计算低维空间中数据点的相似性
  4. 优化目标函数 (最小化 KL 散度),更新低维空间中数据点的位置
  5. 重复步骤 3-4,直到收敛

在 Python 中实现 t-SNE 非常方便,成熟的机器学习库有 Scikit-learn 和 OpenTSNE。

Scikit-learn 提供了 t-SNE 的标准实现:

from sklearn.manifold import TSNE
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

# 加载数据
iris = load_iris()
X = iris.data
y = iris.target

# t-SNE 降维
tsne = TSNE(n_components=2, random_state=42)
X_tsne = tsne.fit_transform(X)

# 可视化
plt.figure(figsize=(88))
colors = ['red''green''blue']
for i in range(len(colors)):
    plt.scatter(X_tsne[y == i, 0], X_tsne[y == i, 1], c=colors[i], label=iris.target_names[i])
plt.legend()
plt.show()

上述代码首先从 sklearn 加载经典的 iris 数据集,然后使用 TSNE 类将 4 维特征降到 2 维。最后,我们绘制散点图,不同类别的样本用不同颜色表示。可以看到,t-SNE 很好地将三类鸢尾花样本区分开来。

然而,sklearn 的 t-SNE 实现在计算效率上还有提升空间。

这时,OpenTSNE 库就派上用场了。OpenTSNE 对 t-SNE 算法做了诸多优化,如 Barnes-Hut 近似方法,并用 C++ 重写了关键步骤,这使得 OpenTSNE 在运行速度上大幅领先于 sklearn。

使用 OpenTSNE 进行降维和可视化的代码如下:

from openTSNE import TSNE
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target

# t-SNE 降维
tsne = TSNE(
    n_components=2,
    perplexity=30,
    metric="euclidean",
    n_jobs=8,
    random_state=42,
)
X_tsne = tsne.fit(X)

# 可视化
plt.figure(figsize=(1212))
colors = plt.cm.rainbow(np.linspace(0110))
for i in range(10):
    plt.scatter(X_tsne[y == i, 0], X_tsne[y == i, 1], color=colors[i], label=str(i))
plt.legend()
plt.show()

这里我们使用手写数字数据集,它有 784 维特征 (28x28 像素)。OpenTSNE 支持多种距离度量 (如欧氏距离、余弦距离等),并可以利用多核并行加速 (n_jobs 参数)。在可视化结果中,我们发现不同数字样本被清晰地分离开,体现了 t-SNE 强大的降维和可视化能力。

理论上 openTSNE 应该比sklearn的实现运行速度要快很多的。

但是我做了一个测试,,,结果,恰恰相反。

使用经典的 MNIST 手写数字数据集,它包含 60,000 个训练样本和 10,000 个测试样本,每个样本是一个 28x28 的灰度图像。

首先,加载所需的库和数据集:

from sklearn.manifold import TSNE as SKLTSNE
from openTSNE import TSNE as OPENTSNE
from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import time

mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"]

为了公平比较,我们选取前 10000 个样本,并在 sklearn 和 openTSNE 中使用相同的参数设置:

n_samples = 10000
X_subset = X[:n_samples].astype(np.float32)

def plot_tsne(X_tsne, y, title):
    plt.figure(figsize=(88))
    colors = plt.cm.rainbow(np.linspace(0110))
    for i in range(10):
        plt.scatter(X_tsne[y == str(i), 0], X_tsne[y == str(i), 1], color=colors[i], label=str(i))
    plt.legend()
    plt.title(title)
    plt.show()

# sklearn t-SNE
start_time = time.time()
tsne_skl = SKLTSNE(n_components=2, random_state=42)
X_tsne_skl = tsne_skl.fit_transform(X_subset)
skl_time = time.time() - start_time
print(f"sklearn t-SNE took {skl_time:.2f} seconds")
plot_tsne(X_tsne_skl, y[:n_samples], "sklearn t-SNE")

# openTSNE
start_time = time.time()
tsne_open = OPENTSNE(n_components=2, random_state=42, n_jobs=8)
X_tsne_open = tsne_open.fit(X_subset)
open_time = time.time() - start_time
print(f"openTSNE took {open_time:.2f} seconds")
plot_tsne(X_tsne_open, y[:n_samples], "openTSNE")

在这个实验中,我们记录了 sklearn 和 openTSNE 运行 t-SNE 的时间,并绘制了可视化结果。

在我的机器上,输出如下:

sklearn t-SNE took 11.96 seconds
openTSNE took 65.73 seconds

看到一个说法:如果你需要处理大规模数据集(如数十万个样本或更多),OpenTSNE可能是更好的选择,因为它的性能优化可以显著加快计算速度。

如果数据集较小,scikit-learn的TSNE实现可能已经足够快了。

可能是我的Mac问题吧

更大数据规模的测试,太耗时了,就懒得再测了,感兴趣的同学可以试试。

以上

如有帮助,敬请 【在看】

模型篇P1:机器学习基本概念

迄今最好的AI代码编辑器,编程只需狂按Tab

【大模型实战,完整代码】AI 数据分析、可视化项目

108页PDF小册子:搭建机器学习开发环境及Python基础

116页PDF小册子:机器学习中的概率论、统计学、线性代数

全网最全 Python、机器学习、AI、LLM 速查表(100 余张)

Obsidian AI写作神器:一键配置DeepSeek,写作效率飙升1000%!

基于 QAnything 的知识库问答系统:技术解析与应用实践【附代码】


机器学习算法与Python实战
长期跟踪关注统计学、数据挖掘、机器学习算法、深度学习、人工智能技术与行业发展动态,分享Python、机器学习等技术文章。回复机器学习有惊喜资料。
 最新文章