Stable Diffusion 3 论文及源码概览

文摘   2024-07-13 23:30   上海  

近期,最受开源社区欢迎的文生图模型 Stable Diffusion 的最新版本 Stable Diffusion 3 开放了源码和模型参数。开发者宣称,Stable Diffusion 3 使用了全新的模型结构和文本编码方法,能够生成更符合文本描述且高质量的图片。得知 Stable Diffusion 3 开源后,社区用户们纷纷上手测试,在网上分享了许多测试结果。而在本文中,我将面向之前已经熟悉 Stable Diffusion 的科研人员,快速讲解 Stable Diffusion 3 论文的主要内容及其在 Diffusers 中的源码。对于 Stable Diffusion 3 中的一些新技术,我并不会介绍其细节,而是会讲清其设计动机并指明进一步学习的参考文献。

内容索引

本文会从多个角度简单介绍 SD3,具体要介绍的方面如下所示。读者可以根据自己的需求,跳转到感兴趣的部分阅读。

流匹配原理简介

流匹配是一种定义图像生成目标的方法,它可以兼容当前扩散模型的训练目标。流匹配中一个有代表性的工作是整流 (rectified flow),它也正是 SD3 用到的训练目标。我们会在本文中通过简单的可视化示例学习流匹配的思想。

SD3 中的 DiT

我们会从一个简单的类 ViT 架构开始,学习 SD3 中的去噪网络 DiT 模型是怎么一步一步搭起来的。读者不需要提前学过 DiT,只需要了解 Transformer 的结构,并大概知道视觉任务里的 Transformer 会做哪些通用的修改(如图块化),即可学懂 SD3 里的 DiT。

SD3 模型与训练策略改进细节

除了将去噪网络从 U-Net 改成 DiT 外,SD3 还在模型结构与训练策略上做了很多小改进:

  • 改变训练时噪声采样方法
  • 将一维位置编码改成二维位置编码
  • 提升 VAE 隐空间通道数
  • 对注意力 QK 做归一化以确保高分辨率下训练稳定

本文会简单介绍这些改进。

大型消融实验

对于想训练大型文生图模型的开发者,SD3 论文提供了许多极有价值的大型消融实验结果。本文会简单分析论文中的两项实验结果:各训练目标在文生图任务中的表现、SD3 的参数扩增实验结果。

SD3 Diffusers 源码解读

本文会介绍如何配置 Diffusers 环境以用代码运行 SD3,并简单介绍相比于 SD,SD3 的采样代码和模型代码有哪些变动。

论文阅读

核心贡献

介绍 Stable Diffusion 3 (SD3) 的文章标题为 Scaling Rectified Flow Transformers for High-Resolution Image Synthesis。与其说它是一篇技术报告,更不如说它是一篇论文,因为它确实是按照撰写学术论文的一般思路,将正文的叙述重点放到了方法的核心创新点上,而没有过多叙述工程细节。正如其标题所示,这篇文章的内容很简明,就是用整流 (rectified flow) 生成模型、Transformer 神经网络做了模型参数扩增实验,以实现高质量文生图大模型。

由于这是一篇实验主导而非思考主导的文章,论文的开头没有太多有价值的内容。从我们读者学习论文的角度,文章的核心贡献如下:

从方法设计上:

  • 首次在大型文生图模型上使用了整流模型。
  • 用一种新颖的 Diffusion Transformer (DiT) 神经网络来更好地融合文本信息。
  • 使用了各种小设计来提升模型的能力。如使用二维位置编码来实现任意分辨率的图像生成。

从实验上:

  • 开展了一场大规模、系统性的实验,以验证哪种扩散模型/整流模型的学习目标最优。
  • 开展了扩增模型参数的实验 (scaling study),以证明提升参数量能提升模型的效果。

整流模型简介

由于 SD3 最后用了整流模型来建模图像生成,所以文章是从一种称为流匹配 (Flow Matching) 的角度而非更常见的扩散模型的角度来介绍各种训练目标。鉴于 SD3 并没有对其他论文中提出的整流模型做太多更改,我们在阅读本文时可以主要关注整流的想法及其与扩散模型的关系,后续再从其他论文中学习整流的具体原理。在此,我们来大致认识一下流匹配与整流的想法。

所谓图像生成,其实就是让神经网络模型学习一个图像数据集所表示的分布,之后从分布里随机采样。比如我们想让模型生成人脸图像,就是要让模型学习一个人脸图像集的分布。为了直观理解,我们可以用二维点来表示一张图像的数据。比如在下图中我们希望学习红点表示的分布,即我们希望随机生成点,生成的点都落在红点处,而不是落在灰点处。

我们很难表示出一个适合采样的复杂分布。因此,我们会把学习一个分布的问题转换成学习一个简单好采样的分布到复杂分布的映射。一般这个简单分布都是标准正态分布。如下图所示,我们可以用简单的算法采样在原点附近的来自标准正态分布的蓝点,我们要想办法得到蓝点到红点的映射方法。

学习这种映射依然是很困难的。而近年来包括扩散模型在内的几类生成模型用一种巧妙的方法来学习这种映射:从纯噪声(标准正态分布里的数据)到真实数据的映射很难表示,但从真实数据到纯噪声的逆映射很容易表示。所以,我们先人工定义从图像数据集到噪声的变换路线(红线),再让模型学习逆路线(蓝线)。让噪声数据沿着逆路线走,就实现了图像生成。

我们又可以用一种巧妙的方法间接学习图像生成路线。知道了预定义的数据到噪声的路线后,我们其实就知道了数据在路线上每一位置的速度(红箭头)。那么,我们可以以每一位置的反向速度(蓝箭头)为真值,学习噪声到真实数据的速度场。这样的学习目标被称为流匹配。

对于不同的扩散模型及流匹配模型,其本质区别在于图像到噪声的路线的定义方式。在扩散模型中,图像到噪声的路线是由一个复杂的公式表示的。而整流模型将图像到噪声的路线定义为了直线。比如根据论文的介绍,整流中 时刻数据 由真实图像 变换成纯噪声 的位置为:

而较先进的扩散模型 EDM 提出的路线公式为( 是一个形式较为复杂的变量):

由于整流最后学习出来的生成路线近乎是直线,这种模型在设计上就支持少步数生成。

虽然整流模型是这样宣传的,但实际上 SD3 还是默认用了 28 步来生成图像。单看这篇文章,原整流论文里的很多设计并没有用上。对整流感兴趣的话,可以去阅读原论文 Flow straight and fast: Learning to generate and transfer data with rectified flow

流匹配模型和扩散模型的另一个区别是,流匹配模型天然支持 image2image 任务。从纯噪声中生成图像只是流匹配模型的一个特例。

非均匀训练噪声采样

在学习这样一种生成模型时,会先随机采样一个时刻 ,根据公式获取此时刻对应位置在生成路线上的速度,再让神经网络学习这个速度。直观上看,刚开始和快到终点的路线很好学,而路线的中间处比较难学。因此,在采样时刻 时,SD3 使用了一种非均匀采样分布。

如下图所示,SD3 主要考虑了两种公式: mode(左)和 logit-norm (右)。二者的共同点是中间多,两边少。mode 相比 logit-norm,在开始和结束时概率不会过分接近 0。

网络整体架构

以上内容都是和训练相关的理论基础,下面我们来看多数用户更加熟悉的文生图架构。

从整体架构上来看,和之前的 SD 一样,SD3 主要基于隐扩散模型(latent diffusion model, LDM)。这套方法是一个两阶段的生成方法:先用一个 LDM 生成隐空间低分辨率的图像,再用一个自编码器把图像解码回真实图像。

扩散模型 LDM 会使用一个神经网络模型来对噪声图像去噪。为了实现文生图,该去噪网络会以输入文本为额外约束。相比之前多数扩散模型,SD3 的主要改进是把去噪模型的结构从 U-Net 变为了 DiT。

DiT 的论文为 Scalable Diffusion Models with Transformers。如果只是对 DiT 的结构感兴趣的话,可以去直接通过读 SD3 的源码来学习。读 DiT 论文时只需要着重学习 AdaLayerNormZero 模块。

提升自编码器通道数

在当时设计整套自编码器 + LDM 的生成架构时,SD 的开发者并没有仔细改进自编码器,用了一个能把图像下采样 8 倍,通道数变为 4 的隐空间图像。比如输入 的图像会被自编码器编码成 。而近期有些工作发现,这个自编码器不够好,提升隐空间的通道数能够提升自编码器的重建效果。因此,SD3 把隐空间图像的通道数从 改为了

多模态 DiT (MM-DiT)

SD3 的去噪模型是一个 Diffusion Transformer (DiT)。如果去噪模型只有带噪图像这一种输入的话,DiT 则会是一个结构非常简单的模型,和标准 ViT 一样:图像过图块化层 (Patching) 并与位置编码相加,得到序列化的数据。这些数据会像标准 Transformer 一样,经过若干个子模块,再过反图块层得到模型输出。DiT 的每个子模块 DiT-Block 和标准 Transformer 块一样,由 LayerNorm, Self-Attention, 一对一线性层 (Pointwise Feedforward, FF) 等模块构成。

图块化层会把 个像素打包成图块,反图块化层则会把图块还原回像素。

然而,扩散模型中的去噪网络一定得支持带约束生成。这是因为扩散模型约束于去噪时刻 。此外,作为文生图模型,SD3 还得支持文本约束。DiT 及本文的 MM-DiT 把模型设计的重点都放在了处理额外约束上。

我们先看一下模块是怎么处理较简单的时刻约束的。此处,如下图所示,SD3 的模块保留了 DiT 的设计,用自适应 LayerNorm (Adaptive LayerNorm, AdaLN) 来引入额外约束。具体来说,过了 LayerNorm 后,数据的均值、方差会根据时刻约束做调整。另外,过完 Attention 层或 FF 层后,数据也会乘上一个和约束相关的系数。

我们再来看文本约束的处理。文本约束以两种方式输入进模型:与时刻编码拼接、在注意力层中融合。具体数据关联细节可参见下图。如图所示,为了提高 SD3 的文本理解能力,描述文本 ("Caption") 经由三种编码器编码,得到两组数据。一组较短的数据会经由 MLP 与文本编码加到一起;另一组数据会经过线性层,输入进 Transformer 的主模块中。

将约束编码与时刻编码相加是一种很常见的做法。此前 U-Net 去噪网络中处理简单约束(如 ImageNet 类型约束)就是用这种方法。

SD3 的 DiT 的子模块结构图如下所示。我们可以分几部分来看它。先看时刻编码 的那些分支。和标准 DiT 子模块一样, 通过修改 LayerNorm 后数据的均值、方差及部分层后的数据大小来实现约束。再看输入的图像编码 和文本编码 。二者以相同的方式做了 DiT 里的 LayerNorm, FF 等操作。不过,相比此前多数基于 DiT 的模型,此模块用了一种特殊的融合注意力层。具体来说,在过注意力层之前, 对应的 会分别拼接到一起,而不是像之前的模型一样, 来自图像, 来自文本。过完注意力层,输出的数据会再次拆开,回到原本的独立分支里。由于 Transformer 同时处理了文本、图像的多模态信息,所以作者将模型取名为 MM-DiT (Multimodal DiT)。

论文里讲:「这个结构可以等价于两个模态各有一个 Transformer,但是在注意力操作时做了拼接,使得两种表示既可以在独自的空间里工作也可以考虑到另一个表示。」然而,我不太喜欢这种尝试去凭空解读神经网络中间表示的表述。仅从数据来源来看,过了一个注意力层后,图像信息和文本信息就混在了一起。你很难说,也很难测量,之后的 主要是图像信息, 主要是文本信息。只能说 都蕴含了多模态的信息。之前 SD U-Net 里的 可以认为是分别包含了图像信息和文本信息,因为之前的 保留了二维图像结构,而 仅由文本信息决定。

比例可变的位置编码

此前多数方法在使用类 ViT 架构时,都会把图像的图块从左上到右下编号,把二维图块拆成一维序列,再用这种一维位置编码来对待图块。

这样做有一个很大的坏处:生成的图像的分辨率是无法修改的。比如对于上图,假如采样时输入大小不是 ,而是 ,那么 号图块的下面就是 而不是 了,模型训练时学习到的图块之间的位置关系全部乱套。

解决此问题的方法很简单,只需要将一维的编码改为二维编码。这样 Transformer 就不会搞混二维图块间的关系了。

SD3 的 MM-DiT 一开始是在 固定分辨率上训练的。之后在高分辨率图像上训练时,开发者用了一些巧妙的位置编码设置技巧,让不同比例的高分辨率图像也能共享之前学到的这套位置编码。详细公式请参见原论文。

训练数据预处理

看完了模块设计,我们再来看一下 SD3 在训练中的一些额外设计。在大规模训练前,开发者用三个方式过滤了数据:

  1. 用了一个 NSFW 过滤器过滤图片,似乎主要是为了过滤色情内容。
  2. 用美学打分器过滤了美学分数太低的图片。
  3. 移除了看上去语义差不多的图片。

虽然开发者们自信满满地向大家介绍了这些数据过滤技术,但根据社区用户们的反馈,可能正是因为色情过滤器过分严格,导致 SD3 经常会生成奇怪的人体。

由于在训练 LDM 时,自编码器和文本编码器是不变的,因此可以提前处理好所有训练数据的图像编码和文本编码。当然,这是一项非常基础的工程技巧,不应该写在正文里的。

用 QK 归一化提升训练稳定度

按照之前高分辨率文生图模型的训练方法,SD3 会先在 的图片上训练,再在高分辨率图片上微调。然而,开发者发现,开始微调后,混合精度训练常常会训崩。根据之前工作的经验,这是由于注意力输入的熵会不受控制地增长。解决方法也很简单,只要在做注意力计算之前对 Q, K 做一次归一化就行,具体做计算的位置可以参考上文模块图中的 "RMSNorm"。不过,开发者也承认,这个技巧并不是一个长久之策,得具体问题具体分析。看来这种 DiT 模型在大规模训练时还是会碰到许多训练不稳定的问题,且这些问题没有一个通用解。

哪种扩散模型训练目标最适合文生图任务?

最后我们来看论文的实验结果部分。首先,为了寻找最好的扩散模型/流匹配模型,开发者开展了一场声势浩大的实验。实验涉及 61 种训练公式,其中的可变项有:

  • 对于普通扩散模型,考虑 - 或 -prediction,考虑线性或 cosine 噪声调度。
  • 对于整流,考虑不同的噪声调度。
  • 对于 EDM,考虑不同的噪声调度,且尽可能与整流的调度机制相近以保证可比较。

在训练时,除了训练目标公式可变外,优化算法、模型架构、数据集、采样器都不可变。所有模型在 ImageNet 和 CC12M 数据集上训练,在 COCO-2014 验证集上评估 FID 和 CLIP Score。根据评估结果,可以选出每个模型的最优停止训练的步数。基于每种目标下的最优模型,开发者对模型进行最后的排名。由于在最终评估时,仍有采样步数、是否使用 EMA 模型等可变采样配置,开发者在所有 24 种采样配置下评估了所有模型,并用一种算法来综合所有采样配置的结果,得到一个所有模型的最终排名。最终的排名结果如下面的表 1 所示。训练集上的一些指标如表 2 所示。

根据实验结果,我们可以得到一些直观的结论:整流领先于扩散模型。惊人的是,较新推出的 EDM 竟然没有战胜早期的 LDM ("eps/linear")。

当然,我个人认为,应该谨慎看待这份实验结果。一般来说,大家做图像生成会用一个统一的指标,比如 ImageNet 上的 FID。这篇论文相当于是新提出了一种昂贵的评价方法。这种评价方法是否合理,是否能得到公认还犹未可知。另外,想说明一个生成模型的拟合能力不错,用 ImageNet 上的 FID 指标就足够有说服力了,大家不会对一个简单的生成模型有太多要求。然而,对于大型文生图模型,大家更关心的是模型的生成效果,而 FID 和 CLIP Score 并不能直接反映文生图模型的质量。因此,光凭这份实验结果,我们并不能说整流一定比之前的扩散模型要好。

会关注这份实验结果的应该都是公司里的文生图开发者。我建议体量小的公司直接参考这份实验结果,无脑使用整流来代替之前的训练目标。而如果有能力做同等级的实验的话,则不应该错过改良后的扩散模型,如最新的 EDM2,说不定以后还会有更好的文生图训练目标。

参数扩增实验结果

现在多数生成模型都会做参数扩增实验,即验证模型表现随参数量增长而增长,确保模型在资源足够的情况下可以被训练成「大模型」。SD3 也做了类似的实验。开发者用参数 来控制 MM-DiT 的大小,Transformer 块的个数为 ,且所有特征的通道数与 成正比。开发者在 的数据上训练了所有模型 500k 步,每 50k 步在 CoCo 数据集上统计验证误差。最终所有评估指标如下图所示。可以说,所有指标都表明,模型的表现的确随参数量增长而增长。更多结果请参见论文。

Diffusers 源码阅读

测试脚本

我们来阅读一下 SD3 在最流行的扩散模型框架 Diffusers 中的源码。在读源码前,我们先来跑通官方的示例脚本。

由于使用协议的限制,SD3 的环境搭起来稍微有点麻烦。首先,我们要确保 Diffuers 和 Transformers 都用的是最新版本。

pip install --upgrade diffusers transformers

之后,我们要注册 HuggingFace 账号,再在 SD3 的模型网站 https://huggingface.co/stabilityai/stable-diffusion-3-medium 里确认同意某些使用协议。之后,我们要设置 Access Token。具体操作如下所示,先点右上角的 "settings",再点左边的 "Access Tokens",创建一个新 token。将这个 token 复制保存在本地后,点击 token 右上角选项里的 "Edit Permission",在权限里开启 "... public gated repos ..."。

最后,我们用命令行登录 HuggingFace 并使用 SD3。先用下面的命令安装 HuggingFace 命令行版。

pip install -U "huggingface_hub[cli]"

再输入 huggingface-cli login,命令行会提示输入 token 信息。把刚刚保存好的 token 粘贴进去,即可完成登录。

huggingface-cli login

Enter your token (input will not be visible): 在这里粘贴 token

做完准备后,我们就可以执行下面的测试脚本了。注意,该脚本会自动下载模型,我们需要保证当前环境能够访问 HuggingFace。执行完毕后,生成的 大小的图片会保存在 tmp.png 里。

import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

image = pipe(
    "A cat holding a sign that says hello world",
    negative_prompt="",
    num_inference_steps=28,
    guidance_scale=7.0,
).images[0]

image.save('tmp.png')

我得到的图片如下所示。看起来 SD3 理解文本的能力还是挺强的。

模型组件

接下来我们来快速浏览一下 SD3 流水线 StableDiffusion3Pipeline 的源码。在 IDE 里使用源码跳转功能可以在 diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py 里找到该类的源码。

通过流水线的 __init__ 方法,我们能知道 SD3 的所有组件。组件包括自编码器 vae, MM-DiT Transformer, 流匹配噪声调度器 scheduler,以及三个文本编码器。每个编码器由一个 tokenizer 和一个 text encoder 组成.

def __init__(
    self,
    transformer: SD3Transformer2DModel,
    scheduler: FlowMatchEulerDiscreteScheduler,
    vae: AutoencoderKL,
    text_encoder: CLIPTextModelWithProjection,
    tokenizer: CLIPTokenizer,
    text_encoder_2: CLIPTextModelWithProjection,
    tokenizer_2: CLIPTokenizer,
    text_encoder_3: T5EncoderModel,
    tokenizer_3: T5TokenizerFast,
)
:

vae 的用法和之前 SD 的一模一样,编码时用 vae.encode 并乘 vae.config.scaling_factor,解码时除以 vae.config.scaling_factor 并用 vae.decode

文本编码器的用法可以参见 encode_prompt 方法。文本会分别过各个编码器的 tokenizer 和 text encoder,得到三种文本编码,并按照论文中的描述拼接成两种约束信息。这部分代码十分繁杂,多数代码都是在处理数据形状,没有太多有价值的内容。

def encode_prompt(
        self,
        prompt,
        prompt_2,
        prompt_3,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
        negative_prompt_2,
        negative_prompt_3,
        ...

    )
:

    ...

    return prompt_embeds, negative_prompt_embeds,
     pooled_prompt_embeds, negative_pooled_prompt_embeds

采样流水线

我们再来通过阅读流水线的 __call__ 方法了解 SD3 采样的过程。由于 SD3 并没有修改 LDM 的这套生成框架,其采样流水线和 SD 几乎完全一致。SD3 和 SD 的 __call__ 方法的主要区别是,生成文本编码时会生成两种编码。

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = self.encode_prompt(...)

在调用去噪网络时,那个较小的文本编码 pooled_prompt_embeds 会作为一个额外参数输入。

noise_pred = self.transformer(
    hidden_states=latent_model_input,
    timestep=timestep,
    encoder_hidden_states=prompt_embeds,
    pooled_projections=pooled_prompt_embeds,
    joint_attention_kwargs=self.joint_attention_kwargs,
    return_dict=False,
)[0]

MM-DiT 去噪模型

相比之下,SD3 的去噪网络 MM-DiT 的改动较大。我们来看一下对应的 SD3Transformer2DModel 类,它位于文件 diffusers\models\transformers\transformer_sd3.py

类的构造函数里有几个值得关注的模块:二维位置编码类 PatchEmbed、组合时刻编码和文本编码模块 CombinedTimestepTextProjEmbeddings、主模块类 JointTransformerBlock

def __init__(...):
    ...
    self.pos_embed = PatchEmbed(...)
    self.time_text_embed = CombinedTimestepTextProjEmbeddings(...)
    ...
    self.transformer_blocks = nn.ModuleList(
          [
              JointTransformerBlock(..)
              for i in range(self.config.num_layers)
          ]
      )

类的前向传播函数 forward 里都是比较常规的操作。数据会依次经过前处理、若干个 Transformer 块、后处理。所有实现细节都封装在各个模块类里。

def forward(...):
    hidden_states = self.pos_embed(hidden_states)
    temb = self.time_text_embed(timestep, pooled_projections)
    encoder_hidden_states = self.context_embedder(encoder_hidden_states)
    for index_block, block in enumerate(self.transformer_blocks):
       encoder_hidden_states, hidden_states = block(...)
    
    encoder_hidden_states, hidden_states = block(
    hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
  ...

接下来我们来看这几个较为重要的子模块。PatchEmbed 类的实现写在 diffusers/models/embeddings.py 里。这个类的实现写得非常清晰。PatchEmbed 类本身用于维护位置编码宽高、特征长度这些信息,计算位置编码的关键代码在 get_2d_sincos_pos_embed 中。get_2d_sincos_pos_embed 会生成 (0, 0), (1, 0), ... 这样的二维坐标网格,再调用 get_2d_sincos_pos_embed_from_grid 生成二维位置编码。get_2d_sincos_pos_embed_from_grid 会调用两次一维位置编码函数 get_1d_sincos_pos_embed_from_grid,也就是 Transformer 里那种标准位置编码生成函数,来分别生成两个方向的编码,最后拼接成二维位置编码。

class PatchEmbed(nn.Module):
    ...
    def forward(self, latent):
        ...
        pos_embed = get_2d_sincos_pos_embed(...)

def get_2d_sincos_pos_embed(...):
    grid_h = np.arange(...)
    grid_w = np.arange(...)
    grid = np.meshgrid(grid_w, grid_h)
    ...
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

def get_2d_sincos_pos_embed_from_grid(...):
    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb

组合时刻编码和文本编码模块 CombinedTimestepTextProjEmbeddings 的代码非常短。它实际上就是用通常的 Timesteps 类获取时刻编码,用一个 text_embedder 模块再次处理文本编码,最后把两个编码加起来。text_embedder 是一个线性层、激活函数、线性层构成的简单模块。

class CombinedTimestepTextProjEmbeddings(nn.Module):
    def __init__(self, embedding_dim, pooled_projection_dim):
        super().__init__()

        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
        self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

    def forward(self, timestep, pooled_projection):
        timesteps_proj = self.time_proj(timestep)
        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))  # (N, D)

        pooled_projections = self.text_embedder(pooled_projection)

        conditioning = timesteps_emb + pooled_projections

        return conditioning

class PixArtAlphaTextProjection(nn.Module):
    def __init__(...):
        ...

    def forward(self, caption):
        hidden_states = self.linear_1(caption)
        hidden_states = self.act_1(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states

MM-DiT 的主要模块 JointTransformerBlockdiffusers/models/attention.py 文件里。这个类的代码写得比较乱。它主要负责处理 LayerNorm 及数据的尺度变换操作,具体的注意力计算由注意力处理器 JointAttnProcessor2_0 负责。两处 LayerNorm 的实现方式竟然是不一样的。

我们先简单看一下构造函数里初始化了哪些模块。代码中,norm1, ff, norm2 等模块都是普通 Transformer 块中的模块。而加了 _context 的模块则表示处理文本分支 的模块,如 norm1_context, ff_contextcontext_pre_only 表示做完了注意力计算后,还要不要给文本分支加上 LayerNorm 和 FeedForward。如前文所述,具体的注意力计算由 JointAttnProcessor2_0 负责。

class JointTransformerBlock(nn.Module):

    def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
        super().__init__()

        self.context_pre_only = context_pre_only
        context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

        self.norm1 = AdaLayerNormZero(dim)

        if context_norm_type == "ada_norm_continous":
            self.norm1_context = AdaLayerNormContinuous(
                dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
            )
        elif context_norm_type == "ada_norm_zero":
            self.norm1_context = AdaLayerNormZero(dim)
        
        processor = JointAttnProcessor2_0()
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            out_dim=dim,
            context_pre_only=context_pre_only,
            bias=True,
            processor=processor,
        )

        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

        if not context_pre_only:
            self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
            self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
        else:
            self.norm2_context = None
            self.ff_context = None     

我们再来看 forward 方法。在前向传播时,图像分支和文本分支会分别过 norm1,再一起过注意力操作,再分别过 norm2ff。大概的代码如下所示,我把较复杂的 context 分支的代码略过了。

这份代码写得很不漂亮,按理说模块里两个 LayerNorm + 尺度变换 (即 Adaptive LayerNorm) 的操作是一样的,应该用同样的代码来处理。但是这个模块里 norm1AdaLayerNormZero 类,norm2LayerNorm 类。norm1 会自动做完 AdaLayerNorm 的运算,并把相关变量返回。而在 norm2 处,代码会先执行普通的 LayerNorm,再根据之前的变量手动调整数据的尺度。我们心里知道这份代码是在实现论文里那张结构图就好,没必要去仔细阅读。

def forward(
    self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
)
:

    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

    if self.context_pre_only:
        ...

    # Attention.
    attn_output, context_attn_output = self.attn(
        hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
    )

    # Process attention outputs for the `hidden_states`.
    attn_output = gate_msa.unsqueeze(1) * attn_output
    hidden_states = hidden_states + attn_output

    norm_hidden_states = self.norm2(hidden_states)
    norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
    ff_output = self.ff(norm_hidden_states)
    ff_output = gate_mlp.unsqueeze(1) * ff_output

    hidden_states = hidden_states + ff_output
    if self.context_pre_only:
        ...

    return encoder_hidden_states, hidden_states

融合注意力的实现方法很简单。和普通的注意力计算相比,这种注意力就是把另一条数据分支 encoder_hidden_states 也做了 QKV 的线性变换,并在做注意力运算前与原来的 QKV 拼接起来。做完注意力运算后,两个数据又会拆分回去。

class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""


    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    )
 -> torch.FloatTensor:

    ...

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # `context` projections.
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        # attention
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        ...

        # Split the attention outputs.
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],
            hidden_states[:, residual.shape[1] :],
        )

总结

在这篇文章中,我们学习了 SD3 论文及源码中的主要内容。相比于 SD,SD3 做了两项较大的改进:用整流代替原来的 DDPM 中的训练目标;将去噪模型从 U-Net 变成了能更好地处理多模态信息的 MM-DiT。SD3 还在模型结构和训练目标上做了许多小改进,如调整训练噪声采样分布、使用二维位置编码。SD3 论文展示了多项大型消融实验的结果,证明当前的 SD3 是以最优配置训练得到的。SD3 可以在 Diffusers 中使用。当然,由于 SD3 的使用协议较为严格,我们需要做一些配置,才能在代码中使用 SD3。SD3 的采样流水线基本没变,原来 SD 的多数编辑方法能够无缝迁移过来。而 SD3 的去噪模型变动较大,和 U-Net 相关的编辑方法则无法直接用过来。在学习源码时,主要值得学习的是新 MM-DiT 模型中每个 Transformer 层的实现细节。

尽管 SD3 并没有提出新的流匹配方法,但其实验结果表明流匹配模型可能更适合文生图任务。作为研究者,受此启发,我们或许需要关注一下整流等流匹配模型,知道它们的思想,分析它们与原扩散模型训练目标的异同,以拓宽自己的视野。


天才程序员周弈帆
NTU MMLab 在读博士生,ACM金牌选手的个人博客。主要分享深度学习、算法教程。放眼全世界,几乎没有比我讲得更易懂、亲民的人,不信你去读读看。