MAR(Masked AutoRegressive): 破除封建迷信——谁说自回归图像生成一定需要 VQ的!

文摘   2024-10-21 08:01   北京  

近期文章回顾(更多热门文章请关注公众号与知乎Rocky Ding哦)

写在前面

WeThinkIn最新福利放送:大家只需关注WeThinkIn公众号,后台回复“简历资源”,即可获取包含Rocky独家简历模版在内的60套精选的简历模板资源,希望能给大家在AIGC时代带来帮助。

Rocky最新发布Stable Diffusion 3和FLUX.1系列模型的深入浅出全维度解析文章,点击链接直达干货知识:https://zhuanlan.zhihu.com/p/684068402


导读

 

文章讨论了MAR模型和VQ技术在自然语言处理中的应用。文章首先介绍了autoregressive模型的基本原理,然后指出了LLMs在处理这类模型时可能遇到的挑战。接着,文章重点介绍了VQ技术,这是一种将连续值向量映射到离散表示的方法,有助于提高模型的效率和性能。 

前言

提到自回归(autoregressive),相信有人会立马举手说:

这个我熟!就是 _从左到右按顺序一个个地进行预测_,现在如火如荼的 LLMs 就是这么玩的。

没毛病~ 这种认知似乎已经成为一种刻板印象烙在我们脑子里了。

进一步,如果将自回归生成用于图像,那么就需要对连续(continuous-valued)的像素进行离散化,变为离散的 token,从而才能在预测时实现对 token 的分类预测,这种离散化的技术被称作 "VQ(Vector Quantization)".

嗯,这又是一个刻板印象,或者说已经成为了一种封建迷信:

自回归图像生成需要 VQ,而且是必须!

然而,近来由恺明大神带队完成的一篇 paper(https://arxiv.org/abs/2406.11838) 却破除了以上谈到的封建迷信和刻板印象,即:

VQ 在自回归图像生成中并非是必需的,且自回归可以按随机顺序一次性预测多个,只要是根据之前已知的去预测未知的即可。

这对于习惯照搬隔壁 NLP 那套来搞自回归图像生成的 CVer 们来说可能会造成些打击,但~无论如何,作为炼丹者,千万不能本本主义,接受现实、拥抱变化才是正解。要明白既然玩的本是玄学,那么就一切皆有可能~

自回归图像生成的封建迷信

论文开头第一句就揪出了封建迷信所在:

Conventional wisdom holds that autoregressive models for image generation are typically accompanied by vector-quantized tokens.

随后,作者就当机立断地破除了它:

it is not a necessity for autoregressive modeling.

开篇立意明确,一阵“爽朗”之风迎面吹来,em.. 这篇文章在高考场上应该能拿高分!呃,sorry,跑偏了,现在回到正题。

当今流行的自回归图像生成玩法都是借(照)鉴(抄)隔壁 NLP 的,NLP 的自回归生成是基于之前(已经生成)的 token 来预测下一个 token,通常是从左到右 one-by-one 地生成整个 token 序列。由于自然语言天然是离散的,因此每个 token 就顺理成章地被建模为类别分布(Categorical distribution),属于离散随机变量分布。这种简单直白的玩法在大力出奇迹的信念下取得了出奇好的效果,造就了如今 LLMs 不可一世的姿态。

看见隔壁 NLP 如此气盛,CV 小可爱们难免眼红。于是,CVer 们的心声:既然自回归这种简单无脑的玩法这么好使,何不拿(抄)过来试试?BUT! 下一秒他们便发现,直接抄是行不通的,因为图像宝宝们天然是连续(continuous-valued)的啊.. 卧勒个去!

但 CV 界从来都是人才济济,稍加思索,他们便想到了法子——基于图像数据集训练一个离散的 tokenizer 用于对图像进行离散化,从而将一批“特性”相似的连续值像素用一个共同的离散值表示(实际上该离散值背后还是对应着一个连续值的向量,离散值可看作是这个向量的“编号”),这法子在圈内叫作 "VQ(Vector Quantization)",经典代表有 VQ-VAE 等。

于是,对图像进行离散化后,也照样可以将像素如 NLP 的 token 一样建模为类别分布了(从而被叫作 "image token"),也同样可以自回归地基于已经生成的像素去预测(分类)下一个像素了。由此,后面就诞生了一批 "autoregressive with vq" 的代表:iGPT, DALL-E, VQ-GAN, MAGE, MaskGIT

虽然这么做是 work 了,但本论文作者不免觉得别扭,他由心地发出疑问:

Is it necessary for autoregressive models to be coupled with vector-quantized representations?

毕竟大伙有目共睹,VQ tokenizer 是真的难训,其中 quantized vector 的采样(从 codebook 中)是不可导的,于是通常采用 straight-through(https://blog.csdn.net/weixin_43135178/article/details/140160466) 这样的梯度估计方法将 quantized vector 的梯度(来自 Decoder)直接复制给 encoder output vectors,这种近似而不准确的梯度是导致其不容易训好的原因之一。

不妨来重新思考下 autoregressive 与 vq 的关系:自回归代表的仅仅是“基于已知的预测未知的”,与“数据值本身是离散还是连续”应该是毫不相干的,VQ 是基于照抄隔壁 NLP 的念头(从而才能将像素也变成像 language token 一样是离散的)才被理所当然地加入到自回归的玩法中了,这念头本身就政治不正确!

就像中国要发展特色社会主义一样,图像天然是连续的,没有必要盲目模仿自(资)然(本)语(主)言(义)而整容成离散的,要本其优势寻求合适的方法去发展壮大。也就是说,像素不一定要建模为类别分布(天然就不合适),在隔壁 NLP 中是因为自然语言天生是离散的所以才很自然地将 token 建模为类别分布,它们很好地利用了自己先天的“势”,找到了合适的“术”,从而在通往“道”的方向上前进了一大步,这个思想是很值得 CVer 们借鉴的,但切忌照搬他人之术。

由此可知,真正的关键在于要合适地建模每个像素的分布,这个分布要使得我们可以从中采样,并且有相应的 loss 函数去衡量建模的好坏。

用扩散模型来建模分布

要说当下图像生成的流量明星,那自然是扩散模型啦!既然刚刚说了关键点在于建模每个像素的分布,那么何不把扩散模型拿过来使呢,并且其天然就适合建模连续型分布(扩散模型反而在建模离散型分布方面有些棘手)。

另外,上文一直在针对每个像素的分布来论述,然而实际上是可以像 LDM 一样在 latent space 里玩,从而建模的就是每个 latent 变量的分布了。为了便于与隔壁 NLP 统一(人们都偏好简单粗暴地对不同形式的事物进行统一),我们也将 latent 变量叫作 token,只不过这 token 是连续值的,美名曰:"continuous-valued token".

至于如何将像素变成 latent(token),已经有诸多前辈(e.g. VAE)为我们铺好路了,实质上就是对原始图像进行压缩,使其变成更为“紧凑”的向量表征,同时提取了抽象语义。对于扩散模型来说,把它当成图像那样玩即可,什么扩散加噪、去噪生成等过程都不用改。

但是,与通常扩散模型建模图像分布不同,在那里是等价于要建模所有像素的联合分布,而在此处则是变为建模每个 token 的分布。引用论文原话表述就是:

in our case, the diffusion model is for representing the distribution for each token.

也因此模型的体量自然就无需那么大了,用个简单的 MLP 即可,而在建模图像分布的情况下则通常会用到 U-Net 甚至是 Transformer 等庞然大物并且结合 attention 机制(也就是说这里的扩散模型并没用上注意力机制)。

自回归网络辅助扩散模型做条件生成

既然扩散模型充当了建模 token 分布的角色,于是它就相当于用作预测的头部(prediction head),用于生成 token,就像图像分类网络的头部一样,预测结果是由它这里输出的。那么自回归网络那部分生成的就不是 token 而是某种辅助扩散模型去生成 token 的条件变量(也是连续值的),它与 token 是一一对应的关系。

也就是说, 自回归网络基于已知 token 去预测未知 token 所对应的条件变量, 然后进一步把它给到扩散模型去辅助生成对应的未知 token。 记已知 token 为 , 未知 token 所对应的条件变量为 , 那么自回归网络建模的过程就是 , 而扩散模型建模的则是 。结合如 DDPM 里用到的重参数化技巧, 扩散模型训练的 loss 函数就可以表示为:

其中 是标准高斯噪声, 就是在时间步 下的噪声扰动向量(此处是 token)。

这实际上训的就是条件扩散模型,以自回归网络的输出为条件变量来做条件生成,正是 CFG(Classifier-Free Guidance) 那套,于是训练方法还可以白嫖一波~

最吃香的是,这个 loss 不仅能训练扩散模型,而且还能将自回归网络也一并训了! 因为梯度能从 传过去,这就是没有 VQ 的好处 —— z 是自回归网络输出的 continuous-valued latent,而非从 codebook 中采样而来(采样操作不可导)。

考虑到这个 loss 在这里起到这么关键的作用,作者觉得务必给它起个名字,大名曰:"Diffusion Loss" .

重新审视自回归的意义

在破除了“自回归图像生成需要和 VQ 绑定”这个封建迷信后,作者进一步重新审视了“自回归本身的意义”,即:到底什么是“真•自回归”?

如本文“前言”一节所述,大部分人对于自回归的刻板印象就是:“从左到右(raster order)”、“一个个(one-by-one)地”、“基于已知的去预测未知的”。然而最贴近自回归本身意义的,应该仅仅是“基于已知的去预测未知的”,而“从左到右”和“一个个地(每次只预测一个)”并非是必须的,既非充分也非必要条件。

基于这种觉悟,作者“重塑”了大家对自回归的认知。

首先是预测的顺序,不一定非得是先预测左边再预测右边(对于图像这种二维结构则延伸为从左上到右下),毕竟对于图像来说,像素之间并没有明确的顺序规定;其次是预测的数量,每次不只是预测一个,而是预测一批,引用论文中的表述就是“next set-of-tokens prediction”,这样,在相同的迭代步骤下就能更快地预测完所有 tokens,从而起到加速作用。将这两方面结合起来,就变成先随机预测一批 tokens,然后再基于已经预测的这批 tokens 去预测未知的下一批(也是随机选择的) tokens。

另外,通常以 Transformer 架构去玩自回归时,会用到 causal attention,这是一种从左(前)到右(后)的单向注意力,于是后面的 tokens 就看不到前面的。然而作者认为只要遵循“基于已知的去预测未知的”就符合自回归的定义了,与 token 之间是如何交互的没有关系,也就是说自回归不应当受到单向注意力的约束。

the goal of autoregression is to predict the next token given the previous tokens; it does not constrain how the previous tokens communicate with the next token.

于是,作者毫不犹豫地采取了双向注意力(bidirectional attention)机制(顺便 Q一下还有多少人记得 BERT),这样能够使得 tokens 之间的交互更加充分。最后,作者进一步结合 MAE 的做法——基于未 mask 的 tokens 去预测 masked tokens 中随机挑选的一批;新预测的这批 tokens 的 mask 被放开(成为 unmasked tokens),它们与之前的 unmasked tokens 再一起去预测剩下的 masked tokens 中随机挑选的一批。这是利用了掩码生成模型天然维持了自回归的特性——基于已知(unmasked)的去预测未知(masked)的。

Conceptually, masked generative models predict multiple output tokens simultaneously in a randomized order, while still maintaining the autoregressive nature of “predicting next tokens based on known ones.

就这样不断重复执行自回归预测,masked tokens 便逐步减少,最终所有 mask 都被放开,于是生成了所有 tokens.

作者将他这么玩的模型称作 "Masked Autoregressive (MAR)" models:

MAR is a random-order autoregressive model that can predict multiple tokens simultaneously.

Workflow

前面我们讲了这篇论文的主要方法和关键部分,但整个模型具体是怎么 work 的或许还没讲清楚。这一节会将模型的各部分串起来,从输入到输出,包括训练和推理流程,都扒得明明白白。

训练流程

  • from pixel to latent space

上文已经提到,MAR 是在 latent space 中玩的,原始图像 pixel 会先经过编码转换成 latent,前面谈论的 token 也是处于 latent space 中。对于 pixel 和 latent 之间的切换,作者采用了预训练的 VAE 来实现,其中的 Encoder 负责将 pixel 编码为 latent,而 Decoder 则负责将 latent 解码回 pixel。

所以训练流程的第一步就是将输入图像喂给 VAE Encoder 将其编码为 latent space 中的向量。

  • patchify

VAE Encoder 输出的 latent vectors 和输入图像一样是 (b, c, h, w) 的 4-dims 结构,为了方便接下来的 AR(自回归) 网络(通常是 Transformer)进行处理,于是将其划分成为 patches(如 ViT 一样的做法),成为 (b, l, d) 的 3-dims 结构(和隔壁 NLP 玩 token 序列时一样),其中 l = (h // p) x (w // p), d = c x p x p,p 代表 patch size,实质上这就是 reshape 操作。

划分后的每个 patch 被视作 image token(continuous-valued), 同时会将它们克隆一份作为 ground truth latents, 作为扩散模型的输入 , 在每个时间步 对它们按照 noise schedule 进行加噪就得到被噪声扰动的latents , 如扩散模型操作像素空间(i.e. 对图像进行加噪)一般。

  • random masking

接下来就是 MAE 的随机 mask 掉部分 tokens 的操作:设置一个最小的 mask ratio(通常是 70%),然后从截断的正态分布(Truncated Normal distribution)中采样一个掩码比例,使得比例值在最小 mask ratio 与 100% 之间,然后对 tokens 按比例(在数量上)进行 mask,masked tokens 是随机挑选的。

需要注意的是,同一个 batch 的掩码比例相同,但每个样本中哪些位置的 tokens 要被 mask 则是不同的,也就是每个样本单独随机挑选 masked tokens。

  • MAE

然后就是 MAE encode + decode 的流程了。首先 Encoder 接收 unmasked tokens 进行编码,然后 Decoder 将 masked tokens 连同 Encoder 的编码结果一起进行解码,输出提供给扩散模型的条件变量 zzz,其中在 Encoder 和 Decoder 中都要为 tokens 加上位置编码(在编/解码操作前)。

不过,这其中还藏着有别于 naive MAE 的操作。由于输入图像先经过 VAE 下采样变为 tokens(“尺寸”相比原图变小了),然后又 mask 掉一部分而仅把剩下 unmasked 的部分给到 Encoder,因此 Encoder 拿到手的 token 序列就可能非常短。为了充分利用上计算资源,作者就在 token 序列的开头补上 64 个 [cls] tokens(也要加上位置编码) 而后再丢给 Encoder 去编码。同时,为了能够直接把 CFG 的那套训练方法拿过来用,每个样本所对应的 64 个 [cls] tokens 会以一定概率全部设置为真实的 class embeddings 或 fake latent(也就是假的条件向量,用于无条件生成)。

Decoder 解码后会先将 [cls] tokens 所对应的解码结果丢掉,然后再次加上位置编码(与前面的位置编码向量是独立的),这才是最终给到扩散模型条件向量 zzz 。

不过 CW 认为最后这次的位置编码是否必须得加非常值得怀疑!

于是抱着又社恐又难以按捺的心情问了作者大大,没想到个人所想与作者的契合度还蛮高:

Decoder 的解码结果本身已包含了位置信息,因此就逻辑上来说,最后的这次位置编码是没有必要的。

无奈作者所用的预训练模型也是用了最后这次的位置编码来进行训练的,所以就把这个逻辑保留在代码中了。

  • Diffusion loss

最后一步就是计算 Diffusion loss 了,其实就是扩散模型的常规训练方法。

将先前在 patchify 阶段 clone 下来的 gt latents 和 MAE Decoder 解码出来的 conditioning vectors 一同喂给扩散模型, 然后随机采样一个时间步 , 根据 noise schedule 计算出该时间步对应的噪声强度从而对 加噪得到 , 接着扩散模型根据 去预测噪声 ,最后用 MSE loss 计算和真实噪声的误差即是。

需要注意的是,真正用于计算梯度的仅仅是 masked tokens 那部分的 loss,只需将计算出来的 loss tensor 对应乘上 mask 即可,因为 loss tensor 和是一样的 shape.

另外,由于时间步是采样出来的,为了让模型在每个时间步学习得更加充分(每个时间步对应不同噪声强度,不同信噪比,模型需要懂得区分它们以便正确去噪),作者在每个时间步下都会将样本复制4份以达到对同一时间步采样4次的等价效果。并且前面也提到了,扩散模型的结构非常小(small MLP network),因此这么做并不会带来太大的负担。

推理流程

在进行推理时,由于没有输入图像(目标就是要生成图像),因此直接在 latent space 开玩,待 MAE + Diffusion models 自回归地生成所有 tokens 后,再由 VAE 的 Decoder 解码成图像。

那么究竟是如何自回归的呢?上文没有讲到,毕竟训练流程是体现不出来的(就像隔壁 NLP 训练 autoregressive models 一样,在训练过程中是并行解码的)。概括来说,就是 MAE encode + decode 后将条件向量给到扩散模型,后者结合该条件向量进行去噪生成(如常规扩散模型的采样生成一般),生成的 tokens 作为已知(unmasked) tokens 再回馈给 MAE 去预测未知(masked) tokens 所对应的条件向量 zzz ,然后喂给扩散模型再次进行去噪生成,生成的结果又作为已知 tokens 给到 MAE 进行下一轮的生成。

就 MAE + Diffusion models 这个整体来说,扩散模型在其中才像是真正的 Decoder ——MAE encode 出富含语义的条件向量辅助扩散模型去 decode 出未知 tokens.

OK,以上仅仅是简单粗暴的概述,接下来 CW 就为大家详细剖析清楚推理流程中各主要环节的具体操作。

  • sample order, autoregressive steps & mask schedule

首先,在采样前,会预先为每个样本随机指定不同的采样次序,从而规定了生成 tokens 的顺序。

接着,要设置自回归的步数,也就是你打算分几步来完成整个生成过程,论文中作者使用了64步。

然后,根据这个步数,定义一种 mask 策略,使得 mask 比例随步数增加而减少,从 100% 降至 0(实际上不会到 0,最后一步预测要保证至少有1个 masked token,而这一步结束后所有 tokens 就都预测完了),从而使模型能够顺利根据已知(unmasked)的去预测未知(masked)的 tokens. 作者使用的是 cosine schedule,使得 mask 比例呈余弦曲线下降的趋势。

假设指定好的 token 生成顺序为 [10, 13, 18, 15, 40, 50, 66, 70, ...],数字代表各 token 在原序列的位置,那么根据以上做法,可能产生的效果就是:在第一轮先是生成 10, 13, 18 这3个位置的 tokens,下一轮再生成 15, 40, 50, 66, 70 这几个位置 tokens. 一开始所有位置的 tokens 都被 mask 住,随着自回归迭代,mask 逐步放开,越来越多的 tokens 成为 unmasked(被生成了),但它们之间 mask 被放开的相互顺序是在采样前就预先指定好的

  • classifier-free guidance

由于使用了 CFG 的训练方法,因此 MAR 天然就可以实现条件生成,比如生成 ImageNet 数据集里其中一个类别的图片。

如果要实现条件生成,那么在采样时就会额外给模型输入一个指定 label,然后将其编码成 class embeddings;同时将无条件的 fake latent 等量(在 batch 维度)进行复制,接着将其与 class embeddings 在 batch 维度拼接(concat)在一起,这是因为 CFG 需要同时预测含 label 情况下的条件噪声与无条件(i.e. 不含 label)的噪声,为避免让模型分别进行两次前向过程,就选择拓展样本数以达到同时预测不同类型噪声的效果(从而这种操作会使得 batch size x 2)。 最后,拼接后的结果就会作为 64 个(这 64 指的是在 sequence 维度) [cls] tokens 喂给 MAE Encoder;相对地,如果是不含 label 的无条件生成(i.e. without guidance),那么就全部使用 fake latent 作为 64 个 [cls] tokens.

另外,由于在有 guidance 的情况下,batch size “被迫”增加了一倍,因此处理的 tokens 和 mask 都得相应增加,即直接在 batch 维度复制多一份。

  • MAE

在推理阶段 MAE encode + decode 的流程与训练时是一样的,这里就不再赘述了~

  • computing mask

根据当前自回归迭代的步数与预定义的 mask schedule 计算出 mask 比例,并设置好下一轮自回归生成所要用到的 mask(mask_next),同时还要计算另一种特殊的 mask(mask_to_pred) 用于指定该轮生成的 tokens 在哪些位置,以便从 MAE Decoder 的解码结果中取出这些 tokens 所对应的 conditioning vectors( zzz ).

mask_to_pred 是根据当前 maskmask_next 来进行计算的:在当前 mask 中值为 True 而在 mask_next 中为 False 的那些位置就是本轮需要预测(生成)的 tokens 位置,这代表它们在本轮是 masked tokens 而在下一轮是 unmasked tokens,于是对应在 mask_to_pred 中这些位置就为 True.

你或许会问:那在当前 mask 中本来为 False 但在 mask_next 中却变为 True 的那些位置咋办?

很抱歉,没有这种情况。因为前面已经说过,采样次序是在采样开始前就预先指定好的,mask 只是根据这个次序逐步放开。由于 tokens 生成的顺序已经被固定,因此当前已经是 unmasked 的位置在之后也会一直 keep 住是 unmasked 的。

  • token sampling by Diffusion models

利用计算好的 mask_to_pred 在 MAE Decoder 的解码结果中将所要生成的 tokens 的 conditioning vectors( ) 取出来, 然后喂给扩散模型做去噪生成(如常规扩散模型一般,从纯高斯噪声开始迭代去噪)。如果是含 label 的条件生成(i.e. with guidance),那么初始噪声(采样起始点)需要在 batch 维度复制多一倍,因为 CFG 需要同时估计含 label 的条件噪声与无条件噪声(前面也已经说过),此时的 也是 2 x batch size 的,包含了等量的 class embeddings 和 fake latents 的编解码结果。

另外,在扩散模型采样时,作者还参考了 Classifier-Guidance(CG) 这篇 paper 中的建议,使用了温度参数 来 scale 每步采样时的噪声,从而达到调节生成多样性的效果。待扩散模型生成 tokens 后,就将它们提供给 MAE 进行下一轮的自回归生成。

  • from latent to pixel

当自回归流程全部完成后,就生成了所有的 tokens,但它们是 (b, l, d) 的 3-dims 结构并且是处于 latent space 中的,所以我们需要先进行 unpatchify(i.e. reshape),将其变为 (b, c, h, w) 的 4-dims 结构,然后利用预训练 VAE 的 Decoder 将其解码回图像空间。

到此为止,整个推理流程就结束了,这就是由 latent vectors 生成图像的整个过程。最后有一点提一下:一开始进行自回归生成时(第一轮) mask 全为 True,代表全为 masked tokens,从而 MAE Encoder 的输入仅仅是那 64 个 [cls] tokens,这也体现了 pad 这些 [cls] tokens 的作用,否则 MAE Encoder 就只有玩空气的份~

核心源码解析

这一节会对 MAR 的“原创”代码实现进行解析,与上一节的理论剖析相对应。所谓“原创”即其核心思想逻辑但不包括从其它 codebase 搬运过来的部分,比如有关 VAE 的输入输出流程 以及 扩散模型的计算逻辑,诸如这些 CW 就不在这里展示了,有精神的友友们可以自行参考官方库。

附完整源码:https://github.com/LTH14/mar

训练流程

由于省略了 VAE 将 pixel 编码至 latent space 这部分,因此以下所涉及的 code 都是在 latent space 中玩的。尽管有些变量命名为 img ,但千万别当真,它其实是 latents.

  • 主要逻辑

MAR 自然是会被封装为一个类(继承 nn.Module)的,训练的主要逻辑(输入、输出 & 计算 loss)就放在了其 forward() 方法中。值得注意的是,这里用到了三个位置编码,并且每一个都是可学习的。

建议先不看初始化(__init__())方法,直接看 forward(),之后再调头回来看~

class MAR(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""


def __init__(self, img_size=256, vae_stride=16, patch_size=1,
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm,
vae_embed_dim=16,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4,
grad_checkpointing=False,
):
super().__init__()

# --------------------------------------------------------------------------
# VAE and patchify specifics
self.vae_embed_dim = vae_embed_dim

self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
self.seq_len = self.seq_h * self.seq_w
self.token_embed_dim = vae_embed_dim * patch_size**2
self.grad_checkpointing = grad_checkpointing

# --------------------------------------------------------------------------
# Class Embedding
self.num_classes = class_num
self.class_emb = nn.Embedding(1000, encoder_embed_dim)
self.label_drop_prob = label_drop_prob
# Fake class embedding for CFG's unconditional generation
self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))

# --------------------------------------------------------------------------
# MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
self.mask_ratio_generator = stats.truncnorm(
(mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)

# --------------------------------------------------------------------------
# MAR encoder specifics
self.z_proj = nn.Linear(self.token_embed_dim,
encoder_embed_dim, bias=True)
self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
self.buffer_size = buffer_size
self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(
1, self.seq_len + self.buffer_size, encoder_embed_dim))

self.encoder_blocks = nn.ModuleList([
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
self.encoder_norm = norm_layer(encoder_embed_dim)

# --------------------------------------------------------------------------
# MAR decoder specifics
self.decoder_embed = nn.Linear(
encoder_embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(
1, self.seq_len + self.buffer_size, decoder_embed_dim))

self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

self.decoder_norm = norm_layer(decoder_embed_dim)
self.diffusion_pos_embed_learned = nn.Parameter(
torch.zeros(1, self.seq_len, decoder_embed_dim))

self.initialize_weights()

# --------------------------------------------------------------------------
# Diffusion Loss
self.diffloss = DiffLoss(
target_channels=self.token_embed_dim,
z_channels=decoder_embed_dim,
width=diffloss_w,
depth=diffloss_d,
num_sampling_steps=num_sampling_steps,
grad_checkpointing=grad_checkpointing
)
self.diffusion_batch_mul = diffusion_batch_mul


def forward(self, imgs, labels):
# class embed (B, D)
class_embedding = self.class_emb(labels)

''' patchify and mask (drop) tokens '''

# (B, C, H, W) -> (B, l = (H // P) * (W // P), C x P x P)
x = self.patchify(imgs)
# 相当于 x_0, 作为扩散模型训练的 gt, 根据 noise schedule 加噪可得 x_t
gt_latents = x.clone().detach()
# 对每个样本单独打乱 tokens 次序, 结合以下从而做到随机 mask 的效果
orders = self.sample_orders(bsz=x.size(0))
# 计算 mask 比例 r%, mask 掉以上 orders 中前 r% 位置的 tokens
# 由于 orders 是随机顺序, 因此实现了随机 mask 的效果
mask = self.random_masking(x, orders)

''' MAE encode & decode '''

# mae encoder
# 在 token 序列前 pad 上 64 个 [cls] tokens,
# 然后与 unmasked tokens 一起(加上位置编码)进入到 encoder 进行编码
x = self.forward_mae_encoder(x, mask, class_embedding)

# mae decoder
# 将 encoder 的编码结果与 masked tokens 一起(再次加上位置编码)进行解码,
# 解码后去掉 64 个 [cls] tokens 对应的解码结果(最后再加一次位置编码).
z = self.forward_mae_decoder(x, mask)

# diffloss
# 与常规扩散模型的 loss 计算类似, 这里是对 `gt_latents` 加噪得到 x_t,
# 然后将 x_t, t, z 输入扩散模型去估计噪声, 采用与真实噪声的 MSE 进行训练,
# 但是 loss 只取 masked tokens 所对应的部分
loss = self.forward_loss(z=z, target=gt_latents, mask=mask)

return loss
  • 随机 mask

随机 mask 实际上是通过随机采样 token 次序而实现的,看以下代码就懂了。

def sample_orders(self, bsz):
# generate a batch of random generation orders
orders = []
for _ in range(bsz):
order = np.array(list(range(self.seq_len)))
np.random.shuffle(order)
orders.append(order)
orders = torch.Tensor(np.array(orders)).cuda().long()

return orders

def random_masking(self, x, orders):
# generate token mask
bsz, seq_len, _ = x.shape
# 从截断的正态分布中采样出 mask 比例
mask_rate = self.mask_ratio_generator.rvs(1)[0]
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
mask = torch.zeros(bsz, seq_len, device=x.device)
# 因为 orders 是随机的 tokens 次序, 所以计算出需要 mask 的 token 数量后,
# 将 orders 前面这么多数量的 tokens 掩盖掉即实现了随机 mask 的效果
mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
src=torch.ones(bsz, seq_len, device=x.device))

return mask
  • MAE

MAE 编解码的实现如下所示,重点我都在以下进行注释了,结合上一节的解释一起搭配食用即可。

def forward_mae_encoder(self, x, mask, class_embedding):
# 将最后一维映射到 encoder embedding dim
x = self.z_proj(x)
bsz, _, embed_dim = x.shape

# 提前预留出 64(即 `buffer_size`) 个 [cls] tokens 的位置, 初始化为 0, 拼接在原 token 序列前面
x = torch.cat([torch.zeros(bsz, self.buffer_size,
embed_dim, device=x.device), x], dim=1)
# mask 也要相应拓展, 值为 0 表示 [cls] tokens 均不会被 mask
mask_with_buffer = torch.cat(
[torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

# random drop class embedding during training
# CFG 的那套玩法, 在训练时以一定概率 drop 掉条件项(此处以 `fake_latent` 作为无条件的表示),
# 从而实现有条件噪声与无条件噪声估计的训练
if self.training:
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
drop_latent_mask = drop_latent_mask.unsqueeze(
-1).cuda().to(x.dtype)
class_embedding = drop_latent_mask * self.fake_latent + \
(1 - drop_latent_mask) * class_embedding

# 将 [cls] tokens 放到序列的前 64 个位置
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)

# encoder position embedding
x = x + self.encoder_pos_embed_learned
# 过一层 LayerNorm
x = self.z_proj_ln(x)

# dropping
# 仅拿 unmasked tokens 喂给 encoder
x = x[(1-mask_with_buffer).nonzero(as_tuple=True)
].reshape(bsz, -1, embed_dim)

''' encoder 编码 '''

# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.encoder_blocks:
x = checkpoint(block, x)
else:
for block in self.encoder_blocks:
x = block(x)

# 最后过一个归一化层
x = self.encoder_norm(x)

return x

def forward_mae_decoder(self, x, mask):
# 将最后一维映射为 decoder embedding dim
x = self.decoder_embed(x)
# 对原始 mask 拓展出 64 个 [cls] tokens 的位置, 值为 0 表示它们均不被 mask
mask_with_buffer = torch.cat(
[torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

# pad mask tokens
# 由于 masked 仅仅是1个维度为 decoder embedding dim 的向量,
# 因此要进行维度的扩展(在 batch 和 sequence 维度进行复制)
mask_tokens = self.mask_token.repeat(
mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
# 先全部初始化为 masked tokens, 而后把 encoder 的编码结果放到 unmasked 部分
x_after_pad = mask_tokens.clone()
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = \
x.reshape(x.shape[0] * x.shape[1], x.shape[2])

# decoder position embedding
x = x_after_pad + self.decoder_pos_embed_learned

''' decoder 解码 '''

# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.decoder_blocks:
x = checkpoint(block, x)
else:
for block in self.decoder_blocks:
x = block(x)

# 经过一个归一化层
x = self.decoder_norm(x)

# 去掉 [cls] tokens 所对应的解码结果
x = x[:, self.buffer_size:]
# 最后再加上另一个位置编码(与前面的位置编码不同)
x = x + self.diffusion_pos_embed_learned

return x
  • Diffusion loss

以下相当于是计算 loss 前的“准备工作”,真正的计算逻辑并不在此处展现。

def forward_loss(self, z, target, mask):
bsz, seq_len, _ = target.shape

# 之所以要在个数上复制 `diffusion_batch_mul` 这么多倍,
# 是为了实现在每个时间步下采样多次从而达到充分训练的效果, 如论文中所述
target = target.reshape(
bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)

loss = self.diffloss(z=z, target=target, mask=mask)
return loss

Diffusion loss 被封装成一个类,其所用的扩散模型相关的计算逻辑“抄”自大名鼎鼎之 OpenAI 的 ADM(https://github.com/openai/guided-diffusion/tree/main).

class DiffLoss(nn.Module):
    """Diffusion Loss"""

    def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
        super(DiffLoss, self).__init__()
        self.in_channels = target_channels
        self.net = SimpleMLPAdaLN(
            in_channels=target_channels,
            model_channels=width,
            out_channels=target_channels * 2,  # for vlb loss
            z_channels=z_channels,
            num_res_blocks=depth,
            grad_checkpointing=grad_checkpointing
        )

        self.train_diffusion = create_diffusion(
            timestep_respacing="", noise_schedule="cosine")
        self.gen_diffusion = create_diffusion(
            timestep_respacing=num_sampling_steps, noise_schedule="cosine")

    def forward(self, target, z, mask=None):
        t = torch.randint(0, self.train_diffusion.num_timesteps,
                          (target.shape[0],), device=target.device)
        model_kwargs = dict(c=z)

        loss_dict = self.train_diffusion.training_losses(
            self.net, target, t, model_kwargs)
        loss = loss_dict["loss"]
        # 仅取 masked tokens 所对应的 loss
        if mask is not None:
            loss = (loss * mask).sum() / mask.sum()

        return loss.mean()

以上的 self.net 代表一个小型的扩散模型,用于估计噪声,使用带 AdaLN 的 MLP 结构来实现,其对于时间步的编码采用了正余弦编码的方式,而对于 conditioning vectors(即 MAE Decoder 的解码结果)则直接使用一个全连接层映射到特定维度,整个输入输出的流程非常简单,如下:

class SimpleMLPAdaLN(nn.Module):
    """
    The MLP for Diffusion Loss.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param z_channels: channels in the condition.
    :param num_res_blocks: number of residual blocks per downsample.
    """


    ...  # 省略, 懒得贴

  
def forward(self, x, t, c):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param c: conditioning from AR transformer.
        :return: an [N x C x ...] Tensor of outputs.
        """

        x = self.input_proj(x)
        t = self.time_embed(t)
        c = self.cond_embed(c)

        y = t + c

        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.res_blocks:
                x = checkpoint(block, x, y)
        else:
            for block in self.res_blocks:
                x = block(x, y)

        return self.final_layer(x, y)

采样过程

  • 预备工作

在正式进入自回归生成前需要做一些预备工作,以下初始化 mask 全为 True,代表一开始全部都是 masked tokens;同时,还将 tokens 初始化为 0,但实际上“0并不发挥作用”,更多地像是起到了占位符的效果,原因 CW 写在以下注释中了;而上一节所说的在采样开始前确定采样次序即对应以下 sample_orders()

def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):

        ''' init and sample generation orders '''
        
        # 一开始 mask 掉所有 tokens
        mask = torch.ones(bsz, self.seq_len).cuda()
        # 虽然初始 token 设为 0, 但由于一开始全被 mask 掉, 因此实际上是 64 个 [cls] tokens 和 `self.mask_token` 
        # 分别在 encoder 和 decoder 起作用
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
        # 采样前先确定采样次序
        orders = self.sample_orders(bsz)

        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)

        ... # 省略,下文会接上
  • CFG 的相关设置

接下来就正式进入自回归迭代生成的流程了。

首先需要对当前的情况做判断,看是否是含 label 的条件生成,如果是,则需要将样本多复制一倍以便让网络同时估计含 label 的条件噪声和无条件噪声;否则,就将 class embedding 替换为 fake latent 按常规估计无条件噪声即可。

MAE 编解码的过程比较无聊,在前面的训练部分也已经展示过其中的代码逻辑了,于是就顺便在此处带过了。

        # 接以上内容

        ''' generate latents '''
        
        # 自回归迭代
        for step in indices:
            cur_tokens = tokens.clone()

            ''' class embedding and CFG '''
            
            # 含 label 的条件生成
            if labels is not None:
                class_embedding = self.class_emb(labels)
            # 无条件生成
            else:
                class_embedding = self.fake_latent.repeat(bsz, 1)
            
            # w CFG
            # CFG 在采样时需要同时估计含 label 的条件噪声和无条件噪声,
            # 于是需要将 class embedding & fake latent 拼起来并且将
            # tokens 和 mask 都复制多一倍样本以便让网络同时估计两种情况下的噪声
            if not cfg == 1.0:
                tokens = torch.cat([tokens, tokens], dim=0)
                class_embedding = torch.cat(
                    [class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
                mask = torch.cat([mask, mask], dim=0)

            ''' MAE encode & decode '''

            # mae encoder
            x = self.forward_mae_encoder(tokens, mask, class_embedding)

            # mae decoder
            z = self.forward_mae_decoder(x, mask)
  • computing mask

接着是计算 mask,包括下一轮用于指示 masked tokens 位置的 mask(mask_next) 以及 本轮指示 predicted tokens(i.e. 下一轮将会成为 unmasked tokens) 位置的 mask(mask_to_pred),原理在上一节已经阐述过。

其中,对于 mask_to_pred 的计算,采取对本轮的 mask 和下一轮的 mask_next 实施 XOR(亦或) 操作来实现。

            # 接以上内容

            # mask ratio for the next round, following MaskGIT and MAGE.
            mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
            # 根据 mask 比例和序列长度计算需要被 mask 掉的 token 数量
            mask_len = torch.Tensor(
                [np.floor(self.seq_len * mask_ratio)]).cuda()

            # masks out at least one for the next iteration
            mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                     torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))

            ''' get masking for next iteration and locations to be predicted in this iteration '''
            
            # 设置下一轮 masked tokens 的位置
            mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
            
            # 计算本轮需要预测的 tokens 对应在序列的哪些位置
            if step >= num_iter - 1:
                # 若本轮是最后一轮, 则需要预测的 tokens 位置就是之前 mask 掉的所有位置
                mask_to_pred = mask[:bsz].bool()
            else:
                # 本轮是 masked(=True) 但下一轮是 unmasked(=False) 的位置即为本轮需要预测的 tokens 位置
                # 使用 XOR(亦或) 操作即可实现
                mask_to_pred = torch.logical_xor(
                    mask[:bsz].bool(), mask_next.bool())
            
            # CFG 需要多复制一倍样本
            if not cfg == 1.0:
                mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
            
            mask = mask_next

其中,以上 mask_by_order() 与之前在训练部分展示的 random_masking() 是差不多的逻辑,只不过以上是提前把 mask 比例和将要 mask 掉的 token 数量在方法外部提前计算好,然后再传参过去。

def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
    
    return masking
  • 扩散模型采样

然后就是自回归步骤里的最后一个环节——扩散模型采样生成了。

先根据 mask_to_pred 取出所要生成的 tokens 对应的 conditioning vectors(z),同时设置好 guidance scale(cfg_iter) ,然后连同采样所用的温度参数一并投给扩散模型去操作,最后将扩散模型的采样结果放到 token 序列的对应位置随即开启下一轮的自回归生成。

            # 接以上内容
            # sample token latents for this step
            # 取出本轮预测的 tokens
            z = z[mask_to_pred.nonzero(as_tuple=True)]
            
            # cfg schedule follow Muse
            if cfg_schedule == "linear":
                # 1 ~ `cfg`
                cfg_iter = 1 + (cfg - 1) * (self.seq_len -
                                            mask_len[0]) / self.seq_len
            elif cfg_schedule == "constant":
                cfg_iter = cfg
            else:
                raise NotImplementedError
            
            # 扩散模型采样生成
            sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
            # w CFG
            if not cfg == 1.0:
                # CFG 情况下, 样本多复制了一倍, 因此取出采样结果的一半即为目标
                sampled_token_latent, _ = sampled_token_latent.chunk(
                    2, dim=0)  # Remove null class samples
                mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)

            # 将采样结果放到序列对应的位置
            cur_tokens[mask_to_pred.nonzero(
                as_tuple=True)] = sampled_token_latent
            tokens = cur_tokens.clone()

以上看不到采样的具体逻辑,让我们进一步潜到 self.diffloss.sample() 中去探探:

class DiffLoss(nn.Module):
    """Diffusion Loss"""

    ... # 省略部分内容

    def sample(self, z, temperature=1.0, cfg=1.0):
        # diffusion loss sampling
        if not cfg == 1.0:
            noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
            noise = torch.cat([noise, noise], dim=0)
            model_kwargs = dict(c=z, cfg_scale=cfg)
            sample_fn = self.net.forward_with_cfg
        else:
            noise = torch.randn(z.shape[0], self.in_channels).cuda()
            model_kwargs = dict(c=z)
            sample_fn = self.net.forward

        sampled_token_latent = self.gen_diffusion.p_sample_loop(
            sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
            temperature=temperature
        )

        return sampled_token_latent 

其中主要是对是否是含 label 的条件生成做了区分,含 label 的条件生成对应 self.net.forward_with_cfg() 方法;无条件生成则对应 self.net.forward(),对应上一节训练部分展示的 SimpleMLPAdaLN 类的 forward() 方法,就是根据输入直接估计噪声。

所以,以下就单独来看 self.net.forward_with_cfg()

class SimpleMLPAdaLN(nn.Module):
    """
    The MLP for Diffusion Loss.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param z_channels: channels in the condition.
    :param num_res_blocks: number of residual blocks per downsample.
    """


    ...  # 省略部分内容

    def forward(self, x, t, c):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param c: conditioning from AR transformer.
        :return: an [N x C x ...] Tensor of outputs.
        """


        ...  # 省略

    def forward_with_cfg(self, x, t, c, cfg_scale):
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, c)
        eps, rest = model_out[:,
                              :self.in_channels], model_out[:, self.in_channels:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)

        return torch.cat([eps, rest], dim=1

不出所料,就是 CFG 的采样方法:

其中 就是对应时间步估计的噪声, 代表 label 对应的 class embedding, fake latent 被省略了。

至于 self.gen_diffusion.p_sample_loop(),是从 ADM 搬过来的采样实现,就是常规扩散模型在迭代采样过程中所涉及的数学计算,不属于 MAR 的原创,CW 就不在这里展示了。

  • latent to pixel

待所有的自回归步骤都实施完毕,就对生成的 tokens 进行 unpatchify(i.e. reshape),变回像图像一样的 (b,c,h,w) 4-dims 结构。

def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
    ...  # 省略部分内容

    ''' generate latents '''
        
    # 自回归迭代
    for step in indices:
        cur_tokens = tokens.clone()

        ...  # 省略部分内容
    
    # unpatchify
    tokens = self.unpatchify(tokens)
    return tokens

最终再经由 VAE 解码回图像空间,得到生成的图片,在外部看来就类似于:

model = MAR(...)
sampled_tokens = model.sample_tokens(...)
sampled_images = vae.decode(sampled_tokens)

"per-token" Distribution

作者在论文中反复宣称扩散模型在这里建模的是每个 token 的分布,而非像常规玩法那样建模的是所有 token 的联合分布。

unlike common usages of diffusion models for representing the joint distribution of all pixels or all tokens, in our case, the diffusion model is for representing the distribution for each token.

但是,何以体现 MAR 建模的是每个 token 的分布而非所有 token 的联合分布呢? CW 在看论文时就一直带有这个疑问,毕竟这点在论文中被反复强调(不信你可以到 paper 中去搜关键字),论文看完了也仍然没有找到答案,我确实没有从论文中找到充分的证据来帮作者说服自己,当时不禁产生了“作者在骗我”的情绪..

直至前阵子代码开源后就释怀了,答案藏在以下两方面:

  • 各 token 的噪声强度相互独立,这源于它们对应的时间步独立采样(噪声强度根据时间步设置)
  • 扩散模型单独根据每个 noised token 和对应的 conditioning vector 去预测噪声,期间没有使用 attention 来对多个 tokens 进行关联和交互

以上第一点在前面展示 Diffusion loss 的源码时就有所体现:

class MAR(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """

    ...  # 省略

    def forward_loss(self, z, target, mask):
        bsz, seq_len, _ = target.shape
        target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)

        ... # 省略


class DiffLoss(nn.Module):
    """Diffusion Loss"""
    ...  # 省略

    def forward(self, target, z, mask=None):
        t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)

这就是为何 targetz 等要 reshape 成为 (bsz * seq_len, \-1) 二维结构并且 t 要采样出 bsz * seq_len 那么多个的原因了。bsz -> bsz * seq_len 代表将原本属于同一个样本(i.e. 图片)的 tokens “断绝关系”,让每个 token 自己单独作为一个样本。

并且,在采样时 z 也是形如 (n, d) 的二维结构,于是每个 token 也遵循独立采样。

class MAR(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """

    ...  # 省略    
    
    def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
        ...  # 省略

        # generate latents
        for step in indices:    
            z = z[mask_to_pred.nonzero(as_tuple=True)]
            ...  # 省略
            sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)

其中 z[mask_to_pred.nonzero(as_tuple=True)] 的结果就是二维的。

总结

用论文里的话来说,MAR 就是这样一种模型:

modeling the interdependence of tokens by autoregression, jointly with the per-token distribution by diffusion.

感恩作者大大替我想出了如此漂亮话,真滴十分精准到位~

论文的实验部分展示了 MAR 方法的有效性,包括 Diffusion loss 的灵活适配性(甚至能够用在 VQ 模型上)和对比交叉熵损失的优越性、maksed AR(MAE) 比起常规 AR 在速度和精度上的权衡等,感兴趣的友友们可以自行食用论文。

完爆交叉熵损失

以上值得注意的是,在 AR 那部分,Diffusion loss 比起交叉熵损失带来的 FID 提升(越小越优)比较少,可能是因为传统 AR(i.e. raster order + causal attention + only predict 1 token per time) 为扩散模型提供的 conditioning context 有限,作者大大对此也没有定论。

Diffusion loss 的灵活性
MAR vs AR 在速度与精度方面的权衡

虽然 MAR 只在 ImageNet 上进行了实验,也没有 scale 到大规模(数据 & 模型尺寸),其有效性在许多方面还有待检验,但其方法本身在一定程度上还是能起到“耳目一新”的效果,破除了诸如以下的刻板印象:

  • 自回归就是从左到右一个个地预测
  • 自回归生成就是在做分类预测(是离散的),对于图像生成就得使用 VQ 进行离散化

感觉现在冒出越来越多“拿扩散模型打辅助”的例子,如最近这篇 Diffusion Feedback Helps CLIP See Better 就是利用了预训练扩散模型来提升 CLIP 的细粒度视觉能力,只不过其中的扩散模型是不参与训练的,仅用于为 CLIP 提供梯度;又如去年的这篇 Diffusion-TTA: Test-time Adaptation of Discriminative Models via Generative Feedback 则是利用了预训练扩散模型来提升判别模型的准确性,同样扩散模型也是用于提供梯度。

从这种视角来看,MAR 也算是以上两者的好友——其中的扩散模型同样起到提供梯度的作用。而也正因如此,其中的自回归网络(AR) 便“摆脱”了 VQ,充分体现出“拿扩散模型为 AR 打辅助的精神”。

这么一想,不小心发现 MAR 貌似也可以此为题:

Diffusion Feedback Helps Autoregressive Model Get Rid of Vector Quantization


推荐阅读

1、加入AIGCmagic社区知识星球

AIGCmagic社区知识星球不同于市面上其他的AI知识星球,AIGCmagic社区知识星球是国内首个以AIGC全栈技术与商业变现为主线的学习交流平台,涉及AI绘画、AI视频、ChatGPT等大模型、AI多模态、数字人、全行业AIGC赋能等50+应用方向,内部包含海量学习资源、专业问答、前沿资讯、内推招聘、AIGC模型、AIGC数据集和源码等

那该如何加入星球呢?很简单,我们只需要扫下方的二维码即可。知识星球原价:299元/年,前200名限量活动价,终身优惠只需199元/年。大家只需要扫描下面的星球优惠卷即可享受初始居民的最大优惠:

2、Sora等AI视频大模型的核心原理,核心基础知识,网络结构,经典应用场景,从0到1搭建使用AI视频大模型,AI视频大模型性能测评,AI视频领域未来发展等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

Sora等AI视频大模型文章地址:https://zhuanlan.zhihu.com/p/706722494

3、Stable Diffusion3和FLUX.1核心原理,核心基础知识,网络结构,从0到1搭建使用Stable Diffusion 3和FLUX.1进行AI绘画,从0到1上手使用Stable Diffusion 3和FLUX.1训练自己的AI绘画模型,Stable Diffusion 3和FLUX.1性能优化等全维度解析文章正式发布

码字不易,欢迎大家多多点赞:

Stable Diffusion 3和FLUX.1文章地址:https://zhuanlan.zhihu.com/p/684068402

4、Stable Diffusion XL核心基础知识,网络结构,从0到1搭建使用Stable Diffusion XL进行AI绘画,从0到1上手使用Stable Diffusion XL训练自己的AI绘画模型,AI绘画领域的未来发展等全维度解析文章正式发布

码字不易,欢迎大家多多点赞:

Stable Diffusion XL文章地址:https://zhuanlan.zhihu.com/p/643420260

5、Stable DiffusionV1-V2核心原理,核心基础知识,网络结构,经典应用场景,从0到1搭建使用Stable Diffusion进行AI绘画,从0到1上手使用Stable Diffusion训练自己的AI绘画模型,Stable Diffusion性能优化等全维度解析文章正式发布

码字不易,欢迎大家多多点赞:

Stable Diffusion文章地址:https://zhuanlan.zhihu.com/p/632809634

6、ControlNet核心基础知识,核心网络结构,从0到1使用ControlNet进行AI绘画,从0到1上手构建ControlNet高级应用等全维度解析文章正式发布

码字不易,欢迎大家多多点赞:

ControlNet文章地址:https://zhuanlan.zhihu.com/p/660924126

7、LoRA系列模型核心基础知识,从0到1使用LoRA模型进行AI绘画,从0到1上手训练自己的LoRA模型,LoRA变体模型介绍,优质LoRA推荐等全维度解析文章正式发布

码字不易,欢迎大家多多点赞:

LoRA文章地址:https://zhuanlan.zhihu.com/p/639229126

8、最全面的AIGC面经《手把手教你成为AIGC算法工程师,斩获AIGC算法offer!(2024年版)》文章正式发布

码字不易,欢迎大家多多点赞:

AIGC面经文章地址:https://zhuanlan.zhihu.com/p/651076114

9、10万字大汇总《“三年面试五年模拟”之算法工程师的求职面试“独孤九剑”秘籍》文章正式发布

码字不易,欢迎大家多多点赞:

算法工程师三年面试五年模拟文章地址:https://zhuanlan.zhihu.com/p/545374303

《三年面试五年模拟》github项目地址(希望大家能给个star):https://github.com/WeThinkIn/Interview-for-Algorithm-Engineer

10、Stable Diffusion WebUI、ComfyUI、Fooocus三大主流AI绘画框架核心知识,从0到1搭建AI绘画框架,从0到1使用AI绘画框架的保姆级教程,深入浅出介绍AI绘画框架的各模块功能,深入浅出介绍AI绘画框架的高阶用法等全维度解析文章正式发布

码字不易,欢迎大家多多点赞:

AI绘画框架文章地址:https://zhuanlan.zhihu.com/p/673439761

11、GAN网络核心基础知识、深入浅出解析GAN在AIGC时代的应用等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

GAN网络文章地址:https://zhuanlan.zhihu.com/p/663157306

12、其他

Rocky将YOLOv1-v7全系列大解析文章也制作成相应的pdf版本,大家可以关注公众号WeThinkIn,并在后台 【精华干货】菜单或者回复关键词“YOLO” 进行取用。

Rocky一直在运营技术交流群(WeThinkIn-技术交流群),这个群的初心主要聚焦于技术话题的讨论与学习,包括但不限于算法,开发,竞赛,科研以及工作求职等。群里有很多人工智能行业的大牛,欢迎大家入群一起学习交流~(请添加小助手微信Jarvis8866,拉你进群~)


WeThinkIn
Rocky相信人工智能,数据科学,商业逻辑,金融工具,终身成长,以及顺应时代的潮流会赋予我们超能力。
 最新文章