零基础万字长文实践diffusion模型

科技   2024-10-30 22:00   广东  
↑ 点击蓝字 关注极市平台
作者丨mikeeee
来源丨WeThinkIn
编辑丨极市平台

极市导读

 

文是关于去噪扩散概率模型(也称为DDPM或扩散模型)的详细实践指南,从基本原理出发,逐步介绍了如何在PyTorch中实现这一模型,并通过实验展示了模型在图像生成上的应用。文章还探讨了扩散模型的数学描述、神经网络架构以及训练和采样过程,为读者提供了一个全面的学习资源。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

本文为The Annotated Diffusion Model博客的中文翻译版,仅用作学术交流分享。

在这篇文章中,我们将深入研究去噪扩散概率模型(也称为 DDPM、扩散模型、基于分数的生成模型或简称为 自动编码器),因为研究人员已经能够利用它们在(无)条件图像/音频/视频生成方面取得显著成果。在2022年时,流行的例子包括 OpenAI 的 GLIDE 和 DALL-E 2、海德堡大学的 Latent Diffusion 和 Google Brain 的 ImageGen。

我们将回顾 (Ho et al., 2020) 的原始 DDPM 论文,并基于 Phil Wang 的 实现 在 PyTorch 中逐步实现它,而 Phil Wang 的 实现 本身又基于 原始 TensorFlow 实现。请注意,生成建模的扩散概念实际上已在 (Sohl-Dickstein et al., 2015) 中引入。然而,直到 (Song et al., 2019)(斯坦福大学)和 (Ho et al., 2020)(Google Brain)才独立改进了该方法并取得了显著成效,之后diffusion模型才开始流行起来。

请注意,关于扩散模型有 几种观点。在这里,我们从离散时间(潜在变量模型)视角出发进行解读,但请读者务必对其他视角也进行一定的了解。OK,那我们就开始吧!

我们将首先安装并导入所需的库(假设你已经安装了 PyTorch)。

!pip install -q -U einops datasets matplotlib tqdm  
  
import math  
from inspect import isfunction  
from functools import partial  
  
%matplotlib inline  
import matplotlib.pyplot as plt  
from tqdm.auto import tqdm  
from einops import rearrange, reduce  
from einops.layers.torch import Rearrange  
  
import torch  
from torch import nn, einsum  
import torch.nn.functional as F  

什么是diffusion model ?

扩散模型(或去噪扩散模型)相比其他生成模型(如Normalizing Flows、GANs或VAEs)并不复杂:这些模型都是将某种简单分布中的噪声转换为数据样本。这种方式同样适用于扩散模型,其中神经网络学习逐渐去噪数据,从纯噪声开始逐步生成一个数据样本。

更具体来说,对于图像数据,模型包含两个过程:

  • 一个我们可以选择的固定(或预定义)前向扩散过程 , 该过程逐渐将高斯噪声添加到图像中, 直到最终得到纯噪声。
  • 一个学习到的反向去噪扩散过程 , 通过训练神经网络逐渐对图像去噪, 从纯噪声逐步生成一张真实图像。

前向和反向过程均基于时间步骤 进行, 且在有限时间步骤 内完成(DDPM 论文中采用 )。在 时,从数据分布中采样一个真实图像 (例如 ImageNet 数据集中的一张猫的图片),前向过程在每个时间步骤 中从高斯分布中采样一些噪声并将其添加到前一时间步骤的图像上。通过选择足够大的 和合理的噪声添加调度, 最终可以在 时通过渐进过程得到一个各向同性高斯分布。

从数学角度来描述

让我们更正式地描述这一过程,因为我们最终需要一个可计算的损失函数,供神经网络进行优化。

为真实数据分布(例如"真实图像")。我们可以从这个分布中采样一个图像, 即 。我们定义前向扩散过程 , 该过程在每个时间步骤 添加高斯噪声,根据一个已知的variance scheduler0 进行

回顾一下,正态分布(也称高斯分布)由两个参数定义:均值 和方差 。基本上,每个时间步 的新图像是从一个 条件高斯分布 中生成的, 其均值 且方差 。我们可以通过采样 并设置 来实现。

注意, 在每个时间步骤 中不是常数(因此带有下标)。实际上,人们定义了一种所谓的 "方差调度", 可以是线性的、二次的、余弦的等等(类似学习率调度)。

因此, 从 开始, 最终得到 , 若调度设置合理, 将为纯高斯噪声。

现在, 假设我们知道条件分布 , 那么可以运行反向过程:从一些随机高斯噪声 采样, 并逐渐"去噪", 最终得到一个来自真实分布的样本

然而, 我们并不知道 。这在计算上是不可行的, 因为需要知道所有可能图像的分布来计算这一条件概率。因此,我们将借助神经网络来 逼近(学习)这一条件概率分布,将其记作 , 其中 为神经网络的参数, 通过梯度下降进行更新。

因此,我们需要一个神经网络来表示反向过程的(条件)概率分布。假设这一反向过程也为高斯分布,回忆一下任何高斯分布都由两个参数定义:

  • 均值由 参数化;
  • 方差由 参数化;

因此我们可以参数化该过程为

其中均值和方差也依赖于噪声水平

因此,我们的神经网络需要学习/表示均值和方差。然而,DDPM 论文的作者决定固定方差,仅让神经网络学习(表示)这一条件概率分布的均值。论文中提到:

首先, 我们设 为未训练的时间依赖常数。实验表明, (参见论文)有类似的效果。

在后续的 改进的扩散模型 论文中,这一方法得到改进,神经网络不仅学习反向过程的均值,还学习方差。

因此,我们继续假设神经网络仅需要学习/表示这一条件概率分布的均值。

定义需要优化的目标函数

为了推导出一个用于学习反向过程均值的目标函数, 作者指出, 的组合可以视为变分自编码器(VAE)(Kingma et al.,2013)。因此,可以使用变分下界(也称 ELBO)最小化相对于真实数据样本 的负对数似然(关于 ELBO 的详细信息可参考 VAE 论文)。该过程的 ELBO 实际上是每个时间步骤 的损失的总和, 即 。通过前向过程 和反向过程的构建, 损失的每一项(除了 )实际上是两个高斯分布之间的 KL 散度,可以用关于均值的 L2 损失显式表示!

根据 Sohl-Dickstein 等人的结果, 前向过程 的一个直接结果是, 我们可以在任意噪声水平上以 为条件采样 (因为高斯分布的和仍是高斯分布)。这非常方便:我们不需要反复应用 来采样 。我们有

其中 。我们将此等式称为"优良特性"。这意味着我们可以采样高斯噪声, 适当缩放并将其添加到 上以直接得到 。注意, 是已知的 方差调度的函数,因此也是已知的,可以预先计算。这使得我们在训练过程中,可以优化损失函数 的随机项(换句话说, 在训练过程中随机采样 并优化 )。

这一特性的另一个好处是,如 Ho 等人所展示的,可以(通过一些数学推导,详细推导请参考这篇优秀博文)重新参数化均值,使神经网络学习(预测)在 KL 项中作为损失的噪声(通过网络。这意味着我们的神经网络成为一个噪声预测器,而不是直接的均值预测器。均值可以如下计算:

最终的目标函数 如下所示(对于给定 的随机时间步骤 ):

这里, 是初始 (真实且未受污染的) 图像, 我们看到由固定的前向过程给出的直接噪声水平 样本。 是在时间步骤 采样的纯噪声, 而 是我们的神经网络。神经网络通过真实噪声和预测高斯噪声之间的简单均方误差 (MSE) 进行优化。

训练算法如下所示:

换句话说:

  • 我们从真实的、可能复杂的数据分布 中随机采样一个样本
  • 我们在 1 和 之间均匀采样噪声水平 (即随机时间步骤)
  • 我们从高斯分布中采样一些噪声, 并在水平 将其添加到输入上(使用上面定义的优良特性)
  • 神经网络被训练来基于被污染的图像 预测此噪声(即在已知调度 基础上对 应用噪声)

实际上,这些操作都是在数据批次上完成的,因为我们使用随机梯度下降优化神经网络。

神经网络(模型部分)

神经网络需要在特定时间步骤上接收一个带噪声的图像并返回预测的噪声。注意,预测的噪声是一个与输入图像大小/分辨率相同的张量。因此,网络实际上输入和输出具有相同形状的张量。那我们可以使用哪种类型的神经网络呢?

通常使用的神经网络与自编码器非常相似,您可能在“深度学习入门”教程中见过自编码器。自编码器在编码器和解码器之间具有所谓的“瓶颈”层。编码器首先将图像编码为一个较小的隐藏表示,称为“瓶颈”,然后解码器将该隐藏表示解码回实际图像。这强迫网络在瓶颈层中仅保留最重要的信息。

在架构方面,DDPM 的作者选择了 U-Net,由 (Ronneberger et al., 2015) 提出(当时在医学图像分割中实现了最先进的效果)。这种网络与任何自编码器一样,在中间有一个瓶颈,确保网络只学习最重要的信息。重要的是,它在编码器和解码器之间引入了残差连接,大大改善了梯度流动(灵感来源于 He et al., 2015 的 ResNet)。

如图所示,U-Net 模型首先对输入进行下采样(即在空间分辨率方面缩小输入),然后进行上采样。

下面,我们一步步实现这个网络。

辅助函数

首先,我们定义一些在实现神经网络时会使用到的辅助函数和类。特别地,我们定义一个 Residual 模块,它简单地将输入添加到特定函数的输出中(换句话说,为特定函数添加一个残差连接)。

我们还为上采样和下采样操作定义了别名。

def exists(x):  
    return x is not None  
  
def default(val, d):  
    if exists(val):  
        return val  
    return d() if isfunction(d) else d  
  
  
def num_to_groups(num, divisor):  
    groups = num // divisor  
    remainder = num % divisor  
    arr = [divisor] * groups  
    if remainder > 0:  
        arr.append(remainder)  
    return arr  
  
  
class Residual(nn.Module):  
    def __init__(self, fn):  
        super().__init__()  
        self.fn = fn  
  
    def forward(self, x, *args, **kwargs):  
        return self.fn(x, *args, **kwargs) + x  
  
  
def Upsample(dim, dim_out=None):  
    return nn.Sequential(  
        nn.Upsample(scale_factor=2, mode="nearest"),  
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),  
    )  
  
  
def Downsample(dim, dim_out=None):  
    # No More Strided Convolutions or Pooling  
    return nn.Sequential(  
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),  
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),  
    )  

Position embeddings

由于神经网络的参数在时间(噪声水平)上是共享的,作者借鉴了 Transformer 的位置编码 (Vaswani et al., 2017),使用正弦位置编码对 进行编码。这使得神经网络可以“知道”它在每张批次中的图像上处于哪个特定时间步骤(噪声水平)。

SinusoidalPositionEmbeddings 模块接收一个形状为 (batch_size, 1) 的张量作为输入(即批次中若干带噪声图像的噪声水平),并将其转换为形状为 (batch_size, dim) 的张量,其中 dim 为位置嵌入的维度。然后,该嵌入将被添加到每个残差块中,后续我们会进一步说明。

class SinusoidalPositionEmbeddings(nn.Module):  
    def __init__(self, dim):  
        super().__init__()  
        self.dim = dim  
  
    def forward(self, time):  
        device = time.device  
        half_dim = self.dim // 2  
        embeddings = math.log(10000) / (half_dim - 1)  
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)  
        embeddings = time[:, None] * embeddings[None, :]  
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)  
        return embeddings  

ResNet block

接下来,我们定义 U-Net 模型的核心构建块。DDPM 作者使用了一个宽残差网络块(Wide ResNet block) (Zagoruyko et al., 2016),但 Phil Wang 将标准卷积层替换为“权重标准化”的版本,这在与组归一化结合时表现更佳(详情请参见 (Kolesnikov et al., 2019))。

class WeightStandardizedConv2d(nn.Conv2d):  
    """  
    https://arxiv.org/abs/1903.10520  
    weight standardization purportedly works synergistically with group normalization  
    "
""  
  
    def forward(self, x):  
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3  
  
        weight = self.weight  
        mean = reduce(weight, "o ... -> o 1 1 1""mean")  
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))  
        normalized_weight = (weight - mean) / (var + eps).rsqrt()  
  
        return F.conv2d(  
            x,  
            normalized_weight,  
            self.bias,  
            self.stride,  
            self.padding,  
            self.dilation,  
            self.groups,  
        )  
  
  
class Block(nn.Module):  
    def __init__(self, dim, dim_out, groups=8):  
        super().__init__()  
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)  
        self.norm = nn.GroupNorm(groups, dim_out)  
        self.act = nn.SiLU()  
  
    def forward(self, x, scale_shift=None):  
        x = self.proj(x)  
        x = self.norm(x)  
  
        if exists(scale_shift):  
            scale, shift = scale_shift  
            x = x * (scale + 1) + shift  
  
        x = self.act(x)  
        return x  
  
  
class ResnetBlock(nn.Module):  
    """https://arxiv.org/abs/1512.03385"""  
  
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):  
        super().__init__()  
        self.mlp = (  
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))  
            if exists(time_emb_dim)  
            else None  
        )  
  
        self.block1 = Block(dim, dim_out, groups=groups)  
        self.block2 = Block(dim_out, dim_out, groups=groups)  
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()  
  
    def forward(self, x, time_emb=None):  
        scale_shift = None  
        if exists(self.mlp) and exists(time_emb):  
            time_emb = self.mlp(time_emb)  
            time_emb = rearrange(time_emb, "b c -> b c 1 1")  
            scale_shift = time_emb.chunk(2, dim=1)  
  
        h = self.block1(x, scale_shift=scale_shift)  
        h = self.block2(h)  
        return h + self.res_conv(x)  

Attention module

接下来,我们定义注意力模块,DDPM 作者在卷积块之间添加了这个模块。注意力机制是著名的 Transformer 架构的构建块 (Vaswani et al., 2017),它在从自然语言处理到视觉以及蛋白质折叠等各个 AI 领域中取得了巨大成功。Phil Wang 使用了两种注意力机制的变体:一种是常规的多头自注意力(与 Transformer 中使用的相同),另一种是线性注意力变体 (Shen et al., 2018),其时间和内存需求相对于序列长度呈线性增长,而常规注意力呈二次增长。

关于注意力机制的详细解释,请参考 Jay Allamar 的精彩博文。

class Attention(nn.Module):  
    def __init__(self, dim, heads=4, dim_head=32):  
        super().__init__()  
        self.scale = dim_head**-0.5  
        self.heads = heads  
        hidden_dim = dim_head * heads  
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)  
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)  
  
    def forward(self, x):  
        b, c, h, w = x.shape  
        qkv = self.to_qkv(x).chunk(3, dim=1)  
        q, k, v = map(  
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv  
        )  
        q = q * self.scale  
  
        sim = einsum("b h d i, b h d j -> b h i j", q, k)  
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()  
        attn = sim.softmax(dim=-1)  
  
        out = einsum("b h i j, b h d j -> b h i d", attn, v)  
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)  
        return self.to_out(out)  
  
class LinearAttention(nn.Module):  
    def __init__(self, dim, heads=4, dim_head=32):  
        super().__init__()  
        self.scale = dim_head**-0.5  
        self.heads = heads  
        hidden_dim = dim_head * heads  
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)  
  
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),   
                                    nn.GroupNorm(1, dim))  
  
    def forward(self, x):  
        b, c, h, w = x.shape  
        qkv = self.to_qkv(x).chunk(3, dim=1)  
        q, k, v = map(  
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv  
        )  
  
        q = q.softmax(dim=-2)  
        k = k.softmax(dim=-1)  
  
        q = q * self.scale  
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)  
  
        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)  
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)  
        return self.to_out(out)  

Group normalization

DDPM 的作者在 U-Net 的卷积/注意力层之间插入了组归一化 (Wu et al., 2018)。下面,我们定义一个 PreNorm 类,用于在注意力层之前应用组归一化,后续会进一步说明。需要注意的是,关于在 Transformer 中应在注意力前还是注意力后应用归一化的问题,一直存在争论。

class PreNorm(nn.Module):  
    def __init__(self, dim, fn):  
        super().__init__()  
        self.fn = fn  
        self.norm = nn.GroupNorm(1, dim)  
  
    def forward(self, x):  
        x = self.norm(x)  
        return self.fn(x)  

Conditional U-Net

现在我们已经定义了所有构建块(位置嵌入、ResNet 块、注意力和组归一化),是时候定义整个神经网络了。回顾一下,网络的任务是接收一批带噪声的图像及其相应的噪声水平,并输出添加到输入中的噪声。更正式地说:

  • 网络接收一个形状为 (batch_size, num_channels, height, width) 的带噪声图像批次和一个形状为 (batch_size, 1) 的噪声水平批次作为输入,返回一个形状为 (batch_size, num_channels, height, width) 的张量

该网络的构建步骤如下:

  • 首先,对带噪声图像批次应用卷积层,并为噪声水平计算位置嵌入
  • 接着,应用一系列下采样阶段。每个下采样阶段由 2 个 ResNet 块 + 组归一化 + 注意力 + 残差连接 + 一个下采样操作组成
  • 在网络中间,再次应用 ResNet 块并插入注意力层
  • 然后,应用一系列上采样阶段。每个上采样阶段由 2 个 ResNet 块 + 组归一化 + 注意力 + 残差连接 + 一个上采样操作组成
  • 最后,应用一个 ResNet 块和一个卷积层

最终,神经网络像堆叠乐高积木一样堆叠各层(不过理解它们的工作原理很重要)。

class Unet(nn.Module):  
    def __init__(  
        self,  
        dim,  
        init_dim=None,  
        out_dim=None,  
        dim_mults=(1, 2, 4, 8),  
        channels=3,  
        self_condition=False,  
        resnet_block_groups=4,  
    ):  
        super().__init__()  
  
        # determine dimensions  
        self.channels = channels  
        self.self_condition = self_condition  
        input_channels = channels * (2 if self_condition else 1)  
  
        init_dim = default(init_dim, dim)  
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3  
  
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]  
        in_out = list(zip(dims[:-1], dims[1:]))  
  
        block_klass = partial(ResnetBlock, groups=resnet_block_groups)  
  
        # time embeddings  
        time_dim = dim * 4  
  
        self.time_mlp = nn.Sequential(  
            SinusoidalPositionEmbeddings(dim),  
            nn.Linear(dim, time_dim),  
            nn.GELU(),  
            nn.Linear(time_dim, time_dim),  
        )  
  
        # layers  
        self.downs = nn.ModuleList([])  
        self.ups = nn.ModuleList([])  
        num_resolutions = len(in_out)  
  
        for ind, (dim_in, dim_out) in enumerate(in_out):  
            is_last = ind >= (num_resolutions - 1)  
  
            self.downs.append(  
                nn.ModuleList(  
                    [  
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),  
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),  
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),  
                        Downsample(dim_in, dim_out)  
                        if not is_last  
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),  
                    ]  
                )  
            )  
  
        mid_dim = dims[-1]  
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)  
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))  
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)  
  
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):  
            is_last = ind == (len(in_out) - 1)  
  
            self.ups.append(  
                nn.ModuleList(  
                    [  
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),  
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),  
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),  
                        Upsample(dim_out, dim_in)  
                        if not is_last  
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),  
                    ]  
                )  
            )  
  
        self.out_dim = default(out_dim, channels)  
  
        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)  
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)  
  
    def forward(self, x, time, x_self_cond=None):  
        if self.self_condition:  
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))  
            x = torch.cat((x_self_cond, x), dim=1)  
  
        x = self.init_conv(x)  
        r = x.clone()  
  
        t = self.time_mlp(time)  
  
        h = []  
  
        for block1, block2, attn, downsample in self.downs:  
            x = block1(x, t)  
            h.append(x)  
  
            x = block2(x, t)  
            x = attn(x)  
            h.append(x)  
  
            x = downsample(x)  
  
        x = self.mid_block1(x, t)  
        x = self.mid_attn(x)  
        x = self.mid_block2(x, t)  
  
        for block1, block2, attn, upsample in self.ups:  
            x = torch.cat((x, h.pop()), dim=1)  
            x = block1(x, t)  
  
            x = torch.cat((x, h.pop()), dim=1)  
            x = block2(x, t)  
            x = attn(x)  
  
            x = upsample(x)  
  
        x = torch.cat((x, r), dim=1)  
  
        x = self.final_res_block(x, t)  
        return self.final_conv(x)  

定义diffusion前向过程

前向扩散过程在若干时间步骤 内逐渐向来自真实分布的图像添加噪声。这个过程依赖于一个方差调度。原始的 DDPM 作者采用了线性调度:

我们将前向过程的方差设置为常数,从 线性增加到

然而,(Nichol et al., 2021) 研究表明,使用余弦调度可以获得更好的结果。

下面,我们为个时间步定义了不同的调度(稍后我们会选择一个)。

def cosine_beta_schedule(timesteps, s=0.008):  
    """  
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672  
    "
""  
    steps = timesteps + 1  
    x = torch.linspace(0, timesteps, steps)  
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2  
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]  
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])  
    return torch.clip(betas, 0.0001, 0.9999)  
  
def linear_beta_schedule(timesteps):  
    beta_start = 0.0001  
    beta_end = 0.02  
    return torch.linspace(beta_start, beta_end, timesteps)  
  
def quadratic_beta_schedule(timesteps):  
    beta_start = 0.0001  
    beta_end = 0.02  
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2  
  
def sigmoid_beta_schedule(timesteps):  
    beta_start = 0.0001  
    beta_end = 0.02  
    betas = torch.linspace(-6, 6, timesteps)  
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start  

首先, 我们使用线性调度设置 个时间步骤, 并从 定义我们将需要的各种变量,如方差的累积乘积 。下面的每个变量都是一维张量, 存储从 的值。重要的是, 我们还定义了一个 extract 函数, 用于从批次索引中提取适当的 索引。

timesteps = 300  
  
# define beta schedule  
betas = linear_beta_schedule(timesteps=timesteps)  
  
# define alphas   
alphas = 1. - betas  
alphas_cumprod = torch.cumprod(alphas, axis=0)  
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)  
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)  
  
# calculations for diffusion q(x_t | x_{t-1}) and others  
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)  
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)  
  
# calculations for posterior q(x_{t-1} | x_t, x_0)  
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)  
  
def extract(a, t, x_shape):  
    batch_size = t.shape[0]  
    out = a.gather(-1, t.cpu())  
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)  

我们将使用一张猫的图像来展示在扩散过程的每个时间步骤中如何添加噪声。

from PIL import Image  
import requests  
  
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'  
image = Image.open(requests.get(url, stream=True).raw) # PIL image of shape HWC  
image  
output_cats

噪声是添加到 PyTorch 张量上的,而不是 Pillow 图像。因此,我们首先定义图像转换,以便从 PIL 图像转换为 PyTorch 张量(在其上可以添加噪声),反之亦然。

这些转换相对简单:首先将图像标准化, 通过除以 255 使它们处于 范围内, 然后确保它们处于 范围内。根据 DPPM 论文:

我们假设图像数据由 中的整数组成, 线性缩放到 。这确保神经网络的反向过程在从标准正态先验 开始时对一致缩放的输入进行操作。

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize  
  
image_size = 128  
transform = Compose([  
    Resize(image_size),  
    CenterCrop(image_size),  
    ToTensor(), # turn into torch Tensor of shape CHW, divide by 255  
    Lambda(lambda t: (t * 2) - 1),  
      
])  
  
x_start = transform(image).unsqueeze(0)  
x_start.shape  
Output:  
torch.Size([1, 3, 128, 128])  

我们还定义了反向转换,它接收一个包含范围内值的 PyTorch 张量,并将其转换回 PIL 图像:

import numpy as np  
  
reverse_transform = Compose([  
     Lambda(lambda t: (t + 1) / 2),  
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC  
     Lambda(lambda t: t * 255.),  
     Lambda(lambda t: t.numpy().astype(np.uint8)),  
     ToPILImage(),  
])  

让我们来验证一下:

reverse_transform(x_start.squeeze())  
output_cats_verify

现在我们可以按照论文中的描述定义前向扩散过程:

# forward diffusion (using the nice property)  
def q_sample(x_start, t, noise=None):  
    if noise is None:  
        noise = torch.randn_like(x_start)  
  
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)  
    sqrt_one_minus_alphas_cumprod_t = extract(  
        sqrt_one_minus_alphas_cumprod, t, x_start.shape  
    )  
  
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise  

让我们在特定的时间步骤上测试它:

def get_noisy_image(x_start, t):  
  # add noise  
  x_noisy = q_sample(x_start, t=t)  
  
  # turn back into PIL image  
  noisy_image = reverse_transform(x_noisy.squeeze())  
  
  return noisy_image  
# take time step  
t = torch.tensor([40])  
  
get_noisy_image(x_start, t)  
output_cats_noisy

让我们在不同的时间步骤上进行可视化:

import matplotlib.pyplot as plt  
  
# use seed for reproducability  
torch.manual_seed(0)  
  
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py  
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):  
    if not isinstance(imgs[0], list):  
        # Make a 2d grid even if there's just 1 row  
        imgs = [imgs]  
  
    num_rows = len(imgs)  
    num_cols = len(imgs[0]) + with_orig  
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)  
    for row_idx, row in enumerate(imgs):  
        row = [image] + row if with_orig else row  
        for col_idx, img in enumerate(row):  
            ax = axs[row_idx, col_idx]  
            ax.imshow(np.asarray(img), **imshow_kwargs)  
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])  
  
    if with_orig:  
        axs[0, 0].set(title='Original image')  
        axs[0, 0].title.set_size(8)  
    if row_title is not None:  
        for row_idx in range(num_rows):  
            axs[row_idx, 0].set(ylabel=row_title[row_idx])  
  
    plt.tight_layout()  
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])  
output_cats_noisy_multiple

这意味着我们现在可以根据模型定义损失函数,如下所示:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):  
    if noise is None:  
        noise = torch.randn_like(x_start)  
  
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)  
    predicted_noise = denoise_model(x_noisy, t)  
  
    if loss_type == 'l1':  
        loss = F.l1_loss(noise, predicted_noise)  
    elif loss_type == 'l2':  
        loss = F.mse_loss(noise, predicted_noise)  
    elif loss_type == "huber":  
        loss = F.smooth_l1_loss(noise, predicted_noise)  
    else:  
        raise NotImplementedError()  
  
    return loss  

denoise_model 将是我们上面定义的 U-Net。我们将使用真实噪声和预测噪声之间的 Huber 损失函数。

定义PyTorch数据集和数据加载器

这里我们定义了一个常规的 PyTorch 数据集。该数据集仅包含来自真实数据集(如 Fashion-MNIST、CIFAR-10 或 ImageNet)的图像,线性缩放到范围内。

每张图像被调整到相同的尺寸。值得注意的是,图像也会随机水平翻转。论文中提到:

我们在 CIFAR10 训练期间使用了随机水平翻转;我们尝试了有翻转和无翻转的训练,发现翻转略微提高了样本质量。

这里我们使用 🤗 Datasets 库 方便地从 hub 加载 Fashion MNIST 数据集。该数据集包含分辨率相同的图像,即 28x28。

from datasets import load_dataset  
  
# load dataset from the hub  
dataset = load_dataset("fashion_mnist")  
image_size = 28  
channels = 1  
batch_size = 128  

接下来,我们定义一个将在整个数据集上即时应用的函数。我们使用 with_transform 功能 实现这一点。该函数仅应用了一些基本的图像预处理:随机水平翻转、重新缩放,并最终将它们的值置于范围内。

from torchvision import transforms  
from torch.utils.data import DataLoader  
  
# define image transformations (e.g. using torchvision)  
transform = Compose([  
            transforms.RandomHorizontalFlip(),  
            transforms.ToTensor(),  
            transforms.Lambda(lambda t: (t * 2) - 1)  
])  
  
# define function  
def transforms(examples):  
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]  
   del examples["image"]  
  
   return examples  
  
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")  
  
# create dataloader  
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)  
batch = next(iter(dataloader))  
print(batch.keys())  
Output:  
dict_keys(['pixel_values'])  

采样

由于在训练期间我们会从模型中采样(以跟踪进展),我们在下面定义了相关代码。采样过程在论文中总结为算法 2:

从扩散模型生成新图像的过程是通过逆转扩散过程实现的:我们从 开始,从高斯分布中采样纯噪声, 然后使用神经网络逐渐去噪(使用它学到的条件概率),直到我们到达时间步骤 。如上所述, 我们可以通过使用均值的重新参数化, 利用我们的噪声预测器, 得出一个稍微去噪的图像 。请记住, 方差是提前已知的。

理想情况下,我们得到的图像看起来像是来自真实数据分布的图像。

下面的代码实现了这一过程。

@torch.no_grad()  
def p_sample(model, x, t, t_index):  
    betas_t = extract(betas, t, x.shape)  
    sqrt_one_minus_alphas_cumprod_t = extract(  
        sqrt_one_minus_alphas_cumprod, t, x.shape  
    )  
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)  
      
    # Equation 11 in the paper  
    # Use our model (noise predictor) to predict the mean  
    model_mean = sqrt_recip_alphas_t * (  
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t  
    )  
  
    if t_index == 0:  
        return model_mean  
    else:  
        posterior_variance_t = extract(posterior_variance, t, x.shape)  
        noise = torch.randn_like(x)  
        # Algorithm 2 line 4:  
        return model_mean + torch.sqrt(posterior_variance_t) * noise   
  
# Algorithm 2 (including returning all images)  
@torch.no_grad()  
def p_sample_loop(model, shape):  
    device = next(model.parameters()).device  
  
    b = shape[0]  
    # start from pure noise (for each example in the batch)  
    img = torch.randn(shape, device=device)  
    imgs = []  
  
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):  
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)  
        imgs.append(img.cpu().numpy())  
    return imgs  
  
@torch.no_grad()  
def sample(model, image_size, batch_size=16, channels=3):  
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))  

请注意,上述代码是原始实现的简化版本。我们发现这种简化(与论文中的算法 2 一致)与原始的、更复杂的实现一样有效,原始实现中采用了裁剪。

训练模型

接下来,我们以常规的 PyTorch 方式训练模型。同时,我们定义了一些逻辑,使用上述定义的 sample 方法定期保存生成的图像。

from pathlib import Path  
  
def num_to_groups(num, divisor):  
    groups = num // divisor  
    remainder = num % divisor  
    arr = [divisor] * groups  
    if remainder > 0:  
        arr.append(remainder)  
    return arr  
  
results_folder = Path("./results")  
results_folder.mkdir(exist_ok = True)  
save_and_sample_every = 1000  

下面,我们定义模型并将其移动到 GPU 上。同时定义一个标准优化器(Adam)。

from torch.optim import Adam  
  
device = "cuda" if torch.cuda.is_available() else "cpu"  
  
model = Unet(  
    dim=image_size,  
    channels=channels,  
    dim_mults=(1, 2, 4,)  
)  
model.to(device)  
  
optimizer = Adam(model.parameters(), lr=1e-3)  

让我们开始训练吧!!

from torchvision.utils import save_image  
  
epochs = 6  
  
for epoch in range(epochs):  
    for step, batch in enumerate(dataloader):  
      optimizer.zero_grad()  
  
      batch_size = batch["pixel_values"].shape[0]  
      batch = batch["pixel_values"].to(device)  
  
      # Algorithm 1 line 3: sample t uniformally for every example in the batch  
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()  
  
      loss = p_losses(model, batch, t, loss_type="huber")  
  
      if step % 100 == 0:  
        print("Loss:", loss.item())  
  
      loss.backward()  
      optimizer.step()  
  
      # save generated images  
      if step != 0 and step % save_and_sample_every == 0:  
        milestone = step // save_and_sample_every  
        batches = num_to_groups(4, batch_size)  
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))  
        all_images = torch.cat(all_images_list, dim=0)  
        all_images = (all_images + 1) * 0.5  
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)  
    Output:  
    ----------------------------------------------------------------------------------------------------  
    Loss: 0.46477368474006653  
    Loss: 0.12143351882696152  
    Loss: 0.08106148988008499  
    Loss: 0.0801810547709465  
    Loss: 0.06122320517897606  
    Loss: 0.06310459971427917  
    Loss: 0.05681884288787842  
    Loss: 0.05729678273200989  
    Loss: 0.05497899278998375  
    Loss: 0.04439849033951759  
    Loss: 0.05415581166744232  
    Loss: 0.06020551547408104  
    Loss: 0.046830907464027405  
    Loss: 0.051029372960329056  
    Loss: 0.0478244312107563  
    Loss: 0.046767622232437134  
    Loss: 0.04305662214756012  
    Loss: 0.05216279625892639  
    Loss: 0.04748568311333656  
    Loss: 0.05107741802930832  
    Loss: 0.04588869959115982  
    Loss: 0.043014321476221085  
    Loss: 0.046371955424547195  
    Loss: 0.04952816292643547  
    Loss: 0.04472338408231735  

采样(推理)

要从模型中采样,我们可以直接使用上面定义的 sample 函数:

# sample 64 images  
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)  
  
# show a random one  
random_index = 5  
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")  
output

看起来模型已经能够生成一件不错的 T 恤了(可能在读者看来效果并不好)!请注意,我们训练的数据集分辨率较低(28x28)。

进阶阅读

DDPM 论文表明扩散模型在(无)条件图像生成方面是一个有前途的方向。此后,它得到了(极大的)改进,尤其是在文本条件图像生成方面。以下列出了一些重要的工作,我们强烈建议对diffusion模型感兴趣的读者做进一步的阅读:

  • Improved Denoising Diffusion Probabilistic Models (Nichol et al., 2021):发现学习条件分布的方差(除了均值外)有助于提高性能
  • Cascaded Diffusion Models for High Fidelity Image Generation (Ho et al., 2021):引入级联扩散,这是一条多重扩散模型管道,通过生成逐渐提高分辨率的图像,实现高保真图像合成
  • Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021):通过改进 U-Net 架构以及引入分类器引导,展示了扩散模型可以实现优于当前最先进生成模型的图像样本质量
  • Classifier-Free Diffusion Guidance (Ho et al., 2021):展示了通过单个神经网络联合训练条件和无条件扩散模型,可以在没有分类器的情况下引导扩散模型
  • Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) (Ramesh et al., 2022):使用先验将文本标题转换为 CLIP 图像嵌入,然后扩散模型将其解码为图像
  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) (Saharia et al., 2022):展示了结合大型预训练语言模型(例如 T5)与级联扩散对于文本到图像合成非常有效


公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

点击阅读原文进入CV社区

收获更多技术干货


极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
 最新文章