理论与数学
扩散模型(diffusion model)通过一系列时间步骤 T (x₀,xₜ) 逐渐降低图像中的信息。在每个步骤中,都会添加少量高斯噪声(Gaussian noise),最终将图像转换为纯随机噪声,类似于正态分布(normal distribution)的样本,这称为前向过程(forward process)。从 xₜ₋₁ 到 xₜ 的过渡遵循这种噪声添加机制。
为了扭转这一局面,需要训练神经网络逐步消除噪声。经过训练后,模型可以从正态分布中提取的随机噪声开始。它会迭代地对输入进行去噪,每次都会消除一些噪声,直到最终结果是与原始分布相似的清晰图像。
这种方法在概念上与变分自动编码器 (variational autoencoders,VAE)相似。在 VAE 中,图像被编码为高斯分布的均值和方差,然后解码器通过从该分布中采样来重建图像。同样,扩散模型的去噪过程将随机噪声转换回连贯图像,类似于 VAE 中的重建阶段。
扩散是指分子从高浓度区域向低浓度区域移动。从统计学意义上讲,扩散过程是一种随机马尔可夫过程(stochastic Markov process),其特征是连续的样本路径。随机(Stochastic)意味着存在随机性,而马尔可夫(Markov)则表示未来状态仅取决于当前状态——了解过去的状态不会增加更多信息。连续意味着过程平稳发展,没有突然跳跃。
在统计学中,扩散描述了在同一域内将复杂分布转换为更简单的分布(通常是先验分布)。如果满足某些条件,将过渡核反复应用于任何分布的样本最终都会产生来自这种更简单的先验的样本。在扩散模型的情况下,输入图像表示复杂分布,该分布逐渐转变为简单的正态分布。
这个函数背后发生了什么?
我们假设系数为 ⍺ 和 ꞵ,两者相互独立。我们将从左侧给出的分布开始,对图像进行采样并绘制直方图。然后我们将应用一次过渡步骤(⍺=0.5,ꞵ=0.1),这将导致值发生变化,从而导致直方图发生变化。
通过应用该方程一次,我们观察到从原始状态到更随机状态的逐渐过渡。当 ⍺ = 0 和 ꞵ = 1 时,我们立即获得高斯分布。然而,目标是通过增量变化逐步实现这种转变,这定义了扩散过程。如果 ⍺ > 1,方差将不受控制地增加,从而阻止收敛到所需的分布。因此,原始值中的一些衰减是必不可少的。合适的选择是⍺ = 0.999,略低于 1。
接下来,必须逐渐引入噪声项 ꞵ。较高的 ꞵ 值会导致分布发生突然变化,而这正是我们要避免的。相反,转换应该缓慢进行,以确保随着原始分布的退化,它越来越接近正态分布。最终,我们得出了最佳值 ⍺ = √0.99 和 ꞵ = √0.01。
从数学上讲,将项 xₜ₋₁ 和 xₜ₋₂ 代入方程式可以揭示该过程中的重复模式,从而强化转变的渐进性。
对于所有时间步骤,我们最终得出一个最终方程,其中只有第一项包含 x₀。当 T 变得足够大时,这个第一项趋近于零,因为乘积中的所有值都小于 1。其余项是均值为 0 但方差不同的高斯分布。
由于这些项是独立的,因此可以将它们组合成单个高斯分布,其中总体平均值仍为 0,方差等于各个方差之和。具体来说,最后一项的方差为 β,倒数第二项的方差为 β(1-β),依此类推。这形成了一个几何级数 (GP),其中第一项为 β,公比为 1-β。因此,T 项之和可以表示为:β(1-(1-β)ₜ) / (1-(1-β)) ~ β/β = 1。
我们本质上是将原始分布结构被破坏的程度与我们引入的噪声量联系起来。扩散过程被离散化为 1D 分布的有限步骤。这种方法可以扩展到 aw×h,其中前向过程的输出为 aw×h,每个像素类似于高斯分布的样本,平均值为 0,方差为 1。
在实践中,我们不是在每个步骤中使用恒定的噪声方差,而是应用一个时间表。作者提出了一个线性时间表,随着时间的推移逐渐增加噪声方差。这种策略是合乎逻辑的,因为在逆向过程开始时,模型需要学习进行较大的调整。当它接近清晰的图像时,需要进行更小、更精细的调整。
此计划可确保方差从输入分布平滑地扩展到目标高斯分布。在噪声方差固定的情况下,分布方差的减少在早期非常显著,在大约 500 步后达到接近 1。但是,当遵循建议的计划时,方差的减少在大部分时间步骤中更加渐进和一致,从而实现更平稳的过渡。
为了获得 t=1000 时 X 的值,我们需要应用 1000 次转换来遍历整个马尔可夫链,这是效率不高的。
我们需要做这整套数学运算只是为了给输入图像添加噪声,逆过程也是一个具有高斯转移概率的马尔可夫链。我们不能直接计算它,因为我们需要计算整个数据分布,但是我们可以通过分布 P 来近似它,它被表述为高斯分布,其均值和方差是我们感兴趣的学习参数。
如果我们不能计算逆分布,我们如何训练一个近似它的模型?
在 VAE 中我们也遇到过类似的情况,给定 Z 我们不知道 X 的真实分布,但我们学习通过神经网络来近似它。所以我们想要学习 P(X | Z),以便我们能够生成尽可能接近训练数据分布的图像。
假设 P 服从高斯分布,第一项成为生成的输出与真实图像之间的重建损失。第二项表示我们的先验(均值为零、方差为单位的高斯分布)与编码器预测的分布之间的 KL 散度。这个概念可以扩展到扩散过程。
在这种情况下,我们不是直接过渡到Zₜ ,而是逐渐通过一系列潜在变量从 x₁ 移动到 xₜ。这里,q(z ∣ x) 被 q(xₜ ∣ xₜ₋₁) 取代,后者是固定的,不是学习的。和以前一样,目标是最大化观察到的数据的可能性。
期望中的项,q 是正向过程,P 是反向过程的近似值。由于链是马尔可夫的,q(xₜ ∣ xₜ₋₁) = q(xₜ ∣ xₜ₋₁, x₀)。
让我们看看分母,我们可以应用贝叶斯定理得到这个:
其余项与下限方程相加,我们可以将该方程分解为三个对数项之和。请记住,所有这些都是在 Q 的期望之下。
第一项类似于 KL 散度先验项,但在这里因为我们使用的是扩散,所以 q 是固定的,并且根据理论,最终的 q(xₜ ∣ x₀) 实际上将非常接近正态分布,所以这是无参数的,我们不会费心优化它。第二项是在给定 x₁ 的情况下对输入 x₀ 的重构,最后一项是数量之和,这些数量只不过是 KL 散度。并且由于我们想要最大化下限,因此我们希望最小化所有这些尺度项,该公式的优点在于最后一项涉及两个相同形式的量,这只是要求近似去噪过渡分布非常接近以 x₀ 为条件的真实去噪过渡分布。
我们又回到了同样的问题,因为我们不知道反向分布,所以我们该如何继续前进呢?
现在,我们用 q(xₜ₋₁ ∣ xₜ) 来代替 q(xₜ₋₁ ∣ xₜ),这在概念上更容易计算。通过应用贝叶斯定理,我们可以将其分解。此过程中的每个项都是高斯的,我们已经推导出必要的表达式。第一个项表示前向过程,尽管它以 x₀ 为条件,但它不受影响,因为前向过程遵循马尔可夫链。
另外两个项可以使用之前建立的递归来表示,这使我们能够在任何时间步骤 t 从 x₀ 过渡到噪声图像。由于这些项是高斯的,因此它们可以表示为指数形式。主要目标是计算 xₜ₋₁。为了实现这一点,我们可以将表达式重写为完全平方,从而使我们能够推导出高斯分布的方程。通过这种方法,我们将 x ² ₜ₋₁ 项、xₜ₋₁ 项和与 xₜ₋₁ 无关的所有内容分开。
到目前为止,我们已经用指数形式表示了正态分布,并使用代数求和对其进行了简化。现在,我们将重点关注最后一个项,它与 xₜ₋₁ 无关。经过进一步简化,我们发现该项分解为表达式的平方乘以 2xₜ₋₁。
这样,我们就可以把整个逆分布方程写成高斯分布,其中均值为均值,方差为方差。对于最后一项,我们需要计算 q(xₜ₋₁ | xₜ₋₁, x₀)。
我们可以看到,平均值是 xₜ 和 x₀ 的加权平均值。如果我们计算 x₀ 的权重,那么我们将看到它在较高时间步骤时的值非常低,并且随着我们接近反向过程的结束,该权重会增加。从图中我们可以看到权重长时间为 0,但是当我们绘制对数值时,我们可以看到随着时间的推移权重确实会增加。
对于逆分布过程,我们必须对其进行近似,因为对于生成,我们实际上不会有 x₀,但因为我们知道我们的逆分布是高斯分布,所以我们也可以将近似值作为高斯分布。我们需要做的就是学习均值和方差。作者做的第一件事是将方差固定为与地面实况去噪步骤完全相同。
似然函数,所有项都是 q 的期望,使其成为 KL 散度,当我们使用高斯的 KL 散度公式时,因为这里两个分布具有完全相同的方差,它最终将是均值之差的平方除以两倍方差,也就是这个量,所以我们的目标是最大化似然,我们现在需要减少差异,因为我们的基本事实去噪步骤具有均值。
虽然论文中的损失有所不同,在某些噪声方面有所不同,但这里我们还有其他东西。为什么?
回顾→在似然项中,我们忽略了第一个项,而专注于最后一个求和项。我们发现它可以写成地面真实噪声与某些模型使用 xₜ 作为输入生成的噪声预测之间的缩放平方差,实际上我们也提供了时间步长。作者实际上完全忽略了缩放,并通过实验发现仅在噪声平方差上训练模型就足够了。
通过让模型近似 x₀ 的噪声以获得 x₁,似然函数中的第二项被包裹在这个损失之下。
训练
在训练期间,我们首先从数据集中抽取图像并均匀地选择一个时间步长 t。接下来,我们从正态分布中抽取随机噪声。使用 xₜ 的 x₀ 和 ϵ 方程,我们在扩散过程中获得时间步长 t 处图像的噪声版本。然后,我们使用原始图像、采样噪声、时间步长和噪声计划来计算累积乘积项。这个噪声图像通过神经网络,我们使用损失函数训练网络以最小化预测噪声和实际噪声之间的差异。通过多步训练,我们覆盖了所有时间步长并有效地优化了求和项的每个组成部分。
对于图像生成,我们遵循相反的过程,通过从我们的神经网络学习到的去噪步骤分布 P 中迭代采样。为了生成图像,我们首先从正态分布中随机抽取一个样本作为步骤 t 的初始图像。然后将其传递给我们训练过的模型来预测噪声。需要澄清的是,我们的近似去噪分布由时间步骤 t 的噪声图像的平均值和预测的噪声定义。
一旦我们有了预测噪声,我们就可以使用均值图像和方差从分布中采样一个图像,方差固定为与前向过程相同。这个采样图像成为我们的 xₜ₋₁。我们继续重复这个过程,直到得到原始图像 x₀。唯一的变化是,要从 x1 到 x₀,我们直接返回均值图像。
PyTorch 实现 DDPM
扩散过程由前向阶段组成,其中图像通过在每个步骤中添加高斯噪声逐渐损坏。经过许多步骤后,图像实际上变得与从正态分布中采样的随机噪声无法区分。这是通过在每个时间步骤 xₜ 应用过渡函数来实现的,其中 β 表示在 t-1 时添加到图像中的预定噪声量,以产生 t 时的图像。
在前面的讨论中,我们确定设置 α=1−β 并计算每个时间步骤中这些 α 值的累积乘积,使我们能够在任何给定步骤 t 直接从原始图像过渡到噪声版本。在反向过程中,模型被训练以近似反向分布。由于正向和反向过程都是高斯的,因此目标是让模型预测反向分布的均值和方差。
通过详细的推导,从最大化观测数据的对数似然性这一目标出发,我们得出需要最小化真实去噪分布(以 x₀ 为条件)与模型预测分布之间的 KL 散度(以特定均值和方差为特征)。方差固定为与目标分布的方差匹配,而均值则以相同形式重写。最小化 KL 散度简化为最小化预测噪声与实际噪声样本之间的平方差。
训练过程包括对图像进行采样、选择时间步长 t,以及添加从正态分布中采样的噪声。然后将 t 处的噪声图像传递给模型。从噪声时间表得出的累积乘积项确定随时间增加的噪声。损失函数是原始噪声样本与模型预测之间的均方误差 (MSE)。
从头开始实现
对于图像生成,我们从学习到的反向分布中进行采样,从正态分布中的随机噪声样本 xₜ 开始。使用与 xₜ 和预测噪声相同的公式计算平均值,方差与地面真实去噪分布相匹配。使用重新参数化技巧,我们反复从这个反向分布中采样以生成 x₀。在 x₀ 处,没有添加额外的噪声;相反,平均值直接作为最终输出返回。
为了实现扩散过程,我们需要处理正向和反向阶段的计算。我们将创建一个噪声调度程序来管理这些任务。在正向过程中,给定一个图像、一个噪声样本和一个时间步长 t,调度程序将使用正向方程返回图像的噪声版本。为了优化效率,它将预先计算并存储 α(1−β) 的值以及所有时间步长中 α 的累积乘积。
作者采用了线性噪声调度,其中 β 在 1,000 个时间步骤内从 1×10⁻⁴ 线性缩放到 0.02。调度程序还处理反向过程:给定 xt 和模型预测的噪声,它将通过从反向分布中采样来计算 xₜ₋₁。这涉及使用各自的方程计算均值和方差,并通过重新参数化技巧生成样本。
为了支持这些计算,调度程序还将存储 1-αₜ、1-累积乘积项以及该项的平方根的预先计算的值。
线性噪声调度器(Linear Noise Scheduler)
import torch
class LinearNoiseScheduler:
def __init__(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1. - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
使用传递给此类的参数初始化所有参数后,我们将定义 β 值从起始范围到结束范围线性增加,确保 βₜ 从 0 进展到最后的时间步骤。接下来,我们将设置正向和反向过程方程所需的所有变量。
def add_noise(self, original, noise, t):
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
for _ in range(len(original_shape) - 1):
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
# Apply and Return Forward process equation
return (sqrt_alpha_cum_prod.to(original.device) * original
+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
该add_noise()函数表示正向过程。它以原始图像、噪声样本和时间步长 ttt 作为输入。图像和噪声的维度为 b×h×w,而时间步长为大小为 b 的一维张量。对于正向过程,我们计算给定时间步长的累积乘积项的平方根和 1-累积乘积项。这些值被重新整形为维度 b×1×1×1。最后,我们应用正向过程方程来生成噪声图像。
def sample_prev_timestep(self, xt, noise_pred, t):
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
x0 = torch.clamp(x0, -1., 1.)
mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
variance = variance * self.betas.to(xt.device)[t]
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
调度程序类中的下一个函数处理反向过程。它使用噪声图像 xₜ、模型的噪声预测和时间步长 t 作为输入,从学习到的反向分布中生成样本。我们保存原始图像预测 x₀ 以供可视化,它是通过重新排列正向过程方程以使用噪声预测而不是实际噪声来计算 x₀ 获得的。
对于逆向过程中的采样,我们使用逆均值方程计算均值。在 t=0 时,我们只需返回均值。对于其他时间步骤,噪声会添加到均值中,方差与以 x₀ 为条件的地面真实去噪分布的方差相同。最后,我们使用计算出的均值和方差从高斯分布中采样,应用重新参数化技巧来生成结果。
这样就完成了噪声调度程序,它管理添加噪声的正向过程和采样的反向过程。对于扩散模型,我们可以灵活地选择任何架构,只要它满足两个关键要求。第一,输入和输出形状必须相同,第二,必须有一种方法可以整合时间步长信息。
无论是在训练期间还是采样期间,时间步长信息始终是可访问的。包含此信息有助于模型更好地预测原始噪声,因为它表明输入图像中有多少是噪声。我们不仅向模型提供图像,还提供相应的时间步长。
对于模型架构,我们将使用 UNet,这也是原作者的选择。为了确保一致性,我们将复制 Hugging Face 的 Diffusers 管道中使用的稳定扩散 UNet 中实现的块、激活、规范化和其他组件的精确规格。
时间步长由时间嵌入块处理,该块采用大小为b(批次大小)的时间步长的一维张量,并输出批次中每个时间步长的大小为t_emb_dim的表示。此块首先通过嵌入空间将整数时间步长转换为矢量表示。然后,此嵌入通过中间带有激活函数的两个线性层,产生最终的时间步长表示。对于嵌入空间,作者使用了 Transformers 中常用的正弦位置嵌入方法。在整个架构中,使用的激活函数是 S 形线性单元 (SiLU),但也可以选择其他激活函数。
UNet架构遵循简单的编码器-解码器设计。编码器由多个下采样块组成,每个块都会减少输入的空间维度(通常减半),同时增加通道数量。最终下采样块的输出由中间块的几层处理,所有层都以相同的空间分辨率运行。随后,解码器采用上采样块,逐步增加空间维度并减少通道数量,最终匹配原始输入大小。在解码器中,上采样块通过残差跳过连接以相同的分辨率集成相应下采样块的输出。虽然大多数扩散模型都遵循这种通用的 UNet 架构,但它们在各个块内的具体细节和配置上有所不同。
大多数变体中的下行块通常由ResNet 块、后跟自注意力块和下采样层组成。每个 ResNet 块都使用一系列操作构建:组归一化、激活层和卷积层。此序列的输出将通过另一组归一化、激活和卷积层。通过将第一个归一化层的输入与第二个卷积层的输出相结合来添加残差连接。这个完整的序列形成ResNet 块,可以将其视为通过残差连接连接的两个卷积块。
在 ResNet 块之后,有一个规范化步骤、一个自注意力层和另一个残差连接。虽然模型通常使用多个 ResNet 层和自注意力层,但为简单起见,我们的实现将只使用每个层的一层。
为了整合时间信息,每个 ResNet 块都包含一个激活层,后面跟着一个线性层,用于处理时间嵌入表示。时间嵌入表示为大小为t_emb_dim的张量,通过此线性层将其投影到与卷积层输出具有相同大小和通道数的张量中。这样就可以通过在空间维度上复制时间步长表示,将时间嵌入添加到卷积层的输出中。
另外两个块使用相同的组件,只是略有不同。上块完全相同,只是它首先将输入上采样为两倍空间大小,然后在整个通道维度上集中相同空间分辨率的下块输出。然后我们有相同的 resnet 层和自注意力块。中间块的层始终将输入保持为相同的空间分辨率。hugging face 版本首先有一个 resnet 块,然后是自注意力层和 resnet 层。对于这些 resnet 块中的每一个,我们都有一个时间步长投影层。现有的时间步长表示会经过这些块,然后被添加到 resnet 的第一个卷积层的输出中。
时间嵌入(Time Embedding)
import torch
import torch.nn as nn
def get_time_embedding(time_steps, temb_dim):
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
第一个函数为给定的时间步长get_time_embedding生成时间嵌入。它受到 Transformer 模型中使用的正弦位置嵌入的启发。
time_steps:时间步长值的张量(形状:[B]其中B是批次大小)。每个值代表批次元素的一个离散时间步长。
temb_dim:时间嵌入的维数。这决定了每个时间步长的生成嵌入的大小。
确保这temb_dim是均匀的,因为正弦嵌入需要将嵌入分成两半,分别表示正弦和余弦分量。无缝扩展以处理任何批量大小或嵌入维度。
Down Block
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim,
down_sample=True, num_heads=4, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.down_sample = down_sample
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for i in range(num_layers)
]
)
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
4, 2, 1) if self.down_sample else nn.Identity()
def forward(self, x, t_emb):
out = x
for i in range(self.num_layers):
# Resnet block of Unet
resnet_input = out
out = self.resnet_conv_first[i](out)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Attention block of Unet
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
out = self.down_sample_conv(out)
return out
DownBlock 类结合了ResNet 块、自注意力块和可选的下采样,并集成了时间嵌入来整合时间步长信息。将卷积层与残差连接相结合,以实现更好的梯度流和更高效的学习。将时间步长表示投影到特征空间中,使模型能够整合时间相关信息。通过对所有空间位置之间的关系进行建模来捕获长距离依赖关系。减少空间维度以专注于更深层中更大规模的特征。
参数:
in_channels:输入通道数。
out_channels:输出通道数。
t_emb_dim:时间嵌入的维度。
down_sample:布尔值,确定是否在块末尾应用下采样。
num_heads:多头注意力层中的注意力头的数量。
num_layers:此块中的 ResNet + 注意力层的数量。
ResNet块:
resnet_conv_first:ResNet 块的第一个卷积层。
t_emb_layers:时间嵌入投影层。
resnet_conv_second:ResNet 块的第二个卷积层。
residual_input_conv:用于残差连接的 1x1 卷积。
自注意力模块:
attention_norms:在注意力机制之前对规范化层进行分组。
attentions:多头注意力层。
下采样:
down_sample_conv:应用卷积来减少空间维度(如果down_sample=True)。
Forward Pass 方法定义了如何x通过块处理输入张量:out初始化为输入x。对于每一层,我们都有 ResNet Block 和 Self-Attention Block。
在 ResNet Block 中,我们有第一个 卷积层,它应用 GroupNorm、SiLU 激活和 3x3 卷积,以及一个时间嵌入函数,它将时间嵌入传递t_emb到线性层(投影到out_channels),并将此投影时间嵌入添加到out(在空间维度上广播)。然后我们有第二个卷积和一个残差连接,它将原始输入(resnet_input)添加到第二个卷积的输出。
在自注意力模块中,我们将空间维度扁平化为一个维度(h * w)以用于注意力机制。规范化输入并转置以匹配注意力层输入格式。多头注意力in_attn使用查询、键和值执行自注意力。重塑回转置并重塑回原始空间维度。残差连接和下采样。
Mid Block
class MidBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers+1)
]
)
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers + 1)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers+1)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers+1)
]
)
def forward(self, x, t_emb):
out = x
# First resnet block
resnet_input = out
out = self.resnet_conv_first[0](out)
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out)
out = out + self.residual_input_conv[0](resnet_input)
for i in range(self.num_layers):
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i+1](out)
out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i+1](out)
out = out + self.residual_input_conv[i+1](resnet_input)
return out
该类MidBlock是位于扩散模型中 U-Net 架构中间的模块。它由ResNet 块和自注意力层组成,并集成了时间嵌入来处理时间信息。这是用于去噪扩散等任务的模型的重要组成部分。此外,我们还有:
时间嵌入:通过将时间信息(例如,扩散模型中的去噪步骤)投影到特征空间并将其添加到卷积特征中来合并时间信息。
层迭代:在注意力和ResNet 块之间交替,按num_layers这些组合的顺序处理输入。
Up Block
class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(8, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down, t_emb):
x = self.up_sample_conv(x)
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
resnet_input = out
out = self.resnet_conv_first[i](out)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
该类UpBlock是 U-Net 类架构的解码器阶段的一部分,通常用于扩散模型或其他图像生成/分割任务。它结合了上采样、跳过连接、ResNet 块和自注意力来重建输出图像,同时保留早期编码器阶段的细粒度细节。
上采样:通过转置卷积(ConvTranspose2d)实现,以增加特征图的空间分辨率。
跳过连接:允许解码器重用编码器的详细特征,帮助重建。
ResNet Block:使用卷积层处理输入,集成时间嵌入,并添加残差连接以实现高效的梯度流。
自我注意力:捕获远程空间依赖关系以保留全局上下文。
时间嵌入:对时间信息进行编码并将其注入特征图,这对于处理动态数据的模型(如扩散模型)至关重要。
UNet Architecture
class Unet(nn.Module):
def __init__(self, model_config):
super().__init__()
im_channels = model_config['im_channels']
self.down_channels = model_config['down_channels']
self.mid_channels = model_config['mid_channels']
self.t_emb_dim = model_config['time_emb_dim']
self.down_sample = model_config['down_sample']
self.num_down_layers = model_config['num_down_layers']
self.num_mid_layers = model_config['num_mid_layers']
self.num_up_layers = model_config['num_up_layers']
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-2]
assert len(self.down_sample) == len(self.down_channels) - 1
# Initial projection from sinusoidal time embedding
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim)
)
self.up_sample = list(reversed(self.down_sample))
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
self.downs = nn.ModuleList([])
for i in range(len(self.down_channels)-1):
self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
down_sample=self.down_sample[i], num_layers=self.num_down_layers))
self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels)-1):
self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
num_layers=self.num_mid_layers))
self.ups = nn.ModuleList([])
for i in reversed(range(len(self.down_channels)-1)):
self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))
self.norm_out = nn.GroupNorm(8, 16)
self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
def forward(self, x, t):
# Shapes assuming downblocks are [C1, C2, C3, C4]
# Shapes assuming midblocks are [C4, C4, C3]
# Shapes assuming downsamples are [True, True, False]
# B x C x H x W
out = self.conv_in(x)
# B x C1 x H x W
# t_emb -> B x t_emb_dim
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
t_emb = self.t_proj(t_emb)
down_outs = []
for idx, down in enumerate(self.downs):
down_outs.append(out)
out = down(out, t_emb)
# down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
# out B x C4 x H/4 x W/4
for mid in self.mids:
out = mid(out, t_emb)
# out B x C3 x H/4 x W/4
for up in self.ups:
down_out = down_outs.pop()
out = up(out, down_out, t_emb)
# out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
out = self.norm_out(out)
out = nn.SiLU()(out)
out = self.conv_out(out)
# out B x C x H x W
return out
该类是U-Net 架构Unet的实现,专为图像处理任务而设计,例如分割或生成,通常用于扩散模型。该网络包括下采样、中级处理和上采样阶段。它利用时间嵌入执行动态任务(例如扩散模型),利用跳过连接保留空间信息,利用 GroupNorm 进行归一化。
时间嵌入:实现时间动态。
跳过连接:通过连接将细粒度的空间细节集成到解码器中。
灵活的架构:允许通过model_config不同的深度、分辨率和功能丰富度进行定制。
规范化和激活:GroupNorm 确保稳定的训练,而 SiLU 激活则改善非线性。
输出一致性:确保输出图像保留原始的空间尺寸和通道数。
Training
import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train(args):
with open(args.config_path, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
print(config)
diffusion_config = config['diffusion_params']
dataset_config = config['dataset_params']
model_config = config['model_params']
train_config = config['train_params']
# Create the noise scheduler
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
beta_start=diffusion_config['beta_start'],
beta_end=diffusion_config['beta_end'])
# Create the dataset
mnist = MnistDataset('train', im_path=dataset_config['im_path'])
mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4)
# Instantiate the model
model = Unet(model_config).to(device)
model.train()
# Create output directories
if not os.path.exists(train_config['task_name']):
os.mkdir(train_config['task_name'])
# Load checkpoint if found
if os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):
print('Loading checkpoint as found one')
model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
train_config['ckpt_name']), map_location=device))
# Specify training parameters
num_epochs = train_config['num_epochs']
optimizer = Adam(model.parameters(), lr=train_config['lr'])
criterion = torch.nn.MSELoss()
# Run training
for epoch_idx in range(num_epochs):
losses = []
for im in tqdm(mnist_loader):
optimizer.zero_grad()
im = im.float().to(device)
# Sample random noise
noise = torch.randn_like(im).to(device)
# Sample timestep
t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
# Add noise to images according to timestep
noisy_im = scheduler.add_noise(im, noise, t)
noise_pred = model(noisy_im, t)
loss = criterion(noise_pred, noise)
losses.append(loss.item())
loss.backward()
optimizer.step()
print('Finished epoch:{} | Loss : {:.4f}'.format(
epoch_idx + 1,
np.mean(losses),
))
torch.save(model.state_dict(), os.path.join(train_config['task_name'],
train_config['ckpt_name']))
print('Done Training ...')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for ddpm training')
parser.add_argument('--config', dest='config_path',
default='config/default.yaml', type=str)
args = parser.parse_args()
train(args)
加载配置:从 YAML 文件读取训练配置(如数据集路径、超参数和模型设置)。
设置组件:
初始化噪声调度器,用于在不同的时间步添加噪声。
创建一个MNIST 数据集加载器。
实例化U-Net模型。
检查点管理:检查现有检查点,如果可用则加载。创建保存检查点和输出所需的目录。
训练循环:每个时期:
遍历数据集,根据采样的时间步长向图像添加噪声。
使用模型预测噪声并计算损失(预测噪声和实际噪声之间的 MSE)。
使用反向传播更新模型参数并保存模型检查点。
优化:使用 Adam 优化器和 MSE 损失函数来训练模型。
完成:打印 epoch 损失并在每个 epoch 结束时保存模型。
Sampling
import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample(model, scheduler, train_config, model_config, diffusion_config):
xt = torch.randn((train_config['num_samples'],
model_config['im_channels'],
model_config['im_size'],
model_config['im_size'])).to(device)
for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
# Get prediction of noise
noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
# Use scheduler to get x0 and xt-1
xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
# Save x0
ims = torch.clamp(xt, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=train_config['num_grid_rows'])
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
os.mkdir(os.path.join(train_config['task_name'], 'samples'))
img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
img.close()
def infer(args):
# Read the config file #
with open(args.config_path, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
print(config)
diffusion_config = config['diffusion_params']
model_config = config['model_params']
train_config = config['train_params']
# Load model with checkpoint
model = Unet(model_config).to(device)
model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
train_config['ckpt_name']), map_location=device))
model.eval()
# Create the noise scheduler
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
beta_start=diffusion_config['beta_start'],
beta_end=diffusion_config['beta_end'])
with torch.no_grad():
sample(model, scheduler, train_config, model_config, diffusion_config)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
parser.add_argument('--config', dest='config_path',
default='config/default.yaml', type=str)
args = parser.parse_args()
infer(args)
加载配置:从 YAML 文件读取模型、扩散和训练参数。
模型设置:加载训练好的 U-Net 模型检查点。初始化噪声调度程序以指导反向扩散过程。
采样过程:
从随机噪声开始,并在指定的时间步内迭代地对其进行去噪。
在每个时间步:
使用模型预测噪音。
使用调度程序计算去噪图像(x0)并更新当前噪声图像(xt)。
将中间去噪图像作为 PNG 文件保存在输出目录中。
推理:执行采样过程并保存结果而不改变模型。
Config (Default.yaml)
dataset_params:
im_path: 'data/train/images'
diffusion_params:
num_timesteps : 1000
beta_start : 0.0001
beta_end : 0.02
model_params:
im_channels : 1
im_size : 28
down_channels : [32, 64, 128, 256]
mid_channels : [256, 256, 128]
down_sample : [True, True, False]
time_emb_dim : 128
num_down_layers : 2
num_mid_layers : 2
num_up_layers : 2
num_heads : 4
train_params:
task_name: 'default'
batch_size: 64
num_epochs: 40
num_samples : 100
num_grid_rows : 10
lr: 0.0001
ckpt_name: 'ddpm_ckpt.pth'
该配置文件提供了扩散模型的训练和推理的设置。
数据集参数im_path:指定训练图像的路径( )。
扩散参数:设置扩散过程的时间步数和噪声参数的范围(beta_start和beta_end)。
模型参数:
定义模型架构,包括:
输入图像通道(im_channels)和大小(im_size)。
下采样、中间处理和上采样的通道数。
每一级是否发生下采样(down_sample)。
各种块的嵌入尺寸和层数。
训练参数:
指定训练配置,如任务名称、批量大小、时期、学习率和检查点文件名。
包括采样设置,例如用于可视化的样本数量和网格行数。
Dataset (MNIST)
import glob
import os
import torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
class MnistDataset(Dataset):
self.split = split
self.im_ext = im_ext
self.images, self.labels = self.load_images(im_path)
def load_images(self, im_path):
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
ims = []
labels = []
for d_name in tqdm(os.listdir(im_path)):
for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
ims.append(fname)
labels.append(int(d_name))
print('Found {} images for split {}'.format(len(ims), self.split))
return ims, labels
def __len__(self):
return len(self.images)
def __getitem__(self, index):
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
return im_tensor
初始化:采用分割名称、图像文件扩展名(im_ext)和图像路径(im_path)。调用load_images以加载图像路径及其相应的标签。
图像加载:load_images遍历 处的目录结构im_path,假设子目录已标记(例如,数字类别的0、1、...)。收集图像文件路径并根据文件夹名称分配标签。
数据集长度:__len__返回图像的总数。
数据检索:__getitem__通过索引检索图像,将其转换为张量,并将像素值缩放到范围 -1,1-1,1-1,1。