最近在做蛋白从头设计相关研究,看了超多文献(后面也想慢慢把它们分享出来嘿嘿嘿)!发现现在最新的模型其实多数都是基于扩散模型(Diffusion Model),既然如此,依据咱们要知道的尿性,那必须给它掰开了揉碎了鼓捣清楚对不啦!哈哈哈哈哈哈哈哈!开工!!!
BTW:以前介绍模型或算法的时候,我都是先介绍原理,想着代码后续再出,但是这样好像不太有助于大家使用哈!那这次!包括以后!咱们就不一样啦!咱们原理代码一起来!所以,需要代码的小伙伴们!咱们直接刷刷刷往后滑,就可以看到心心念念的代码啦!
本文包含内容如下所示:
浅浅引入一下下
“扩散”是什么意思呢?
扩散模型的基本原理
扩散模型的直观理解
扩散模型 vs. GAN 和 VAE
扩散模型背后的数学原理
前向过程
反向过程
损失函数
训练与采样算法
神经网络
U-Net 架构概述
U-Net 在扩散模型中的应用
U-Net 的优势
代码实现
非条件生成
条件生成
数学推导补充(不看也罢!)
均值方差推导
损失函数推导
参考资料
浅浅引入一下下
扩散模型(Diffusion Model),这个名字大家是不是有点耳熟!现在生成式 AI 很火对不对!它就是一种生成模型,现在已经在条件和非条件的图像、音频、视频生成中均取得了显著成果(目前看来还是图像生成居多),包括 OpenAI 的 GLIDE 和 DALL-E 2(现在已经是 DALL-E 3 啦!)、海德堡大学的 Latent Diffusion 以及 Google Brain 的 ImageGen 等都是基于扩散模型。
实际上,生成式建模中的扩散思想早在2015年就由 Sohl-Dickstein 等人在《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》中提出,但是当时这一方法并未引起广泛关注。直到2019年Song 等人和2020年 Ho 等人对这一方法进行了重要改进,才使得扩散模型在生成式模型领域掀起了一股新潮流。
我们先举个小小的例子让大家感受一下扩散模型的用途!
从上面这两张图,我们可以明显地看出,左边的图是相对比较模糊的(也就是有一些我们所说的噪音),右边的图就相当于一个去噪(denoise)的过程,把这个图片还原了是不是!这其实就是扩散模型的一个用法,把图像分布还原!
那咱们今天呢,就主要基于原始论文 ——(Denoising Diffusion Probabilistic Models (arxiv.org)),对去噪扩散概率模型(Denoising Diffusion Probabilistic Model,DDPM)进行深入浅出地探讨!
原文链接:https://arxiv.org/pdf/2006.11239
小小知识要知道 —— DDPM 和 扩散模型是一个东西吗?
扩散模型(Diffusion Model):
广义概念:扩散模型是一类生成模型的统称,基于扩散过程生成数据。这类模型的核心是通过逐步加噪和去噪的过程生成数据样本。 多种实现形式:扩散模型可以有不同的实现方式和变体,包括不同的噪声添加策略、去噪过程、损失函数等。DDPM 就是扩散模型的一个具体实现。 DDPM(Denoising Diffusion Probabilistic Model):
具体实现:DDPM 是扩散模型的一种具体实现形式,它采用了特定的噪声添加和去噪策略,并且定义了明确的概率模型来指导生成过程。 关键贡献:DDPM 引入了特定的去噪算法和训练方法,使得扩散模型在生成质量和稳定性方面得到了显著提升。它在生成图像、音频、视频等任务中取得了很好的效果。 总的来说呢,扩散模型是一个广泛的类别,而 DDPM 是这个类别中的一个具体模型。所有的 DDPM 都是扩散模型,但并不是所有的扩散模型都是 DDPM。DDPM 通常被认为是扩散模型家族中具有代表性和影响力的成员。
小小知识要知道 —— 条件/非条件生成是什么意思嘞?
“条件生成”,指生成的数据(如图像、音频或视频)是基于某些特定的条件或输入的。这些条件可以是类别标签、文本描述、图像的某个部分或其他信息。
条件生成(Conditional Generation):模型在生成数据时需要根据给定的条件来决定生成的内容。例如,给模型输入一个类别标签“猫”,它就会生成一张猫的图片。另一个例子是文本到图像的生成模型,如 DALL-E,它根据输入的文本描述生成相应的图像。 非条件生成(Unconditional Generation):与之相对,非条件生成指的是模型在没有任何外部条件的情况下生成数据。模型只是简单地从其学习到的数据分布中采样,生成样本。这种生成方式不依赖于任何特定的输入条件。 因此,"条件生成" 在这里表示生成数据是基于某种特定的条件,而"非条件生成" 则表示生成数据时不依赖于特定的条件。
“扩散”是什么意思呢?
我们以清晰的图片变成模糊的画这个小例子,帮助大家理解扩散模型中的“扩散”到底是什么样的一个过程!
扩散过程(添加噪声)
想象:你有一张清晰的照片。现在,想象你在这张照片上涂抹一些透明的油漆,使得照片变得越来越模糊。每次你涂抹一点油漆,照片就变得更模糊,直到最后你完全看不清楚原来的照片是什么样的了。 解释:在扩散模型中,这个涂抹油漆的过程就像是“扩散”,逐步增加噪声,让清晰的图像变得模糊不清。
想象:现在,你有一张完全模糊的图片,像是一片油漆涂满的纸。你的任务是用一种特殊的清洁剂,逐步去除油漆,直到你能看到最初的清晰照片。 解释:去噪过程就是“逆扩散”,你在逐步去除油漆,恢复原始的清晰照片。模型学习如何从模糊的状态逐渐恢复出清晰的图像。
扩散模型的基本原理
这是原文的原理图,看起来似乎略微复杂,不过咱们不慌!这里咱们先稍微笼统地把它介绍一下!这样咱们心里可以有个大致的框架!然后再一点一点深入了解并应用!
扩散模型其实就是通过逐步引导噪声数据变得更加结构化,最终生成高质量的数据样本。这类模型的基础是将数据样本逐步加上噪声,直到它们完全随机化,然后通过逆过程(去噪过程)逐步还原为原始数据。这种方法的核心思想是让模型学习如何逐步从噪声生成数据。
扩散模型主要有两个阶段:
前向扩散阶段(forward process): 这一过程通过逐步添加噪声,将原始数据转化为纯噪声。
反向扩散阶段(reverse process): 在逆向过程中,模型学习如何去掉噪声,将噪声还原为原始数据。这个过程通常通过一种称为“反向扩散”的算法实现。
扩散模型在图像生成、图像修复、图像超分辨率等领域表现出色,生成的图像质量通常优于传统的生成对抗网络(Generative adversarial network,GAN,偷咪咪说一句,它其实已经有一丢丢过时啦)。由于其生成过程是逐步完成的,扩散模型在生成样本时的控制和稳定性方面也表现得更好。
扩散模型的直观理解
为了更好地理解扩散模型,我查阅了超多资料!很多文章中常常直接就是大量的数学公式,推导到最后还是不清楚它的实际作用。或者直接开始各种比较,结果还是让人一脸懵逼。其实呀,扩散模型没有那么复杂!咱们简单几句话就可以相对清晰地帮助大家理解它的!
扩散模型的目的是什么嘞?
学习如何从纯噪声生成图片。
扩散模型是怎么做的呢?
训练一个网络(这里用的是 U-Net),输入一系列添加了噪声的图片,学习去预测这些图片中的噪声。
前向过程在做什么捏?
逐步向真实图片添加噪声,最终得到一张纯噪声的图片。对于训练集中的每一张图片,都可以生成一系列噪声程度不同的加噪图片。在训练时,这些(不同程度的噪声图片 + 生成它们所用的噪声)就是实际的训练样本。
反向过程又在做什么捏?
在模型训练完成后,通过采样和生成图片来实现目标。
扩散模型 vs. GAN 和 VAE
很多文章都对它们进行了对比,直接看的话比较抽象,但是确实值得比较!因为扩散模型在图像合成方面已经击败了 GAN,且与 VAE 的理论关系确实非常紧密,所以咱们也来比一比!
生成对抗网络(GAN,generative adversarial network):
工作原理:GAN 通过两个网络 —— 生成器(generator)和判别器(discriminator)—— 相互对抗,生成器试图生成逼真的数据以骗过判别器,而判别器则试图区分真实数据和生成数据,看看到底哪个是真实存在的,哪个是被“凭空捏造”出来的!其实 GAN 的训练就是两个模型互相学习的过程(非常和谐!),不过“对抗”这个词叭,就显得火药味十足!哈哈哈哈哈哈哈哈哈!
优势:GAN 以快速生成逼真图像而闻名,尤其在计算速度和生成样本的多样性上表现出色。
挑战:GAN 的训练过程不稳定,容易出现模式崩溃(mode collapse)现象,即生成的样本缺乏多样性。
变分自编码器(VAE,variational autoencoder):
工作原理:VAE 通过编码器-解码器(encoder-decoder)架构学习数据的概率分布,从中采样生成新样本。它可以将采样后的概率分布映射到训练集的概率分布,生成隐变量(z),而且隐变量是既含有数据信息又含有噪声的,所以 VAE 除了可以还原输入的样本数据外,还可以用于生成新的数据。它采用了一种随机的方法,在生成数据时会引入一定的随机性。
优势:VAE 生成的样本通常多样性较好,且模型易于训练和解释。
挑战:与 GAN 相比,VAE 生成的图像通常较模糊,质量不如 GAN 或扩散模型高。
扩散模型(diffusion model):
工作原理:扩散模型通过逐步添加和去除噪声来生成数据。在训练过程中,模型学习如何将噪声逐步还原为有意义的信息。其实扩散模型的灵感来源于非平衡热力学,模型首先定义了扩散步骤的马尔科夫链,接着将随机噪声缓慢地添加到数据中,并通过学习逆向扩散过程,从噪声中构建出所需的数据样本。
优势:扩散模型生成的样本质量极高,尤其在图像细节和稳定性上表现优异。它能够避免 GAN 的训练不稳定问题,同时生成的图像往往比 VAE 更清晰。
挑战:扩散模型的生成过程较为复杂且时间较长,因为它是逐步生成的。
总来来看,扩散模型现在正处于一个百花齐放的状态,就像是 GAN 刚提出来的时候,真真的一山更比一山高呐!
扩散模型背后的数学原理
个人认为,如果你的研究目的不是改进扩散模型本身,其实没必要花费过多时间去钻研它的数学原理。对于多数小伙伴而言,还是建议先快速理解扩散模型的整体思想,掌握其核心的训练算法和采样算法,并确保能够运行通代码,到这种程度,就可以开始利用扩散模型做自己的任务啦!
先放个图图,看不懂不要谎!咱们接下来就介绍具体流程和公式推导过程,其实咱们的最终目标就是推导出训练过程中的损失函数。
前向过程
扩散模型的前向过程,其实就是不断地向原始输入数据中添加噪声(注意这个噪声一定要满足高斯分布,也就是我们常说的正态分布),直至原本清晰的图像变成一个类似纯噪声的数据(也就是趋于高斯分布)。我们可以看下面这张图:
每一时刻都要添加噪声,后一时刻的图像都由前一时刻的图像添加噪声得到; 最后的图像会变成纯噪声; 每一时刻的添加的噪声强度均不同,目前有线性调度器、余弦调度器等; 这一过程构建了我们训练所用到的标签(噪音),后续会用到哟
这里我们有个小问题想问一下大家,咱们不是每一时刻都要添加噪音嘛,那每个时刻添加的噪音是一样的吗?
咱们看上面的图,显然是不一样的哈,随着时刻的增加,噪声占比会越来越大,所以添加噪声的强度也会越来越大。
那为什么捏!咱们来解释一下!
我们希望每一步添加噪音后,可以使图像的分布变得更加随机,也就是更加扩散,换句话说,我们希望每一步都能保持相同的扩散幅度。在一开始,只需稍微增加一点点噪音,图像就会稍微扩散一些,但随着时间推移,图像变得越来越不规律,如果我们依然像最初那样仅仅加一点点噪音,变化幅度会逐渐减小,甚至几乎没有变化。所以呀,添加噪音的过程是需要随着时间的推移逐步增加的!
接下来我们开始推导如何从初始图像直接得到第 时刻的图像!
前向过程的定义
我们把上面的介绍从数学角度描述一下:扩散模型的前向过程其实就是将一个原始样本 逐步转化为噪声样本 的过程,通过逐步添加噪声得到一系列样本 ,这个过程通常用马尔科夫链来建模。
原始样本: 是我们需要生成的原始数据样本(比如图像)。 噪声:每一步我们添加的噪声 服从标准正态分布 。 噪声系数: 是一个定义好的噪声增加系数, 随时间步 递增,通常 的值在 (0, 1) 之间(我们可以把它理解为每个时刻噪声的权重,它的初始化值往往都比较小,越往后越大)。
每一步的前向过程公式 —— 如何得到
我们可以这么理解,后一时刻的图像和哪个时刻最相关呢?那肯定是前一时刻对不对!所以 是不是和 是最相关的!来!咱们想想它俩之间有什么关系呢?咱们是不是在 中添加了一些噪音得到了 ,所以说它俩其实就是差了一个噪音对不对!来!那现在!我们就可以写第一个重要的式子啦!
在第 步,我们将噪声 以系数 加到前一步的样本 中,生成当前的样本 :
这个公式咱们可以理解为:(请往下滑一丢丢看公式的解析,在那为什么要要加根号 呢?这个部分的下面,不知道为啥它格式错乱还怎么也调整不回去呜呜呜呜呜呜!)
咱们通俗易懂解释一下!
和 我们可以理解为 和 的权重,就相当于咱们要衡量一下,当前时刻的 是噪音对它影响更大呢,还是前一时刻的状态对它影响更大嘞!后续我们把 称为 ,也就是 ,感觉其实主要是为了方便运算。那我们现在看一下,之前咱们不是说 随时间 递增嘛,所以 就是随时间递减。也就是说, 的权重就会越来越小, 的权重就会越来越大!这是不是就符合咱们之前所说的随着时刻越来越往后,添加的噪音就需要越来越大!
那为什么要要加根号 呢?
其实原始论文是从概率论角度去介绍的,但咱们把它当成权重会更便于理解!有兴趣的小伙伴可以去看看原始论文!有点抽象!太难了!!!
确保了前一步的样本在生成下一步样本时保持一定权重,而不是完全被噪声覆盖, 是添加的噪声部分,随着时间步 的增加,噪声的影响逐步增大。 代表 是服从高斯分布的。
从原始样本 直接推导到任意时间步
现在我们知道了如何得到 ,但我们实际拿到手的数据是什么呢?是原始样本 !那我们该怎么办呢?我们可以通过马尔科夫链的性质,直接从原始样本 推导出任意时间步 的样本 ,而不必逐步递推每一步。这里直接推导出(不理解没关系,这里咱们先写出来,下面咱们会一步一步推导):
这里 是前 步的系数 的累乘:
表示第 t 步样本中保留的原始样本的部分,随着 增加,这部分逐渐减小。 表示第 t 步样本中添加的噪声部分,随着 增加,这部分逐渐增加,直到完全被噪声支配。
马尔科夫链的逐步生成过程
为了更好地理解累积公式 ,咱们从逐步生成的过程开始推导!
上面我们知道了如何从 得到 ,那我们肯定就可以通过 得到 ,……,那我们也一定能通过 得到 ,然后再以此类推得到 !所以其实整个过程都是围绕咱们上面的那个重要公式()!它可以写成(因为**把 称为 ,也就是 **):
那,咱们开始!
首先,我写个第一步第二步的式子帮助大家理解!
第一步 :
在第一步,直接将 加权后叠加一些噪声。
第二步 :
接下来,重头戏来啦!
在第 步时:
将 再次代入,得到:
继续展开:
这里我们可以看到噪音项 和 前面都乘了一个东西,我们知道如果两个独立的高斯分布相加之后,结果仍然会是一个高斯分布,并且新分布的均值和方差可以根据原分布的均值和方差来计算。那现在,假设我们有两个独立的高斯分布 和 ,它们的加法结果 仍然服从一个新的高斯分布,其均值是两个分布的均值之和,方差是两个分布的方差之和。因为在这种情况下均值为零,合并后的噪声分布为:
所以我们前面的 推导公式中,噪声项 和 都服从正态分布,且两者相互独立,因此它们可以合并为一个新的高斯分布:
如果我们不断递归展开,最终可以写出第 步的样本 的通项公式:
其中, 表示所有时间步的系数累乘,它是随时间 变化的衰减因子。 是总的噪声项,从标准正态分布中采样的噪声,表示随着时间的累积,所有噪声项逐渐叠加在一起的结果。在最早的时间步(即 时), 的系数最小,随着时间步的增加,这个系数逐渐增加。这个式子表明,在前向过程中,数据 被逐渐转化为噪声。
为了更好地理解加噪声过程的影响,我们可以讨论如下的分布:
这里, 是一个小于 1 的常数,表示在每一步加噪声时添加的噪声强度。在扩散模型的论文中, 通常从 到 线性增长。随着 的增大, 逐渐减小,导致 更快地趋于 0,最终 将接近于一个标准正态分布,这也正是扩散模型所期望的。
通过上述推导,我们可以看到前向过程是如何通过逐步添加噪声,将原始数据样本 转变为逐步模糊的样本 。随着步数 的增加,噪声逐渐主导样本,直到最终生成一个纯噪声的样本 。这个过程为模型学习逆向去噪过程提供了训练样本!
接下来,咱们看反向过程!
反向过程
上面咱们介绍的前向过程是加噪的过程,那咱们的实际目的是什么嘞!肯定不是得到一个噪音哈哈哈哈哈哈!其实咱们更关心的是它的逆向过程!我们要把上面得到的样本 还原成 ,也就是一个去噪的过程!那我们该怎么做嘞!请继续看下去!加油!
反向过程的目标是从纯噪声逐步生成真实数据样本,具体来说,是通过学到的模型从加噪的样本逐步去噪,最终还原出原始数据样本。也就是从最终生成的噪声样本 开始,逐步去噪,最终生成原始样本 。反向过程可以被看作是前向过程中每一步的逆操作,即我们需要学会从加噪样本 中减去噪声,得到上一时间步的样本 **。
前向过程中,我们可以从 直接得到 ,那我们现在可以直接从 得到原始样本 吗?显然是不可能滴!咱要是能做到,那就真成神啦哈哈哈哈哈!虽然不能直接得到,但咱们可以退而求其次嘛,咱们慢慢推,总能求出来的(虽然只是估计哈哈哈哈哈哈)!从 **得到 **,一步一步得到 。
那么我们就需要找到每一步去噪声的逆操作,并通过神经网络来学习这个逆操作。
理论上,当 足够小时,每一步加噪声的逆操作也将满足正态分布:
其中, 和 分别是该正态分布的均值和方差。
为了描述所有去噪声操作,我们需要根据当前时刻 和当前图像 ,拟合当前时刻的加噪声逆操作的正态分布,即拟合当前的均值 和方差 。
注意注意:以下步骤会省略一丢丢推导细节,主要为了方便大家理解,有兴趣深入了解的小伙伴可以往下滑滑滑,可以看到有个小节叫做数学推导补充(不看也罢!),去看它!哈哈哈哈哈哈哈哈!
直接计算加噪声逆操作分布不太现实,那我们可以曲线救国一下!通过贝叶斯公式(Cool!)计算给定 时的去噪分布:
其中,左边的分布 是加噪声操作的逆操作,满足正态分布 ,均值和方差待求。右边的分布 和 是加噪声的分布,可以通过前面提到的公式()得到:
通过代入这些公式,我们可以计算去噪分布的均值和方差:
其中 是从标准正态分布中采样的噪声,来自于公式 。
去噪分布的方差为:
损失函数
我们希望让模型的输出尽可能接近理论计算得到的去噪分布。为了达到这个目标,最直观的方法是让加噪声逆操作和去噪声操作这两个正态分布尽可能接近。
那怎么让这两个正态分布尽可能接近呢?那必然是让它俩的均值和方差都尽可能接近嘛!
那这里我们要注意, 是加噪声的方差是一个常量。那么,加噪声逆操作的方差 也是一个常量,不与输入 相关。那现在方差是固定的,我们就只需要让均值尽可能接近即可。
怎么办嘞!咱们看上面的均值公式,模型在拟合均值的时候, 是已知的(图像是一步一步倒着去噪的嘛),那公式里唯一不确定的其实只有 ,那我们是不是可以想到,咱们直接预测噪声 (其中 是模型的可学习参数),使得它与生成 的噪声 之间的均方误差最小化就好啦!
这样的话,最终的损失函数就可以写为:
在反向过程中,模型通过学习 个去噪操作来拟合对应的 个加噪声逆操作。每步加噪声逆操作符合正态分布,并且在给定输入 时,该正态分布的均值和方差可以通过解析表达式计算出来。最终,模型的学习目标就是让其输出的分布与理论计算的分布一致。经过一系列化简,我们就可以把问题转换为拟合生成 时用到的随机噪声 。
训练与采样算法
现在我们理解了前向过程和反向过程,那我们就可以知道模型训练和采样的算法啦!
这张图图里是扩散模型的训练与采样算法流程,我们来略微细致地介绍一下!
先看训练阶段:
我们可以看出,训练阶段是使用噪声预测模型来预测图像中的噪声,然后最小化目标噪声与预测噪声之间的损失。具体过程如下:
从训练集中抽取样本 ;然后随机选择一个时间步 从 中;再然后随机生成一个噪声 ;再再然后计算加噪声图像 ,根据公式 ;再再再然后将 和时间步 作为输入传递给神经网络 ,预测噪声 ;最后计算损失函数,即预测噪声与实际噪声之间的均方误差,使用梯度下降算法优化,以最小化损失函数。
再看采样阶段:
在采样阶段,噪声预测模型以噪声图像和当前时间步作为输入,生成一个预测的噪声图像。通过从噪声图像中减去预测噪声(注意不要遗漏系数项),可以得到去噪后的图像。然后然后,注意注意,接下来还需在该图像上加上一个从标准正态分布中采样的噪声,这样才能够得到最终的去噪图像。
那这里为什么要再加上一个采样出来的噪声呢?
在逆向过程中,加入噪声的目的是为了引入随机性,使得每次采样都略有不同,从而生成多样化的样本。这种随机性使得模型能够产生具有不同细节的图像,增加了生成样本的多样性和逼真性。
具体过程如下:
首先,从标准正态分布中采样生成一个噪声图像 ,想要生成不同的图像,只需要更换这个噪声即可;
接下来就是模型的反向过程:
从时间步 开始,逐步计算每个时刻 的去噪声操作,直到时间步 。每一步的操作包括以下几个部分:
计算均值
()
计算方差
方差的计算公式有两种选择,根据 的分布来决定:
当 是特定数据(例如特定图像)时:
当 (即 是标准正态分布)时:,这种情况不需要考虑前一步的方差,因为噪声本身就是标准正态分布。
采样
生成下一个时刻 的图像,使用均值 和方差 :
这里有个特殊情况,就是最后一步去噪。在最后一步()的去噪过程中,会有 ,而 理论上 是从 1 开始的,在 时其实是没有定义的,但我们可以令 ,这样分子 就会变为 0,那我们就不需要再计算方差项啦!其实也就是最后一步的时候不需要再给它加噪音啦!(其实很多咱们可以在代码里看,感觉会更好理解一点!)
到这里,咱们对扩散模型的数学原理有了一定的了解(可算是费劲巴哈搞懂了呜呜呜呜呜呜),那接下来,咱们终于可以开始进行代码实现啦!!!感激涕零!!!!!!
注意注意:以上步骤会省略一丢丢推导细节,主要为了方便大家理解,有兴趣深入了解的小伙伴可以往下滑滑滑,可以看到有个小节叫做数学推导补充(不看也罢!),去看它!哈哈哈哈哈哈哈哈!
神经网络
在扩散模型中,每一步的输入是带有一定噪声的图像(从完全随机噪声开始逐步还原),网络的任务是从这个带噪声的图像中预测出当前步骤的噪声。U-Net 被用来作为网络结构,以利用它的多尺度特征提取和残差连接来有效地处理这些带噪声的图像。
U-Net 架构概述
U-Net 由两部分组成:
编码器(Encoder):这部分逐步下采样输入图像,通过一系列卷积层和池化层来提取特征,形成高层次的抽象表示。 解码器(Decoder):这部分逐步上采样特征图,通过一系列反卷积层和上采样操作恢复图像的分辨率,同时融合编码器对应层的特征图以保留图像细节。
U-Net 的一个关键特点是编码器和解码器之间的跳跃连接(skip connections)。这些连接直接将编码器中的中间特征图传递到解码器的相应层,以帮助解码器更好地恢复原始图像中的细节。
U-Net 在扩散模型中的应用
在扩散模型中,U-Net 的作用主要体现在去噪过程,这个过程是从带噪声的图像逐步生成清晰图像的核心。咱们简单介绍一下 U-Net 在扩散模型中的具体步骤:
输入准备
带噪声的图像:在每个时间步 ,扩散模型输入的是一个已经添加了噪声的图像 。 时间步信息:时间步 通常会作为附加信息输入到网络中,以帮助网络识别噪声的程度和去噪的目标。
下采样:U-Net 从输入的带噪声图像 开始,通过一系列卷积层和池化操作逐步下采样(减少图像的空间分辨率)。这一步骤提取了图像的高层次特征,并压缩了图像的信息。
特征抽取:在编码阶段和解码阶段之间,U-Net 的瓶颈层进一步处理并压缩特征。这是网络的核心部分,负责提取图像的高级语义信息。
上采样:解码阶段通过上采样操作逐步恢复图像的空间分辨率。这一过程中使用了转置卷积(或上采样层)来恢复图像的尺寸,并将高层特征映射回更高分辨率的图像。
特征融合:U-Net 通过跳跃连接将编码阶段的特征图直接连接到解码阶段的相应层。这些连接允许网络在解码阶段利用编码阶段提取的细节信息,从而保持图像的细节和边界信息。 特征重建:跳跃连接帮助网络在恢复图像的过程中保留了重要的细节信息,使得去噪过程更为精确。
预测噪声:最终,U-Net 输出一个与输入图像同尺寸的噪声预测图。这个预测图表示了当前图像中噪声的估计值。通过比较这个预测值和实际的噪声,模型可以逐步减去噪声,逼近原始图像。
损失函数:在训练过程中,模型通过比较预测的噪声与真实噪声的差异来优化。
逐步去噪:在生成过程中,U-Net 被用来逐步从完全噪声的图像生成清晰图像。每一步都应用去噪过程,直到最终生成一个高质量的图像。
U-Net 在扩散模型中的作用是通过其对称的编码-解码架构和跳跃连接有效地去噪图像。它不仅能够提取和恢复图像的高层次特征,还能保持细节信息,确保生成的图像在细节和质量上都达到很高的水平。
U-Net 的优势
有效的细节保留:跳跃连接确保了在逐步恢复图像时,保留了输入图像中的丰富细节。 适应不同尺度的特征:U-Net 的多层结构允许网络同时处理不同尺度的特征,使其能够从全局和局部两个层次上理解图像内容。
在扩散模型的背景下,U-Net 的这种灵活性和高效性使其成为了去噪任务的理想选择,大大提升了模型生成高质量图像的能力(个人感觉可以使用更高级的架构,比如 Transformer 之类的,万一模型性能有更大提升呢🌚)。
代码实现
非条件生成
非条件生成(Unconditional Generation)实际上就是随机生成,不需要对数据进行任何标签指定。在这种生成模式下,模型在不依赖任何输入条件的情况下进行生成,完全根据内部的噪声和自身学习到的分布来随机生成数据。整个生成过程分为前向和后向两部分:
前向过程:前向过程中,模型从纯噪声开始,通过多次迭代逐步对噪声进行处理,最终生成目标数据。由于不需要条件输入,采样时也不会涉及标签。
训练过程:在训练阶段,通常使用 U-Net 网络结构作为核心模型,处理逐步添加噪声的数据, 到 的标签从前向扩散中得到。在模型的输入中嵌入时间步的编码(),这种编码类似于 Transformer 模型中的位置编码(Positional Encoding),能够帮助模型更好地理解和处理不同时间步的信息,使得模型训练更加高效和稳定。
这种方法的核心在于让模型能够仅凭内部的时间步编码和噪声的去噪过程来还原数据,无需外部条件的参与。这种架构使得非条件生成在随机性和多样性上具有极高的表现。
条件生成
条件生成(Conditional Generation)与非条件生成类似,同样使用前向过程对数据进行采样。然而,在条件生成中,除了前向过程的噪声处理外,我们还引入了额外的类别标签编码。这些标签编码作为模型的输入,结合时间步的编码一同用于训练。
在条件生成中,除了噪声与时间步信息外,我们额外嵌入类别标签的编码,使模型能够生成与特定条件相关的数据。这一机制通过 Classifier-Free Guidance(CFG)控制生成的图像中条件和非条件信息的比例。
**CFG (Classifier-Free Guidance)**:用于引导标签控制,通过调整 CFG 可以在生成结果中增加或减少条件生成的影响力。
CFG 参数决定了条件生成与非条件生成之间的比值:CFG 越大,生成图像中条件生成的比例越高。生成的图像可以表示为:生成图像 = (1 - alpha) * 条件生成 + alpha * 非条件生成
,其中 alpha 与 CFG 相关,调节条件和非条件生成的平衡。
此外,为了实现更稳定的模型训练,条件生成过程中采用了指数移动平均(Exponential Moving Average, EMA)策略。EMA 通过对上一代模型参数与当前模型参数进行指数加权平均,减少了离群点对模型参数更新的影响,进而实现更平滑和稳定的梯度更新。这种策略有助于提升生成的稳定性和模型的整体性能。
接下来,咱们正式进入代码实现!
首先,导入需要的库或模块。
import math # 导入 Python 标准库中的 math 模块,提供基本的数学函数和常量
from inspect import isfunction # 从 inspect 模块中导入 isfunction 函数,用于检查对象是否为函数
from functools import partial # 从 functools 模块中导入 partial 函数,用于创建部分参数已固定的新函数
import matplotlib.pyplot as plt # 导入 Matplotlib 库中的 pyplot 模块,用于数据可视化,并命名为 plt
from tqdm.auto import tqdm # 从 tqdm 模块中导入 tqdm 函数,自动选择适当的进度条显示方式
from einops import rearrange # 从 einops 库中导入 rearrange 函数,用于简化张量维度变换操作
import torch # 导入 PyTorch 库,用于深度学习和张量计算
from torch import nn, einsum # 从 PyTorch 库中导入 nn 模块和 einsum 函数,nn 模块包含神经网络构建的基本组件,einsum 用于简化复杂的张量操作
import torch.nn.functional as F # 导入 PyTorch 的 nn.functional 模块,并命名为 F,包含常用的神经网络功能
import numpy as np # 导入 numpy 模块用于数据处理
from PIL import Image # 导入 PIL(Python Imaging Library)的 Image 模块,用于图像处理和操作
import requests # 导入 requests 库用于发送 HTTP 请求,如获取网页数据
from datasets import load_dataset # 从 Hugging Face Datasets 中导入加载数据集的方法
# from torchvision import datasets # 这是另一种下载数据集的方式
from torchvision import transforms # 导入 torchvision 中的数据变换工具
from torch.utils.data import DataLoader # 导入 PyTorch 数据加载器
from pathlib import Path # 用于处理文件路径
from torch.optim import Adam # 用于在训练过程中更新神经网络的权重参数
from torchvision.utils import save_image # 用于将 Tensor 格式的图像保存为图像文件(例如 PNG)
import matplotlib.animation as animation # 用于创建和管理动画,比如生成动态图或将一系列图像保存为GIF文件
from IPython.display import Image # 在Jupyter Notebook中嵌入并显示不同类型的多媒体内容
# Jupyter Notebook 魔法命令,确保 Matplotlib 绘制的图像内联显示在 Notebook 中
%matplotlib inline
我们先定义一些辅助性的函数,帮助我们进行后续的网络构建。
# exists(x) 函数用于检查模型参数或超参数是否被正确初始化。
def exists(x):
return x is not None
# default(val, d) 函数用于提供默认值,尤其在扩散模型的配置阶段。如果某个配置参数(如学习率、噪声调度参数等)未被指定,则使用默认值。对于动态生成的默认值(例如函数的返回值),可以根据需要调用函数 d()。
def default(val, d):
if exists(val): # 如果 val 不是 None
return val
return d() if isfunction(d) else d # 如果 d 是函数,调用 d();否则返回 d
# Residual 类
# 在扩散模型中,残差连接常用于构建深度神经网络,以帮助模型保留原始输入信息并稳定训练过程。残差连接可以减缓梯度消失问题,并使模型更容易学习复杂的特征映射。
# Residual 类中的 forward 方法将输入 x 与通过函数 fn 处理后的结果相加。这在扩散模型的各个阶段(如噪声添加、去噪等)都可以帮助网络更好地学习数据的结构和模式。
class Residual(nn.Module):
def __init__(self, fn):
super().__init__() # 调用父类 nn.Module 的初始化方法
self.fn = fn # 将传入的函数 fn 存储在实例变量 self.fn 中
def forward(self, x, *args, **kwargs):
# 计算 fn(x, *args, **kwargs) 的结果,并加上输入 x
# 在扩散模型中,这种操作可以帮助模型保留输入信息
return self.fn(x, *args, **kwargs) + x
# 上采样函数
# 在扩散模型的生成阶段,用于将特征图的尺寸扩大
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1) # 定义一个转置卷积层
# 参数解释:
# - dim: 输入和输出的通道数
# - 4: 卷积核的大小
# - 2: 卷积步幅,放大特征图尺寸
# - 1: 填充,保持输出特征图的空间尺寸
# 下采样函数
# 在扩散模型的编码阶段,用于将特征图的尺寸缩小
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1) # 定义一个卷积层
# 参数解释:
# - dim: 输入和输出的通道数
# - 4: 卷积核的大小
# - 2: 卷积步幅,缩小特征图尺寸
# - 1: 填充,保持输出特征图的空间尺寸
计算位置编码。
# SinusoidalPositionEmbeddings 类用于计算位置编码
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
"""
初始化位置编码层。
参数:
dim (int): 位置编码的维度。该维度应为偶数,因为我们将生成正弦和余弦两部分编码,每部分的维度都是 dim / 2。
"""
super().__init__()
self.dim = dim # 记录位置编码的维度
def forward(self, time):
"""
计算位置编码。
参数:
time (Tensor): 输入的时间步张量,其形状为 (N,) 或 (N,),其中 N 是时间步的数量。
返回:
Tensor: 计算得到的位置编码,其形状为 (N, dim),N 是时间步的数量,dim 是位置编码的维度。
"""
device = time.device # 获取输入张量所在的设备(如CPU或GPU)
half_dim = self.dim // 2 # 计算位置编码维度的一半,我们将为每个位置计算正弦和余弦编码,因此总维度是 dim。
# 计算频率缩放因子,这里使用对数缩放因子来调整不同频率的正弦和余弦函数的周期。`math.log(10000) / (half_dim - 1)`计算对数缩放因子,确保频率从较低到较高范围。
embeddings = math.log(10000) / (half_dim - 1)
# 计算每个频率的缩放因子,并应用到 time 上。embeddings 变成了一个形状为 (half_dim,) 的张量,其中每个值对应不同频率的缩放因子。
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
# 将时间步与频率缩放因子相乘,计算位置编码的核心值。time[:, None] 的形状是 (N, 1),embeddings[None, :] 的形状是 (1, half_dim),它们的广播机制使得得到形状为 (N, half_dim) 的张量。
embeddings = time[:, None] * embeddings[None, :]
# 对每个位置,计算正弦和余弦值,并将它们沿着最后一个维度拼接,得到最终的位置编码。结果的形状为 (N, dim),其中 dim 是正弦和余弦编码的总维度。
# embeddings.sin() 计算每个频率的正弦值
# embeddings.cos() 计算每个频率的余弦值
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
U-Net 的 Block 实现,可以用 ResNet 或 ConvNeXT。
# Block 类: 实现了一个标准的卷积模块,包含卷积层、归一化层和激活函数。可选地,还可以对输出进行缩放和偏移。
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
"""
初始化Block模块。
参数:
dim (int): 输入通道的数量。
dim_out (int): 输出通道的数量。
groups (int): GroupNorm中的分组数目。
"""
super().__init__()
# 卷积层,将输入的通道数 dim 转换为 dim_out
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
# GroupNorm归一化层,分组数为groups
self.norm = nn.GroupNorm(groups, dim_out)
# SiLU激活函数
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
"""
前向传播函数。
参数:
x (Tensor): 输入张量,其形状应为 (N, dim, H, W)。
scale_shift (tuple or None): 如果提供,则为 (scale, shift),对输出进行缩放和偏移。
返回:
Tensor: 输出张量,其形状与输入张量相同。
"""
x = self.proj(x) # 卷积操作
x = self.norm(x) # 归一化操作
if exists(scale_shift):
scale, shift = scale_shift
# 对输出进行缩放和偏移
x = x * (scale + 1) + shift
x = self.act(x) # 激活函数
return x
# ResnetBlock 类: 实现了一个带有残差连接的ResNet风格块。它包括两个 Block 层,并且可以选择是否使用时间嵌入(通过MLP层)。残差连接帮助训练更深的网络,并防止梯度消失。
class ResnetBlock(nn.Module):
"""
实现了一个ResNet风格的块。
参考文献:
https://arxiv.org/abs/1512.03385 —— 《Deep Residual Learning for Image Recognition》
"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
"""
初始化ResnetBlock模块。
参数:
dim (int): 输入通道的数量。
dim_out (int): 输出通道的数量。
time_emb_dim (int or None): 如果提供,则为时间嵌入的维度。
groups (int): GroupNorm中的分组数目。
"""
super().__init__()
# 如果提供了时间嵌入维度,则创建一个MLP层
self.mlp = (
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
if exists(time_emb_dim)
else None
)
# 两个Block层
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
# 残差卷积层,用于调整维度。如果输入通道和输出通道不同,则应用卷积;否则,使用身份映射
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
"""
前向传播函数。
参数:
x (Tensor): 输入张量,其形状为 (N, dim, H, W)。
time_emb (Tensor or None): 时间嵌入,如果提供,则形状应为 (N, time_emb_dim)。
返回:
Tensor: 输出张量,其形状与输入张量相同。
"""
h = self.block1(x) # 第一个Block层
if exists(self.mlp) and exists(time_emb):
# 如果时间嵌入存在,则将其传入MLP,得到条件嵌入
time_emb = self.mlp(time_emb) # 得到形状为 (N, dim_out)
# 对时间嵌入进行形状调整,以便进行广播
time_emb = rearrange(time_emb, "b c -> b c 1 1") # 形状调整为 (N, dim_out, 1, 1)
# 将条件嵌入加到第一个Block的输出上
h = time_emb + h
h = self.block2(h) # 第二个Block层
# 添加残差连接:将输入 x 的卷积结果加到输出上
return h + self.res_conv(x)
# ConvNextBlock 类: 实现了ConvNext风格的块,使用Depthwise卷积、激活函数、归一化层和残差连接。它也可以使用时间嵌入来增强特征。这个块设计用于更现代的卷积网络架构(可以参考class ResnetBlock进行理解)。
class ConvNextBlock(nn.Module):
"""
实现了ConvNext风格的块。
参考文献:
https://arxiv.org/abs/2201.03545 —— A ConvNet for the 2020s
"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
"""
初始化ConvNextBlock模块。
参数:
dim (int): 输入通道的数量。
dim_out (int): 输出通道的数量。
time_emb_dim (int or None): 如果提供,则为时间嵌入的维度。
mult (int): 输出通道的倍数,用于中间卷积层。
norm (bool): 是否应用归一化层。
"""
super().__init__()
# 如果提供了时间嵌入维度,则创建一个MLP层
self.mlp = (
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
# Depthwise卷积,处理输入通道的空间特征
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
# 主网络结构
self.net = nn.Sequential(
# GroupNorm归一化层(可选)
nn.GroupNorm(1, dim) if norm else nn.Identity(),
# 卷积层,扩大通道数
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
nn.GELU(), # 激活函数
# 归一化层
nn.GroupNorm(1, dim_out * mult),
# 卷积层,减少通道数到目标输出
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
)
# 残差卷积层,用于调整维度。如果输入通道和输出通道不同,则应用卷积;否则,使用身份映射
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
"""
前向传播函数。
参数:
x (Tensor): 输入张量,其形状为 (N, dim, H, W)。
time_emb (Tensor or None): 时间嵌入,如果提供,则形状应为 (N, time_emb_dim)。
返回:
Tensor: 输出张量,其形状与输入张量相同。
"""
h = self.ds_conv(x) # Depthwise卷积操作
if exists(self.mlp) and exists(time_emb):
# 如果时间嵌入存在,则将其传入MLP,得到条件嵌入
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb) # 得到形状为 (N, dim)
# 对条件嵌入进行形状调整,以便进行广播
condition = rearrange(condition, "b c -> b c 1 1") # 形状调整为 (N, dim, 1, 1)
# 将条件嵌入加到Depthwise卷积的输出上
h = h + condition
h = self.net(h) # 主网络结构
# 添加残差连接:将输入 x 的卷积结果加到输出上
return h + self.res_conv(x)
两种注意力模块(Attention module),一个是常规的 multi-head self-attention,一个是 linear attention variant。
# Attention 类
# 实现了标准的自注意力机制,使用 Conv2d 计算查询、键和值,通过计算注意力权重来获得加权平均值。
# 主要步骤包括:计算注意力得分,数值稳定性处理,计算注意力权重,计算加权平均值,和调整输出形状。
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
"""
初始化标准的自注意力机制模块。
参数:
dim (int): 输入的通道数量。
heads (int): 注意力头的数量。
dim_head (int): 每个注意力头的维度。
"""
super().__init__()
# 缩放因子,用于缩放注意力分数
self.scale = dim_head**-0.5
# 注意力头的数量
self.heads = heads
# 计算隐藏维度,等于每个头的维度乘以头的数量
hidden_dim = dim_head * heads
# qkv卷积层,将输入通道数转换为隐藏维度的三倍
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# 输出卷积层,将隐藏维度映射回原始的输入通道数
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
"""
前向传播函数。
参数:
x (Tensor): 输入张量,其形状为 (b, c, h, w)。
返回:
Tensor: 输出张量,其形状为 (b, dim, h, w)。
"""
b, c, h, w = x.shape
# 通过卷积层计算q、k、v,qkv为三个张量,分别表示查询、键、值
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv的形状为 (b, hidden_dim, h, w)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
) # q, k, v的形状为 (b, heads, dim_head, h*w)
q = q * self.scale # 缩放查询张量
# 计算注意力得分
sim = einsum("b h d i, b h d j -> b h i j", q, k) # sim的形状为 (b, heads, h*w, h*w)
# 数值稳定性处理,减去每行最大值,防止溢出
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
# 计算注意力权重
attn = sim.softmax(dim=-1) # attn的形状为 (b, heads, h*w, h*w)
# 计算加权平均值
out = einsum("b h i j, b h d j -> b h i d", attn, v) # out的形状为 (b, heads, h*w, dim_head)
# 调整输出形状
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) # out的形状为 (b, hidden_dim, h, w)
return self.to_out(out) # 通过输出卷积层,将隐藏维度映射回输入通道数
# LinearAttention 类(和 class Attention 几乎一致)
# 实现了一种更高效的线性注意力机制,避免了标准注意力机制中的 O(n^2) 复杂度。
# 主要步骤包括:计算注意力权重,计算上下文(键和值的乘积),以及计算加权平均值。通过线性变换处理注意力,以提高效率。
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
"""
初始化线性注意力机制模块。
参数:
dim (int): 输入的通道数量。
heads (int): 注意力头的数量。
dim_head (int): 每个注意力头的维度。
"""
super().__init__()
# 缩放因子,用于缩放注意力分数
self.scale = dim_head**-0.5
# 注意力头的数量
self.heads = heads
# 计算隐藏维度,等于每个头的维度乘以头的数量
hidden_dim = dim_head * heads
# qkv卷积层,将输入通道数转换为隐藏维度的三倍
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# 输出层,包括卷积层和GroupNorm归一化层
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1), # 将隐藏维度映射回输入通道数
nn.GroupNorm(1, dim) # 归一化层
)
def forward(self, x):
"""
前向传播函数。
参数:
x (Tensor): 输入张量,其形状为 (b, c, h, w)。
返回:
Tensor: 输出张量,其形状为 (b, dim, h, w)。
"""
b, c, h, w = x.shape
# 通过卷积层计算q、k、v,qkv为三个张量,分别表示查询、键、值
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv的形状为 (b, hidden_dim, h, w)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
) # q, k, v的形状为 (b, heads, dim_head, h*w)
# 计算加权平均值(线性注意力)
q = q.softmax(dim=-2) # 归一化q
k = k.softmax(dim=-1) # 归一化k
q = q * self.scale # 缩放查询张量
# 计算上下文(键和值的乘积)
context = torch.einsum("b h d n, b h e n -> b h d e", k, v) # context的形状为 (b, heads, dim_head, dim_head)
# 计算注意力加权输出
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) # out的形状为 (b, heads, dim_head, h*w)
# 调整输出形状
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) # out的形状为 (b, hidden_dim, h, w)
return self.to_out(out) # 通过输出层,得到最终的输出
在 DDPM 中,作者在 U-Net 的卷积和注意力层中使用了 Group Normalization(GN)作为正则化技术。归一化有助于减轻内部协变量偏移,使得每一层的输入保持在合理的范围内,从而加速训练并提高模型的性能。为了在注意力层之前应用归一化,我们定义了一个PreNorm
类。归一化的应用时机在 Transformer 中依然存在一定的争议——是应该在注意力机制之前还是之后进行。
# PreNorm类用于在应用某个函数之前对输入进行归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
"""
初始化PreNorm类
参数:
- dim (int): 输入特征图的通道数
- fn (nn.Module): 应用在输入上的函数(如卷积层或注意力层)
"""
super().__init__()
self.fn = fn # 存储要应用的函数(例如卷积层或注意力层)
self.norm = nn.GroupNorm(1, dim) # 定义GroupNorm层,使用1个组进行归一化,通道数为dim
def forward(self, x):
"""
前向传播方法,将输入x先经过归一化层,然后应用存储的函数fn
参数:
- x (Tensor): 输入特征图,形状为 (batch_size, dim, height, width)
返回:
- Tensor: 经归一化和函数fn处理后的特征图
"""
x = self.norm(x) # 对输入进行Group Normalization,标准化特征图,这样可以使得输入数据的均值和方差变得一致,有助于训练的稳定性。
return self.fn(x) # 归一化后的特征图再传递给self.fn进行进一步处理(例如卷积、注意力等),然后返回处理后的结果。
接下来,我们构建 U-Net 网络结构。
这个结构其实就是将噪声图像和步骤 t
作为输入,通过一系列的下采样、瓶颈、上采样阶段,逐步提取并恢复图像特征,最后预测每个图像上添加的噪声。网络结合了卷积、注意力机制和归一化等技术,以实现高效的特征处理和噪声预测。我们来对它进行一个略微详细的介绍!
输入:
一批噪声图像,形状为 (batch_size, num_channels, h, w)
。对应每张图像的步骤 t
,形状为(batch_size, 1)
。
输出:
预测的噪声图像,形状为 (batch_size, num_channels, h, w)
。
网络结构:
初始卷积层:
输入图像通过一个卷积层进行初步处理。 同时计算步骤 t
对应的嵌入(embedding)。
2个 ResNet/ConvNeXT 块: 用于特征提取和处理。 Group Normalization: 归一化处理,以稳定训练过程并加快收敛速度。 Attention: 通过注意力机制增强重要特征的表示。 Residual Connection: 残差连接,帮助避免梯度消失并加快训练。 Downsample Operation: 下采样操作,逐渐减少特征图的尺寸以捕捉更高层次的特征。 该阶段包括多个下采样步骤,每个步骤包含:
应用一个带注意力的 ResNet 或 ConvNeXT 块,用于处理经过下采样后的特征,并进一步提取重要信息。
2个 ResNet/ConvNeXT 块: 用于进一步特征提取和处理。 Group Normalization: 继续进行归一化处理,以保持训练稳定性。 Attention: 通过注意力机制增强特征表示。 Residual Connection: 残差连接,以避免训练中的问题。 Upsample Operation: 上采样操作,逐渐恢复特征图的尺寸以生成最终输出。 该阶段包括多个上采样步骤,每个步骤包含:
通过一个最终的 ResNet/ConvNeXT 块和一个卷积层,输出最终的预测结果,即图像上的噪声。
# U-Net 网络结构
# 注意注意,咱们这里的代码并没有类别标签编码的实现,只是实现了时间步编码,如果有需要,我们可以把类别标签作为额外的输入传递给网络。这可以通过类似的 self.time_mlp 模块进行处理,或者直接作为额外的输入信息嵌入到生成过程中。
class Unet(nn.Module):
def __init__(
self,
dim, # 网络的基础维度,例如图像的尺寸,通常为28
init_dim=None, # 初始卷积的输出维度,默认为None,会被设置为dim // 3 * 2
out_dim=None, # 网络输出的维度,默认为None,最终取channels(输出通道数)
dim_mults=(1,2,4,8), # 每个阶段的维度缩放因子
channels=3, # 输入图像的通道数,默认为3
with_time_emb=True, # 是否使用时间嵌入
resnet_block_groups=8, # 如果使用ResnetBlock,则groups参数为resnet_block_groups
use_convnext=True, # 是否使用ConvNextBlock,True则使用ConvNextBlock,False则使用ResnetBlock
convnext_mult=2, # 如果使用ConvNextBlock,mult参数为convnext_mult
):
super().__init__()
# 确定初始维度
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2) # 设置init_dim,默认值为dim的三分之一乘以2
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) # 初始卷积层,用于将输入通道转换为init_dim
# 确定每个阶段的维度
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 创建维度列表,从init_dim开始,逐步乘以dim_mults中的因子
in_out = list(zip(dims[:-1], dims[1:])) # 创建每对阶段的输入和输出维度的列表
# 根据use_convnext的值选择块类型
if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult) # 使用ConvNextBlock,并设置mult参数
else:
block_klass = partial(ResnetBlock, groups=resnet_block_groups) # 使用ResnetBlock,并设置groups参数
# 时间嵌入
if with_time_emb:
time_dim = dim * 4 # 时间嵌入的维度
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim), # 使用正弦位置嵌入
nn.Linear(dim, time_dim), # 线性层将dim映射到time_dim
nn.GELU(), # GELU激活函数
nn.Linear(time_dim, time_dim), # 另一个线性层保持time_dim
)
else:
time_dim = None
self.time_mlp = None
# 定义下采样层
self.downs = nn.ModuleList([]) # 初始化下采样网络的空列表
self.ups = nn.ModuleList([]) # 初始化上采样网络的空列表
num_resolutions = len(in_out) # 计算阶段的数量
# 添加下采样层
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) # 判断是否为最后一对维度
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim), # 添加块
block_klass(dim_out, dim_out, time_emb_dim=time_dim), # 添加块
Residual(PreNorm(dim_out, LinearAttention(dim_out))), # 添加注意力层
Downsample(dim_out) if not is_last else nn.Identity(), # 添加下采样层或身份层
]
)
)
# 定义瓶颈层
mid_dim = dims[-1] # 最后阶段的维度
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) # 中间阶段的第一个块
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) # 中间阶段的注意力层
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) # 中间阶段的第二个块
# 添加上采样层
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) # 判断是否为最后一对维度
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), # 添加块
block_klass(dim_in, dim_in, time_emb_dim=time_dim), # 添加块
Residual(PreNorm(dim_in, LinearAttention(dim_in))), # 添加注意力层
Upsample(dim_in) if not is_last else nn.Identity(), # 添加上采样层或身份层
]
)
)
# 定义最终卷积层
out_dim = default(out_dim, channels) # 设置输出维度
self.final_conv = nn.Sequential(
block_klass(dim, dim), # 最后一层块
nn.Conv2d(dim, out_dim, 1) # 最终卷积层,用于将输出维度映射到目标通道数
)
def forward(self, x, time):
"""
前向传播方法
参数:
- x (Tensor): 输入特征图,形状为 (batch_size, channels, height, width)
- time (Tensor): 时间嵌入,形状为 (batch_size, dim)
返回:
- Tensor: 网络的输出特征图
"""
x = self.init_conv(x) # 通过初始卷积层处理输入
t = self.time_mlp(time) if exists(self.time_mlp) else None # 计算时间嵌入(如果存在)
h = [] # 存储每个下采样阶段的输出
# 下采样阶段
for block1, block2, attn, downsample in self.downs:
x = block1(x, t) # 通过第一个块
x = block2(x, t) # 通过第二个块
x = attn(x) # 通过注意力层
h.append(x) # 存储输出
x = downsample(x) # 通过下采样层
# 瓶颈阶段
x = self.mid_block1(x, t) # 通过中间阶段的第一个块
x = self.mid_attn(x) # 通过中间阶段的注意力层
x = self.mid_block2(x, t) # 通过中间阶段的第二个块
# 上采样阶段
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1) # 将当前特征图与上一个阶段的特征图拼接
x = block1(x, t) # 通过第一个块
x = block2(x, t) # 通过第二个块
x = attn(x) # 通过注意力层
x = upsample(x) # 通过上采样层
return self.final_conv(x) # 通过最终卷积层得到输出
接下来,我们定义前向扩散过程。
在定义前向扩散过程时,通常会选择不同的 β 调度策略来控制噪声的添加。以下是常见的四种β调度策略,其中 DDPM 采用了线性调度(Linear Schedule),而后续研究指出余弦调度(Cosine Schedule)可能会取得更好的效果。下面是对这些调度策略的简单定义,我们可以根据需求选择适合的调度策略。
# 余弦调度(Cosine Schedule)
def cosine_beta_schedule(timesteps, s=0.008):
"""
余弦调度(Cosine Schedule),参考 https://arxiv.org/abs/2102.09672
该调度方法通过余弦函数调节β值,使得噪声添加过程更加平滑。
参数:
- timesteps: 总的时间步数
- s: 调整因子,用于平滑处理(默认为0.008)
返回:
- 调度后的β值,范围在0.0001到0.9999之间
"""
steps = timesteps + 1 # 因为需要包括0步,所以总步数为timesteps + 1
x = torch.linspace(0, timesteps, steps) # 创建从0到timesteps的线性空间
# 计算累积的alphas值,余弦函数的平方
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] # 归一化,使得第一个值为1
# 计算β值为1减去相邻的alphas_cumprod的比值
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
# 将β值裁剪到0.0001到0.9999的范围内,避免极端值
return torch.clip(betas, 0.0001, 0.9999)
# 线性调度(Linear Schedule)
def linear_beta_schedule(timesteps):
"""
线性调度(Linear Schedule),β值在每个时间步均匀变化。
参数:
- timesteps: 总的时间步数
返回:
- 线性变化的β值
"""
beta_start = 0.0001 # 初始β值
beta_end = 0.02 # 结束β值
# 从beta_start到beta_end线性插值生成β值
return torch.linspace(beta_start, beta_end, timesteps)
# 平方调度(Quadratic Schedule)
def quadratic_beta_schedule(timesteps):
"""
平方调度(Quadratic Schedule),β值按平方根函数变化。
参数:
- timesteps: 总的时间步数
返回:
- 按平方函数变化的β值
"""
beta_start = 0.0001 # 初始β值
beta_end = 0.02 # 结束β值
# 从beta_start的平方根到beta_end的平方根线性插值,然后平方
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
# Sigmoid调度(Sigmoid Schedule)
def sigmoid_beta_schedule(timesteps):
"""
Sigmoid调度(Sigmoid Schedule),β值通过sigmoid函数调节。
参数:
- timesteps: 总的时间步数
返回:
- 通过sigmoid函数生成的β值
"""
beta_start = 0.0001 # 初始β值
beta_end = 0.02 # 结束β值
# 从-6到6线性插值生成β值,然后通过sigmoid函数转化
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# 对比四种β调度策略(cosine、linear、quadratic、sigmoid)的不同变化趋势,直观了解每种调度策略的特点和对噪声添加过程的影响。
x = np.linspace(1, 1001, 1000) # 生成一个从1到1001的等间距数组,共1000个点,该数组用于表示时间步,作为横轴数据
timesteps = 1000 # 设置总的时间步数为1000
fig, ax = plt.subplots() # 创建一个新的绘图实例,包括一个图形(fig)和一个子图(ax)
# 绘制cosine调度的β值曲线,使用定义的cosine_beta_schedule函数
# 这里将β值除以50缩小,以便和其他曲线更好地对比
ax.plot(x, (cosine_beta_schedule(timesteps, s=0.008) / 50).numpy(), label='cosine')
# 绘制linear调度的β值曲线,使用定义的linear_beta_schedule函数
ax.plot(x, linear_beta_schedule(timesteps).numpy(), label='linear')
# 绘制quadratic调度的β值曲线,使用定义的quadratic_beta_schedule函数
ax.plot(x, quadratic_beta_schedule(timesteps).numpy(), label='quadratic')
# 绘制sigmoid调度的β值曲线,使用定义的sigmoid_beta_schedule函数
ax.plot(x, sigmoid_beta_schedule(timesteps).numpy(), label='sigmoid')
plt.legend() # 显示图例,为每条曲线加上标签(cosine、linear、quadratic、sigmoid)
plt.show() # 显示绘制的图形
在 DDPM 中,我们选用第二种线性调度策略(Linear Schedule)来定义 β 值,并将时间步数 设置为 200。以下是各参数的定义和计算过程,我们将在每个时间步 下预先计算这些参数,以便后续使用:
β:扩散过程中添加的噪声强度,线性变化。 α:噪声的保留量,定义为 。 累积 α () (alphas_cumprod):从时间步1到时间步 的所有α的累积乘积,定义为 。 前一步的累积 α () (alphas_cumprod_prev):时间步 时的累积 α 值,即 。 (sqrt_recip_alphas):累积 α 的倒数平方根,用于调整样本的标准化过程。 (sqrt_alphas_cumprod):累积 α 的平方根,用于生成过程的重建。 (sqrt_one_minus_alphas_cumprod):用于表示噪声部分的平方根,即 。 后验方差 (posterior_variance):用来定义逆过程中的噪声方差,计算公式为 。
通过将 设置为 200,我们可以提前计算每个时间步下的这些参数,从而为扩散模型的训练和采样过程奠定基础。
timesteps = 200 # 设置时间步数 T 为 200
# 定义 β 的线性调度
betas = linear_beta_schedule(timesteps=timesteps) # 调用函数生成线性变化的 β 值数组
# 计算 α
alphas = 1. - betas # 根据 β 计算 α,即 α = 1 - β
alphas_cumprod = torch.cumprod(alphas, axis=0) # 计算 α 的累积乘积,即 \bar{α_t} = \prod_{s=1}^{t} α_s
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# 将 alphas_cumprod 左侧补 1 得到 \bar{α_{t-1}}, 用于计算逆过程中的后验方差
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 计算 α 的倒数平方根,主要用于去噪过程的标准化
# 计算扩散 q(x_t | x_{t-1}) 和相关参数
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # 计算 \sqrt{\bar{α_t}},用于样本的重建过程
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# 计算 \sqrt{1 - \bar{α_t}},用于表示噪声部分在扩散中的权重
# 计算后验分布 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# 计算后验方差,用于逆扩散过程中的噪声强度控制:β * (1 - \bar{α_{t-1}}) / (1 - \bar{α_t})
# extract 函数用于从指定参数数组中提取特定时间步 t 对应的参数值
def extract(a, t, x_shape):
"""
从给定的参数数组 a 中提取指定时间步 t 对应的值,并将其重塑为输入 x 的形状。
参数:
- a: 参数数组,例如 sqrt_alphas_cumprod。
- t: 时间步张量,包含时间步索引。
- x_shape: 输入数据的形状,用于调整输出张量的形状。
返回:
- 重塑后的提取值张量,适应输入数据的形状。
"""
batch_size = t.shape[0] # 获取 batch size,即 t 中包含的时间步数
out = a.gather(-1, t.cpu()) # 使用 gather 从 a 中提取 t 对应的值
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# 重塑输出为 (batch_size, 1, ..., 1) 的形状,以适应输入 x 的形状并将张量移动到 t 所在的设备上
# # 测试 extract 函数
# sqrt_alphas_cumprod = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # 伪造的 sqrt_alphas_cumprod 数组
# x_start = torch.ones([1, 3, 8, 8]) # 生成形状为 (1, 3, 8, 8) 的张量,模拟输入数据
# out = extract(a=sqrt_alphas_cumprod, t=torch.tensor([5]), x_shape=x_start.shape)
# # 测试 extract 函数,提取 t = 5 时对应的参数值
# print(out.shape) # 输出提取后的形状
咱们用一个实例来展示一下前向加噪过程,也就是扩散模型中的前向扩散过程。在这个过程中,原始图片逐渐被加入噪声,模拟从纯净数据到完全随机噪声的变化。
# 随便导入一张图片
# 通过 URL 从 COCO 数据集中导入一张示例图片
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # 请求图片并打开
image # 展示图片
通过一系列图像处理操作,将图片转化为适合扩散模型的格式。
# 变变变!
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
image_size = 128 # 设置图片大小
transform = Compose([
Resize(image_size), # 将图片缩放为 128x128
CenterCrop(image_size), # 从中心裁剪出 128x128 的图片
ToTensor(), # 将图片转换为 Tensor 格式,像素值范围从 [0, 255] 转为 [0, 1]
Lambda(lambda t: (t * 2) - 1), # 将像素值从 [0, 1] 转换到 [-1, 1] 区间,符合模型输入要求
])
# 对图片进行变换处理,并增加一个 batch 维度
x_start = transform(image).unsqueeze(0) # 结果是一个形状为 (1, 3, 128, 128) 的张量
x_start.shape # 查看处理后的图片形状
反变换用于将经过处理的图像(Tensor)还原为可视化的 PIL 图像,主要用于展示加噪后的效果。
# 图像还原
reverse_transform = Compose([
Lambda(lambda t: (t + 1) / 2), # 将像素值从 [-1, 1] 转回到 [0, 1]
Lambda(lambda t: t.permute(1, 2, 0)), # 将通道从 CHW 转换为 HWC 格式
Lambda(lambda t: t * 255.), # 将像素值恢复到 [0, 255] 范围
Lambda(lambda t: t.numpy().astype(np.uint8)), # 转为 NumPy 数组并转换为 8-bit 图像
ToPILImage(), # 转为 PIL 图像,便于展示
])
# 处理后的图片
reverse_transform(x_start.squeeze()) # 将原始处理后的张量还原为图片,并展示
在正向扩散过程中,我们通过对原始图像逐步加入噪声来生成不同时刻的噪声图像。这一过程的公式定义为:
其中:
表示第 个时间步的图像。 是原始的无噪声图像。 是从前向扩散过程中累乘得到的参数,决定了信号保留的比例。 是从标准正态分布中采样的随机噪声。
准备工作已经完成,接下来就可以开始定义正向扩散过程啦!
# 定义前向扩散过程函数 q_sample (使用 nice property),用于模拟逐步添加噪声的过程,逐步生成中间加噪图像。
def q_sample(x_start, t, noise=None):
"""
q_sample 函数实现前向加噪过程,在每个时间步 t 将噪声逐渐加入原始图像。
参数:
- x_start: 原始输入图像张量
- t: 当前时间步(噪声程度)
- noise: 加入的噪声,如果未指定则随机生成
返回:
- 加噪后的图像张量
"""
if noise is None:
noise = torch.randn_like(x_start) # 如果没有传入噪声,则生成与输入图像形状一致的随机噪声
# 提取 t 对应的 α 和 (1-α) 参数,用于当前时间步的加噪计算
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
# 返回加噪后的图像: \sqrt{\bar{α_t}} * x_start + \sqrt{1 - \bar{α_t}} * noise
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
# get_noisy_image 函数调用 q_sample 来生成加噪后的图像,并将其转回 PIL 格式以便展示。
def get_noisy_image(x_start, t):
"""
生成加噪图像,并将其转为 PIL 格式便于展示。
参数:
- x_start: 原始输入图像张量
- t: 当前时间步(噪声程度)
返回:
- 加噪后的 PIL 图像
"""
# 调用 q_sample 函数生成加噪图像
x_noisy = q_sample(x_start, t=t)
# 将加噪后的图像张量转换为 PIL 图像
noisy_image = reverse_transform(x_noisy.squeeze())
return noisy_image
# 示例:生成某一时间步的加噪图像
t = torch.tensor([40]) # 设置时间步 t 为 40
# 生成并展示加噪后的图像
get_noisy_image(x_start, t)
我们展示一下不同时间步 t 下的噪声图像效果!
# 设置随机种子以确保结果可重复性
torch.manual_seed(1234) # 设定随机种子为0,使得每次运行生成的随机数相同
# 定义用于绘制图像的函数,这是 PyTorch 官方文档中的示例函数
# 来源: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
# 如果imgs不是二维列表,则将其转换为二维列表,即使只有一行
if not isinstance(imgs[0], list):
imgs = [imgs]
# 获取图像的行数和列数
num_rows = len(imgs) # 行数
num_cols = len(imgs[0]) + with_orig # 列数,考虑是否包括原始图像
# 创建一个指定大小的子图矩阵,nrows是行数,ncols是列数
fig, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
# 遍历每一行图像
for row_idx, row in enumerate(imgs):
# 如果 with_orig 为 True,在每一行的图像前插入原始图像
row = [image] + row if with_orig else row
# 遍历每一列的图像
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx] # 获取当前子图的坐标位置
# 显示图像,np.asarray 将 PIL 图像转换为 NumPy 数组
ax.imshow(np.asarray(img), **imshow_kwargs)
# 去除图像的刻度和标签
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# 如果包含原始图像,设置第一个子图的标题为 "Original image"
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8) # 设置标题字体大小为8
# 如果行标题存在,为每一行第一个图像设置标题
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout() # 自动调整图像布局
plt.show() # 显示绘制的图像
# 绘制经过多次前向扩散后的图像,观察扩散过程
# `get_noisy_image` 函数通过正向扩散为每个时间步添加噪声
# 观察在不同时间步 t 下的噪声图像效果,t 分别为 0, 50, 100, 150, 199
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])
我们可以看到,时间步从 0 到 199,随着 增大,图像噪声逐渐增加,最后变得完全不可辨认,这样可以帮助我们直观地观察噪声对图像的逐步影响。
接下来,我们定义损失函数!主要就是计算模型预测噪声与实际噪声之间的损失,用来训练去噪模型 denoise_model
。
# 定义损失函数 p_losses,有三种损失类型:L1、L2 和 Huber,默认使用 L1 损失
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
# 如果没有提供噪声,则随机生成一个与 x_start 形状相同的噪声
if noise is None:
noise = torch.randn_like(x_start) # 生成标准正态分布的噪声,与 x_start 形状一致
# 使用生成的噪声对原始图像 x_start 进行加噪,得到带噪声的图像 x_noisy
# q_sample 函数实现的是正向扩散过程,根据时间步 t 添加相应的噪声
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# 使用去噪模型 denoise_model 对带噪声的图像 x_noisy 进行预测,输出预测的噪声
# denoise_model 需要输入带噪图像 x_noisy 和对应时间步 t
predicted_noise = denoise_model(x_noisy, t)
# 根据加噪后的图片和预测的噪声,计算与真实噪声的损失
# 根据不同的损失类型选择相应的损失函数
if loss_type == 'l1':
# L1 损失(Mean Absolute Error,MAE):计算真实噪声和预测噪声之间的绝对误差
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
# L2 损失(Mean Squared Error,MSE):计算真实噪声和预测噪声之间的平方误差
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
# Huber 损失:兼具 L1 和 L2 损失的优点,对异常值具有更强的鲁棒性
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
# 如果指定的损失类型不是上述三种,抛出未实现错误
raise NotImplementedError()
return loss # 返回计算的损失
定义数据集和 DataLoader,使用 MNIST 数据集构造一个 PyTorch DataLoader,每个 batch 包含 128 张经过标准化处理的图像。
# 使用 Hugging Face Datasets 加载 Fashion MNIST 数据集
# Fashion MNIST 是一个类似 MNIST 的数据集,但每个样本是不同类型的服装图片
dataset = load_dataset("fashion_mnist") # 加载 Fashion MNIST 数据集
# 尝试使用 torchvision 的 MNIST 数据集下载方式
# dataset = datasets.MNIST(root='./data/mnist', download=True)
# 设置图像尺寸和通道数(灰度图像)
image_size = 28 # MNIST 图像的大小为 28x28 像素
channels = 1 # 灰度图像只有一个通道
batch_size = 128 # 每个批次中包含 128 张图像
# Downloading readme: 100%|██████████| 9.02k/9.02k [00:00<00:00, 16.4kB/s]
# Downloading data: 100%|██████████| 30.9M/30.9M [00:07<00:00, 4.10MB/s]
# Downloading data: 100%|██████████| 5.18M/5.18M [00:01<00:00, 3.59MB/s]
# Generating train split: 100%|██████████| 60000/60000 [00:00<00:00, 152977.41 examples/s]
# Generating test split: 100%|██████████| 10000/10000 [00:00<00:00, 427881.05 examples/s]
定义图像变换过程。
# 定义图像变换过程
# 使用 torchvision.transforms 组合多个变换
# 1. RandomHorizontalFlip() 随机水平翻转图像,用于数据增强
# 2. ToTensor() 将 PIL 图像或 numpy 数组转换为 PyTorch 张量,并且值范围从 [0, 255] 转为 [0, 1]
# 3. Lambda(lambda t: (t * 2) - 1) 将张量值范围从 [0, 1] 转换为 [-1, 1]
transform = Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为张量
transforms.Lambda(lambda t: (t * 2) - 1) # 将值变换到 [-1, 1] 区间
])
# 定义一个函数来应用上述的变换
def transforms(examples):
# 对每个样本的图像应用 transform 变换,并转换为单通道灰度图像 ("L")
examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
del examples["image"] # 删除原始图像数据,保留变换后的张量
return examples
# 使用自定义的 transforms 函数对数据集进行变换
# .with_transform() 用于将变换函数应用到数据集上
# .remove_columns() 用于移除数据集中不需要的列,比如标签(label)
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
# 创建 DataLoader,分批次加载变换后的数据集
# 传入 transformed_dataset 的训练部分,设置批次大小和是否打乱数据
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
# # 示例:从 DataLoader 中取出一个批次并查看其键值
# batch = next(iter(dataloader)) # 获取 dataloader 的一个批次数据
# print(batch.keys()) # 打印批次数据中的键,通常是 'pixel_values'
现在,我们要开始采样过程啦!
采样过程发生在反向去噪阶段。在这个过程中,扩散模型从一张纯噪声图像开始,逐步去除噪声,最终生成一张逼近真实数据分布的图像。采样的核心在于如何定义和实现这一去噪的过程。
在采样算法的第四行中, 步的图像是由 步的图像减去一个噪声项得到的,而这个噪声项是由网络拟合并进行重新缩放的结果。然而,采样过程的一个关键点在于,除了去噪,每一步还会加入一个从正态分布中采样得到的纯噪声。理想情况下,这种反复迭代的去噪操作最终会生成一张仿佛从真实数据分布中采样得到的图像。
去噪过程中的均值计算公式如下:
其中, 是通过模型 预测得到的噪声项。
# 定义采样函数,使用@torch.no_grad()装饰器避免梯度计算,提高效率。
@torch.no_grad()
def p_sample(model, x, t, t_index):
# 通过使用模型和时间步长 t 对给定的图像 x 进行一步采样
# 提取与当前时间步长 t 对应的 betas 值
betas_t = extract(betas, t, x.shape)
# 提取与当前时间步长 t 对应的 sqrt(1 - alpha_cumprod) 值
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
# 提取与当前时间步长 t 对应的 sqrt(1 / alpha) 值
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# 依据论文中的公式11,使用模型预测去噪后的均值
# model_mean 代表模型预测的去噪后的图像
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
# 如果当前步是 t=0(最后一步),返回模型预测的均值,即去噪后的图像
if t_index == 0:
return model_mean
else:
# 否则,计算并加上后验分布的噪声
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x) # 生成与输入图像形状相同的标准正态分布噪声
# 按照算法2第四行,将噪声添加到模型预测的均值中
return model_mean + torch.sqrt(posterior_variance_t) * noise
# 定义采样循环函数,根据算法2逐步从噪声中生成图像
@torch.no_grad()
def p_sample_loop(model, shape):
# 逐步从随机噪声中采样,生成最终的图像
device = next(model.parameters()).device # 获取模型所在的设备(CPU/GPU)
b = shape[0] # 获取批量大小
# 从标准正态分布的纯噪声开始(每个批次的图像都从噪声开始生成)
img = torch.randn(shape, device=device)
imgs = [] # 用于存储每个时间步的图像
# 遍历时间步长,从最大步长到第0步(逐步去噪)
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
# 在每个时间步执行一步采样,得到去噪后的图像
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy()) # 将当前步生成的图像添加到列表中
return imgs # 返回所有时间步的图像(从噪声到最终生成的图像)
# 定义主采样函数,利用p_sample_loop生成指定批次和大小的图像
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
# 调用 p_sample_loop 函数,生成指定形状的图像
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
马上开始训练啦!
# 定义一个辅助函数,将一个数分成多个组,每组的大小为指定的除数
# 例如,如果num=10且divisor=3,则返回[3, 3, 3, 1]
def num_to_groups(num, divisor):
groups = num // divisor # 计算能分成的完整组数
remainder = num % divisor # 计算剩余的数量
arr = [divisor] * groups # 初始化每组大小为divisor
if remainder > 0:
arr.append(remainder) # 如果有剩余,则将剩余部分作为一组添加到数组中
return arr # 返回分组后的数组
# 定义保存结果的文件夹路径
results_folder = Path("./results")
# 创建文件夹,如果文件夹已经存在则不会报错
results_folder.mkdir(exist_ok=True)
# 定义每隔多少步保存一次生成的图片
save_and_sample_every = 1000
# 设置设备为GPU或CPU,取决于是否有可用的CUDA GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 实例化模型,这里使用的是UNet模型
model = Unet(
dim=image_size, # 图像的尺寸
channels=channels, # 图像的通道数
dim_mults=(1, 2, 4) # 不同分辨率层的通道倍数
)
# 将模型放置在指定设备上(GPU或CPU)
model.to(device)
# 使用Adam优化器优化模型的参数,学习率设置为1e-3
optimizer = Adam(model.parameters(), lr=1e-3)
开始训练!
# 设置训练轮数
epochs = 5
# 开始训练循环
for epoch in range(epochs):
# 遍历数据加载器中的每一个批次
for step, batch in enumerate(dataloader):
optimizer.zero_grad() # 每次优化前清零优化器的梯度
batch_size = batch["pixel_values"].shape[0] # 获取当前批次的大小
batch = batch["pixel_values"].to(device) # 将当前批次的数据移动到指定设备上
# 从算法1的第3行,随机为每个样本生成一个时间步 t
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
# 计算损失函数,使用Huber损失
loss = p_losses(model, batch, t, loss_type="huber")
# 每隔100步打印一次当前的损失值
if step % 100 == 0:
print("Loss:", loss.item())
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新模型参数
# 每隔指定的步数保存生成的图像
if step != 0 and step % save_and_sample_every == 0:
milestone = step // save_and_sample_every # 计算当前的里程碑(第几个保存点)
batches = num_to_groups(4, batch_size) # 将4个样本分成多个批次
# 使用生成的样本列表并连接成一个整体
all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5 # 将生成的图像像素值从[-1, 1]转换到[0, 1]区间
# 将生成的图像保存到指定文件夹中,以里程碑的编号命名
save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow=6)
# Loss: 0.4452235996723175
# Loss: 0.15749157965183258
# Loss: 0.08217678219079971
# Loss: 0.07007536292076111
# Loss: 0.05999240279197693
# Loss: 0.058308832347393036
# Loss: 0.06909437477588654
# Loss: 0.05325306951999664
# Loss: 0.06019213795661926
# Loss: 0.04859648644924164
# Loss: 0.052206672728061676
# Loss: 0.05211498588323593
# Loss: 0.054710570722818375
# Loss: 0.05077681690454483
# Loss: 0.047497622668743134
# Loss: 0.0527014285326004
# Loss: 0.04545003920793533
# Loss: 0.04585908353328705
# Loss: 0.04193577542901039
# Loss: 0.049871765077114105
# Loss: 0.04710713028907776
# Loss: 0.049489736557006836
# Loss: 0.058129820972681046
# Loss: 0.050070762634277344
# Loss: 0.040489036589860916
在模型训练完成后,我们就可以使用该模型进行推理(Inference),也就是从噪声开始一步步生成图像!冲!
# 使用已经训练好的模型生成64张图像
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)
# 随机选择一张生成的图像进行展示,这里选择索引为5的图像
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
# sampling loop time step: 100%|██████████| 200/200 [00:08<00:00, 23.84it/s]
# 展示从噪声生成图像的过程
import matplotlib.animation as animation
random_index = 53
fig = plt.figure()
ims = []
# 遍历所有时间步,记录每一步生成的图像
for i in range(timesteps):
im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
ims.append([im])
# 将生成的图像序列制作成动画
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
Image(url='diffusion.gif')
数学推导补充(不看也罢!)
这里我参照了周奕帆的人工智能 - 扩散模型(Diffusion Model)详解:直观理解、数学原理、PyTorch 实现 - 个人文章 - SegmentFault 思否(https://segmentfault.com/a/1190000043744225),个人觉得他讲得很好哎!大家有兴趣可以去看看!
均值方差推导
在前面,我们通过以下几个公式:
推导出如下的结果:
其中,均值 和方差 分别为:
现在,我们来详细推导均值 和方差 的过程。
首先,将其他几个公式代入贝叶斯公式的等式右边:
由于多个正态分布的乘积仍是一个正态分布,我们知道 也可以用一个正态分布公式 表达。
那么,问题就变成了如何把这个较长的式子化简,并计算出 和 。
首先,我们可以从指数函数的系数中直接得到方差 :
接下来,我们关注指数函数的指数部分。指数部分是一个关于 的二次函数。通过将关于 x_{t-1} 的项整理成 的形式,然后除以 -2 倍方差,就可以得到均值 。
指数部分为:
将与 相关的项进行计算和化简后,可以得到均值:
通过前面的讲解我们可以知道,在去噪过程中模型的输入是 和 。这意味着上述公式中的 是已知的,只有 是未知量,那我们就需要计算出 与 之间的关系。
根据正向过程的公式:
将这个 的表达式代入均值公式,均值最终会化简成我们熟悉的形式:
这样,我们就完整推导出了均值和方差。
损失函数推导
前面我们提到模型的优化目标是使加噪声和去噪声的均值尽可能接近,而这实际上意味着让生成的噪声 更加接近真实噪声。实际上,这一优化目标是经过简化得到的,扩散模型的最初优化目标是具有一定数学意义的。
扩散模型,即扩散概率模型(Diffusion Probabilistic Model),其中最简单的一类是去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM)。前面我们知道,DDPM 的框架主要基于两篇重要的论文。第一篇论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》首次提出了扩散模型的思想。在此基础上,《Denoising Diffusion Probabilistic Models》对最早的扩散模型进行了简化,大幅提升了图像生成效果,促使扩散模型得以广泛应用。我们在上一节中看到的公式,其实都是这些简化后的结果。
扩散模型的核心思想是通过学习反向过程生成真实数据的概率密度。具体来说,模型的目标是通过一个包含可学习参数θ的概率模型 ,描述从某个噪声分布到真实数据的过程。这个概率模型实际上是反向过程的积分形式,公式如下:
其中, 是反向过程的概率分布,可以分解为以下形式:
其中, 表示给定 的情况下生成 的概率,这通常被建模为一个正态分布 。
在优化过程中,目标是让去噪声操作 和加噪声操作的逆操作 尽可能相似。这可以通过最小化它们之间的 KL 散度 来实现:
这里,KL 散度 用于衡量两个概率分布 和 Q 之间的差异。对于正态分布 P 和 Q,KL 散度的公式是:
在具体优化过程中,第一项 可以忽略,因为它与可学习参数 无关。最终优化目标主要包括两部分:
最小化每一步去噪声操作和加噪声逆操作的相似度:
最大化最后复原原图 的概率:
其中,正态分布的均值项 和 的差异度是关键,在 DDPM 中,最终简化的优化目标是:
这个简单的优化目标就是我们经常在文献中见到的去噪声训练目标啦!
搞定!!!后面的部分嘛!大家随意享用!嘿嘿嘿!
文末碎碎念
那今天的分享就到这里啦!我们下期再见哟!
最后顺便给自己推荐一下嘿嘿嘿!
如果我的分享对你有用的话,欢迎关注点赞在看转发分享阿巴阿巴阿巴阿巴巴巴!这可是我的第一原动力!
蟹蟹你们的喜欢和支持!!!
参考资料
https://arxiv.org/pdf/2006.11239 https://huggingface.co/blog/annotated-diffusion https://www.bilibili.com/video/BV13h411V7vg https://www.bilibili.com/video/BV14c411J7f2 https://blog.csdn.net/tobefans/article/details/129728036 https://aistudio.baidu.com/projectdetail/4867936 https://blog.csdn.net/DFCED/article/details/132394895 https://segmentfault.com/a/1190000043744225 https://mp.weixin.qq.com/s/keu3TMLuxZOszv2GVrSalw https://github.com/yangqy1110/Diffusion-Models