从零训练一个Vision Transformer

文摘   2024-09-10 23:35   北京  

大语言模型书籍已经发布,欢迎关注、购买。

新书发布:大语言模型原理、训练及应用


本文涉及到的详细测试代码和测试步骤放置于:

https://github.com/xinyuwei-david/david-share.git下的:Multimodal-Models/Training-VIT,本文中不再赘述代码实现。

欢迎给repo点亮Star,您的点赞是作者持续创作的动力。

截止到目前,CV模型主要是基于卷积神经网络。而随着Transformer的兴起,VISION TRANSFORMER也逐渐被应用。

接下来,我们看一下主流的CV实现,以及它们的特点。

U-Net

  • 特点:编码器-解码器结构,跳跃连接。

  • 网络类型:卷积神经网络(CNN)。

  • 适用场景:图像分割,医学图像处理。

  • 优点:高效处理分割任务,保留细节。

  • 缺点:对大规模数据集的扩展性有限。

  • 使用情况:广泛用于医学图像分割。

  • 主流模型:原始U-Net,3D U-Net,Stable Diffusion。

R-CNN

  • 特点:选择性搜索生成候选区域。

  • 网络类型:基于CNN。

  • 适用场景:目标检测。

  • 优点:检测精度高。

  • 缺点:计算复杂度高,速度慢。

  • 使用情况:已被更快的模型(如Faster R-CNN)取代。

  • 主流模型:Fast R-CNN,Faster R-CNN。

GAN

  • 特点:生成器和判别器对抗训练。

  • 网络类型:框架,通常使用CNN。

  • 适用场景:图像生成,风格迁移。

  • 优点:生成高质量图像。

  • 缺点:训练不稳定,易于模式崩溃。

  • 使用情况:广泛用于生成任务。

  • 主流模型:DCGAN,StyleGAN。

RNN/LSTM

  • 特点:处理序列数据,记忆长时依赖。

  • 网络类型:循环神经网络。

  • 适用场景:时间序列预测,视频分析。

  • 优点:适合处理序列数据。

  • 缺点:训练困难,梯度消失。

  • 使用情况:在序列任务中常用。

  • 主流模型:LSTM,GRU。

GNN

  • 特点:处理图结构数据。

  • 网络类型:图神经网络。

  • 适用场景:社交网络分析,化学分子建模。

  • 优点:捕捉图结构信息。

  • 缺点:对大图的扩展性有限。

  • 使用情况:在图数据任务中使用。

  • 主流模型:GCN,GraphSAGE。

Capsule Networks

  • 特点:胶囊结构,捕捉空间层次关系。

  • 网络类型:基于CNN。

  • 适用场景:图像识别。

  • 优点:捕捉姿态变化。

  • 缺点:计算复杂度高。

  • 使用情况:研究阶段,未广泛应用。

  • 主流模型:Dynamic Routing。

Autoencoder

  • 特点:编码器-解码器结构。

  • 网络类型:可以基于CNN。

  • 适用场景:降维,特征学习。

  • 优点:无监督学习。

  • 缺点:生成质量有限。

  • 使用情况:用于特征提取和降维。

  • 主流模型:Variational Autoencoder (VAE)。

Vision Transformer (ViT)

  • 特点:基于自注意力机制,处理图像块。

  • 网络类型:Transformer。

  • 适用场景:图像分类。

  • 优点:捕捉全局信息。

  • 缺点:需要大量数据进行训练。

  • 使用情况:逐渐流行,尤其在大数据集上。

  • 主流模型:原始ViT,DeiT。


根据论文:《UNDERSTANDING THE EFFICACY OF U-NET & VISION TRANSFORMER FOR GROUNDWATER NUMERICAL MODELLING》

U-Net在效率上通常优于ViT,特别是在稀疏数据场景下。U-Net的架构较为简单,参数较少,因此在计算资源和时间上更为高效。ViT虽然在捕捉全局信息上有优势,但其自注意力机制计算复杂度较高,尤其在处理大规模数据时。

在论文的实验中,U-Net和U-Net结合ViT的模型在准确性和效率上都优于Fourier Neural Operator (FNO),特别是在数据稀疏的情况下。

在图像处理中,稀疏数据通常指的是图像中的信息不完整或分布不均匀。例如:

  1. 低分辨率图像:像素较少,细节缺失。

  2. 遮挡或缺失:部分区域被遮挡或数据缺失。

  3. 采样不均匀:某些区域的像素密度较低。

    在这些情况下,模型需要从有限的像素信息中推断出完整的图像内容。

 Vision Transformer出现之后,出现了新的分和流派:

  • Facebook AI 的DeiTData-efficient Image Transformers)。DeiT 模型是经过精炼的 ViT models,。DeiT 的作者还发布了训练效率更高的 ViT 模型,您可以将其直接插入ViTModel或 ViTForImageClassification。有 4 种变体可用(3 种不同大小):facebook/deit-tiny-patch16-224、 facebook/deit-small-patch16-224facebook/deit-base-patch16-224facebook/deit-base-patch16-384。请注意,应使用DeiTImageProcessor为模型准备图像。

  • 微软研究院的BEiTBERT pre-training of Image Transformers)。BEiT 模型使用受 BERT(蒙版图像建模)启发并基于 VQ-VAE 的自监督方法,其表现优于监督预训练vision transformers 

  • Facebook AI 的 DINO(一种 Vision Transformers 的自监督训练方法)。使用 DINO 方法训练的 Vision Transformers 表现出卷积模型所没有的非常有趣的特性。它们能够分割物体,而无需接受过这样的训练。DINO 检查点可以在 hub 上找到

  • Facebook AI 的MAE(蒙版自动编码器)。通过预训练 Vision Transformers 来重建大部分(75%)蒙版块的像素值(使用非对称编码器-解码器架构),作者表明这种简单的方法在微调后优于监督预训练。

下面这张图描述了Vision Transformer(ViT)的工作流程:

  1. 图像分块:输入图像被分割成固定大小的小块(patches)。

  2. 线性投影:每个图像块被展平并通过线性投影转换为向量。

  3. 位置嵌入:为每个图像块添加位置嵌入,以保留位置信息。

  4. CLS标记:在序列的开头添加一个可学习的CLS标记,用于分类任务。

  5. Transformer编码器:这些嵌入向量(包括CLS标记)被输入到Transformer编码器中,进行多层处理。每层包括多头注意力机制和前馈神经网络。

  6. MLP头:经过编码器处理后,CLS标记的输出被送入多层感知机(MLP)头,用于最终的分类决策。

    整个流程展示了如何使用Transformer架构直接处理图像块序列,实现图像分类任务。

这张图展示了Transformer编码器中的注意力层(Attention Layer)的结构,具体包括:

  1. 多头注意力机制(Multi-Head Attention):这是Transformer的核心组件。通过多个注意力头,模型可以在不同的子空间中关注输入序列的不同部分,从而捕捉更丰富的特征和关系。

  2. 归一化(Norm):在多头注意力机制之后进行归一化处理,帮助稳定和加速训练过程。

  3. 残差连接:注意力层中有残差连接,将输入直接加到输出上,促进信息流动并缓解梯度消失问题。

    这些组件共同作用,使得Transformer能够高效地处理和理解输入数据的复杂关系。··

残差连接(Residual Connections)是一种网络结构设计,用于缓解深层神经网络中的梯度消失问题。它通过在层与层之间添加直接的跳跃连接,将输入直接加到输出上。这种设计使得网络更容易训练,因为它允许梯度直接通过跳跃连接传播,从而保持信息流动。残差连接最初在ResNet中引入,并在许多现代深度学习模型中广泛使用。

参考:

https://huggingface.co/docs/transformers/model_doc/vit

https://arxiv.org/pdf/2307.04010

大魏分享
https://github.com/davidsajare/david-share.git
 最新文章