一文图解AF3原理
The Illustrated AlphaFold
AlphaFold3 模型架构的可视化解释,两万字长文,五十几张图,包含你可能没想找的更多细节和图表。
关键词
深度学习|结构预测|AlphaFold3
链接
原文:https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/ 知乎:https://zhuanlan.zhihu.com/p/721768100
目录
背景介绍
1.1 谁应该阅读本文? 1.2 模型总览 1.3 变量名和图表规则
输入准备
2.1 Tokenization 2.2 检索 Retrieval 2.3 构建Atom-Level表征 2.4 更新Atom-Level表征 2.5 Atom-Level → Token-Level
2.5.1 构建token-level的single表征 2.5.2 构建token-level的pair表征 2.5.3 AF3准备阶段的输出
表征学习
3.1 模版模块 3.2 MSA模块 3.3 Pairformer模块
3.3.1 为什么关注三角形(triangles)? 3.3.2 三角形更新 3.3.3 三角形注意力 3.3.4 带有对偏置的单一注意力 (Single Attention with Pair Bias)
扩散模块
4.1 扩散的基础 4.2 扩散模块
4.2.1 准备token-level的条件张量 4.2.2 准备atom-level的条件张量 4.2.3 在token-level应用注意力,并将结果投影回atom-level 4.2.4 在atom-level应用注意力,去预测atom-level的噪声更新
损失函数
5.1 损失函数和置信度头
5.1.1 5.1.2 5.1.3 5.2 置信度指标如何计算?
5.2.1 pLDDT 5.2.2 PAE 5.2.3 PDE
训练细节
6.1 循环利用 (Recycling) 6.2 交叉蒸馏 (Cross-distillation) 6.3 裁剪和多阶段训练 (Cropping and Training Stages) 6.4 冲突 (clashing) 6.5 批次大小 (Batch sizes)
思考总结
7.1 检索增强 7.2 对偏差注意力 7.3 自监督训练 7.4 损失设计 7.5 回收机制 7.6 数据的交叉蒸馏(Cross-distillation)
1. 背景介绍
小编之前虽有撰文《AF3万字长文解读》介绍了AF3,但总感觉自己属于囫囵吞枣,没有细嚼慢咽没掌握其细节。
最近Deepmind发布了基于AF3的Binder设计模型 AlphaProteo。
还有力文所基于AF3的一些模块,开发了Pallatom的全原子蛋白联合设计模型。
也看到小王随笔细嚼慢咽撰写了以下几篇关于AF3原理细节的文章:
这些激起了我再次探索AlphaFold3原理细节的欲望。直到我看到斯坦福大学 Elana Simon博士的博客文章《The Illustrated AlphaFold》,用图文并茂的方式阐述AF3的原理。我知道这就是我想要的,于是搬运到此。
1.1 谁应该阅读本文?
你想了解AlphaFold3是如何工作的吗?它的架构相当复杂,论文中的描述可能会让人感到不知所措、云里雾里,所以Elana制作了一个更加友好且详细的图表,进行可视化阐述其原理。
这篇文章主要为机器学习(ML)受众写的,多个部分假设你对注意力机制的步骤有所了解。如果你对这方面不太熟悉,可以看看 Jay Alammar的《图解 Transformer》https://jalammar.github.io/illustrated-transformer/, 它对模型架构进行了详细的可视化解释,也是本文图表和命名的灵感来源。
关于蛋白结构预测的意义、CASP竞赛等解释已经很多,所以本文不会专注于这些内容。相反,本文专注于:AF3如何预测分子结构,即AF3的算法原理:
各种分子在模型中是如何表征的? 以及进行了那些操作,最终如何将它们转换为预测结构的? 这可能要全面得多,但如果你想了解所有细节,并且喜欢通过图表学习,本文应该会对您有所帮助 :)
1.2 模型总览
本文将首先指出,AF3模型目标与之前的AF2有所不同。它不仅能预测单个蛋白序列(AF2)或蛋白复合物(AF-multimeter)的结构;AF3还能预测蛋白结构,可选地与其他蛋白、核酸、小分子、修饰等分子的复合物。所有预测的各种结构都是仅从序列出发。
因此,之前的AF2模型只需要表征标准氨基酸序列,但AF3必须表征更复杂的输入类型,因此有一个更复杂的特征化/标记化(Featurization/Tokenization)方案。输入的表征也在后文详细介绍。
Tokenization在本文,它要么代表单个氨基酸(对于蛋白),单个核苷酸(对于DNA/RNA),或者单个原子(对于修饰、小分子等)。
本文将AF3模型分为3个部分(上图1)去逐一阐述:
准备阶段(Input Preparation):用户提供一些分子的序列,以预测它们的结构,这些序列需要被嵌入到数值张量中。此外,模型检索一组被认为与用户提供的分子具有相似结构的其他分子。输入准备步骤识别这些分子,并将它们作为自己的张量嵌入。
表征学习(Representation learning):给定第1部分创建的单个和成对张量,我们使用许多变体的注意力机制来更新这些表征。
结构预测(Structure prediction):我们使用这些改进的表征,以及第1部分创建的原始输入,通过条件扩散来预测结构。
还有一些额外的部分本文也会谈及到:
损失函数、置信度头和其他相关的训练细节
从机器学习趋势的角度对模型的一些思考
1.3 变量名和图表规则
在整个模型中,蛋白复合物以两种主要形式表征:
“single”表征,它代表我们蛋白质复合物中的所有标记; “pair”表征,它代表复合物中所有氨基酸/原子对之间的关系(例如:距离、潜在的相互作用等)。 每种表征可以是原子级别和token级别的:
atom-level token-level 并且本文总是以下面这些名称和颜色显示(下图2):
: single 表征 & token-level, 浅的粉红色 : single 表征 & atom-level, 粉红色 : pair 表征& atom-level, 浅蓝色 : pair 表征 & token-level, 蓝色 除了以上4种表征之外,还有检索到的MSA表征和模版表征(下图3):
: MSA 表征, 橙色 : template 模版表征, 绿色 除了6种表征按照以上的颜色搭配,本文的介绍还有些特点如下:
图表省略了模型权重,仅可视化激活形状的变化;
激活张量总是用论文中使用的维度名称进行标记,图表的大小大致上,旨在反映这些维度的增长或缩小;
在可能的情况下,这些(以及每张)图表中张量上方的名称与AF3补充材料中使用的张量名称相匹配。通常,一个张量在模型中保持其名称。然而,在某些情况下,本文使用不同的名称来区分不同处理阶段的张量版本。例如,在原子级别的单(single)表征中, 代表初始的原子级别(atom-level)的单(single)表征,而 代表 这个表征在通过原子转换器后的更新版本;
: 初始的 single 表征 & atom-level; → 原子转换器 →
本文还为了简化忽略了大多数的LayerNorms,但它们被广泛使用。
2. 准备阶段
AlphaFold3模型的输入准备阶段,大致对应原文下图4的这些部分。
用户实际提供给AF3的输入是一个蛋白质的序列,以及可选的其他分子。输入准备阶段的目标是:将这些输入序列转换成一系列张量,这些张量将作为模型主干的输入。如下图5所示。这几种张量分别为:
: 初始的atom-level的single表征 : atom-level的pair表征 : atom-level的single表征 : token-level的single表征 : token-level的pair表征 : MSA表征 : template模版表征 简而言之,就是:
输入信息 → 4种表征(); 输入信息 → 检索Retrieval → 2种表征() 先有atom-level表征(),后有token-level的表征 () single表征,先有atom-level的,后才有token-level的 pair表征,先有atom-level的,后才有token-level的 输入准备阶段主要包含以下5个子模块,后文会对这几个子模块进行详细解释。这里先简单罗列如下:
Tokenization:会介绍分子是如何被Tokenization的,并阐明原子级别(atom-level)和标记级别(token-level)之间的区别。
检索 Retrieval 得到MSA和模版:解释如何将额外的数据库信息输入到模型中。这里将创建MSA表征 和结构模板表征 .
创建Atom-Level表征:这里创建原子级别表征单表示 和对表示 , 并包含了有关分子构象生成的信息。
更新Atom-Level表征:主要是“输入嵌入器(Input Embedder)”模块,也称为“原子转换器”,它重复3次并更新原子级别单表征 . 这里包含Atom Transformer,且里面的一些构建模块在模型后续部分也很重要,这些模块包括:
自适应LayerNorm (Adaptive LayerNorm) 带成对偏置的注意力(Attention with Pair Bias) 条件门控(Conditioned Gating) 条件转换(Conditioned Transition)
Atom-Level → Token-Level:将原子级别(atom-level)表征 (, ) 进行聚合,将多原子标记的一部分原子聚合,创建标记级别(token-level)表示 和 ,并包含来自 MSA () 以及任何用户提供的涉及配体的已知键的信息(上图5)。 2.1 Tokenization
Tokenization对应AF3原文模型框架图的位置,见下图6的蓝色部分:
在AF2中,由于模型仅表示标准的20种氨基酸组成的蛋白质,每种氨基酸都用其自己的token来表示即可。
在AF3中,这一点得以保留。但为了处理AF3能够处理的额外分子类型,也引入了额外的token(下图7):
标准氨基酸:1个 token (=AF2) 标准核苷酸:1个 token 非标准氨基酸或核苷酸(甲基化核苷酸、经过翻译后修饰的氨基酸等):每个原子1个 token 其他分子/离子:每个原子1个 token 因此,可以认为一些 token 与多个原子相关联,如:标准氨基酸/碱基。有些 token 仅与一个原子相关联,如:配体分子中的一个原子。意思就是AF3的 token 策略既有 atom-level 也有 residue-level的。
所以,虽然一个含有35个标准氨基酸的蛋白质,可能超过600个原子,但它在AF3由35个 token 表示。一个含有35个原子的配体小分子也由35个 token 表示。
此处 Tokenization 可能阐述的不是很清晰,推荐读者阅读小王随笔的《AlphaFold3的在线应用举例》。
2.2 检索 Retrieval
检索 Retrieval 是为了得到 MSA 表征和模版表征。检索对应AF3原文模型框架图的位置,见下图8的蓝色部分:
在 AF3 中的一个早期关键步骤是检索增强。找到与感兴趣的蛋白质/RNA序列相似的序列(收集到一个多序列比对中,称为“MSA”),以及与这些序列相关的任何结构(称为“模板”),然后将它们作为额外的输入,分别称为表征 和 包含在模型中(图5)。
为什么要获取 MSA 和模板?
在不同物种中发现的同一蛋白质,在结构和序列上可能非常相似。通过将这些序列对齐到一个多序列比对(MSA)中,可以观察蛋白质序列中某个特定位置是如何在进化过程中变化的。你可以将给定蛋白质的 MSA 想象成一个矩阵(上图9),其中每一行都是来自不同物种的类似蛋白质的序列。已经证明,蛋白质特定位置的列中发现的保守模式,可以反映该位置需要某些氨基酸存在的重要性,不同列之间的关系反映了氨基酸之间的关系(即如果两个氨基酸在物理上相互作用,它们在进化过程中的变化可能会相关联)。也就是说对齐的序列(MSA)携带着共进化信息。因此,MSA 通常用于丰富增强单个蛋白质的表示。
同样,如果这些同源蛋白中有任何已知结构,这些结构也可能会提供预测蛋白结构的信息。不是搜索完整的结构,而是只使用蛋白的单条链。这类似于很早期的同源建模方法,即基于已知蛋白结构的模板来建模,查询需要预测蛋白的结构。
MSA 和模版是如何检索的呢?
首先,进行遗传搜索,寻找任何类似于输入蛋白或 RNA 链的蛋白质或 RNA 链。这并不涉及任何模型的训练,而是依赖于现有的基于隐马尔可夫模型(HMM)的方法(jackhmmer, HHBlits, nhmmer等) 来扫描多个蛋白质数据库和 RNA 数据库以寻找相关的同源序列。
接着,这些序列相互对齐,构建一个包含 序列的 MSA。由于模型的计算复杂性与 成比例,限制了 < 。通常,MSA 是从单个蛋白质链构建的,但正如在 AF-multimer 中描述的,不是简单地将单独的 MSA 串联在一起形成一个块对角矩阵,来自同一物种的某些链可以像这里描述的那样被“配对”(下图10)。这样,MSA 就不必那么大且稀疏,并且可以学习链之间关系上的进化信息。
最后,对于每个蛋白质链,使用另一种基于 HMM 的方法(hmmsearch)在蛋白质数据银行(PDB)中寻找与构建的 MSA 相似的序列。选择最高质量的结构,并从中抽取多达4个样本作为“模板”包含在内。
最近,港中文的李煜老师基于蛋白语言模型开发的 Retrieval 工具 DHR,极大的提高了检索速度,同时保持了结构预测精度。推荐继续阅读《Nat. Biotechnol.|基于蛋白语言模型的超快速MSA检索算法》。
在此处,对搜索序列的数据库、检索工具、对齐工具等讲述的明显不够全面,推荐继续阅读小王随笔的《Alphafold2的大超凡序列库》。
与 AF-multimer 相比,AF3 这些检索中唯一的新部分,是现在除了蛋白质序列检索外,还对 RNA 序列进行检索。
如何表征这些模板?
如何表征这些模板,也等于是说怎么得到模版表征的张量 .
从模板搜索中,可得到每个模板的 3D 结构,以及有关哪些 token 位于哪些链中的信息。首先,计算给定模板中所有 token 对之间的欧几里得距离。对于与多个原子相关的 token(如:标准氨基酸),使用一个代表性的“中心原子”来计算距离(下图11)。对于氨基酸来说,这将是 原子,对于标准核苷酸来说,这将是 原子。
这为每个模板生成了一个 x 的矩阵。然而,并不是将每个距离表示为数值,而是将距离离散化为一个“距离的直方图(histogram of distances)”。具体来说,这些值被划分为38个区间,范围在3.15埃到50.75埃米之间,还有一个额外的区间用于任何大于这个范围的距离。
然后,为每个距离直方图附加元数据,包括:
每个 token 属于哪个链的信息; 这个 token 在晶体结构中是否已解析; 以及每个氨基酸内部局部距离的信息。 因为没有尝试选择多链的模板,来获取关于链间相互作用的信息。所以遮盖 Mask 这个矩阵,以便只查看每条链内部的距离(例如:忽略链 A 和链 B 之间的距离)。
注意,尽管模板中没有链间的相互作用,但在构建多序列比对MSA 时,它们确实包含了这些相互作用。
2.3 构建Atom-Level表征
构建Atom-Level表征对应AF3原文模型框架图的位置,见下图12的蓝色部分:
为了创建原子级单表征 ,需要提取所有原子级特征。
第一步是为每个氨基酸、核苷酸和配体计算一个参考构象体(reference conformer). 尽管还不知道整个复合物的结构,但对每个单独组分的局部结构有很强的先验知识。构象体简称confomer,来自构型异构体 conformational isomer,是通过围绕单键旋转采样产生的分子中原子的三维排列。每个氨基酸/碱基都有一个“标准”构象体,这是这个氨基酸可以存在的低能量构象之一,可以通过查找获得。然而,每个小分子都需要生成自己的构象。这些构象是使用 RDKit 的 ETKDGv3 算法生成的,该算法结合实验数据和扭转角度偏好来产生三维构象体。
第二步将这个构象体的信息(相对位置)与每个原子的电荷、原子序数和其他 token 进行拼接 concat。
矩阵 存储了序列中所有原子的这些信息(上图13左边矩阵)。
这里得到了模型最初的第一个张量 , 它是初始的atom-level的single表征, 代表conformer的意思。
然后使用 来初始化原子级的对表示 ,以存储原子之间的相对距离 . 因为只知道每个token内部的参考距离,所以使用一个掩码矩阵 (上图13中间部分) 来确保这个初始距离矩阵,只代表一个构象体中计算的距离。 还包括距离的倒数平方的线性嵌入,加上 和 的投影(上图13顶部部分),并用几个带有残差连接的线性层进行更新(即图13中的3 linear layers with residual connections)。
AF3论文并没有阐明,为什么执行这个额外的倒数距离步骤,也没有包含对其效果的消融研究。因此,只能假设它们被经验性地证明是有用的。
在AF3补充材料中,通常以向量形式 来指代张量 (这代表了原子 和原子 之间的关系)。
最后,复制原子级单表征 。这个矩阵 是将要更新的,但这里的 也被保存并在后面有所使用。
上图13右下角部分,复制原子级单一表示 , 并用 atom transformer 去更新 , 得到最终的原子级别的单表示 . 这里的内容,将在紧接着的下文(更新Atom-Level表征)进行展开。
2.4 更新Atom-Level表征
更新Atom-Level表征,里面嵌入的模块特别多。这部分还会用到Atom transformer(下图14),有关的所有细节将在这里展开。
已经生成的原子级的表征 和 , 现在希望根据附近的其他原子来更新这些表征。
每当AF3在原子级(atom-level)应用注意力机制时,都会使用Atom Transformer模块,它的特点是:
Atom Transformer其实是一系列模块组合而成的(上图14); Atom Transformer的目的,是使用注意力机制结合 和 的原始表征来更新 ; 由于 不会被Atom Transformer 更新(图13的右下角), 这里类似于神经网络的残差连接。 Atom Transformer大体上遵循传统的Transformer架构,使用层归一化、注意力机制,然后是MLP转换。
然而,Atom Transformer每一步都这里都有调整,包括来自 和 的额外输入(在这里包括次级输入有时被称为“条件作用”)。在注意力和 MLP 块之间还有一个‘门控’机制。下文根据下面的总览,详细说明Atom Transformer的四个步骤。
2.4.1 自适应层归一化 (Adaptive LayerNorm)
在介绍自适应层归一化(Adaptive LayerNorm)之前,先回顾一下标准的层归一化(下图15)。
自适应层归一化 (AdaNorm) 是层归一化 (LayerNorm) 的一个变体,它有一个简单的扩展。
传统的层归一化,对于给定的输入矩阵,是去学习两个参数 (一个缩放因子 和一个偏移因子 ),它们调整我们矩阵中每个通道的均值和标准差。
自适应层归一化 (AdaNorm),不是为 和 学习固定的参数,而是学习一个函数。根据输入矩阵自适应地生成 和 。然而,并不是基于输入 得到重新缩放的参数; 而是使用一个次级输入 , 来预测重新缩放 的均值和标准差的 和 。简而言之,图16的 input1 等同于 , 而 input2 等同于 .
2.4.2 带有对偏置的注意力 (Attention with Pair Bias)
原子级带有对偏置的注意力(Attention with Pair Bias),可以被视为自注意力(self-attention)的扩展(下图17)。
就像在自注意力中一样,虽然查询 (queries)、键 (keys)和值 (values) 都来自同一个一维序列,此处序列为原子级single表征 。但是这里的注意力机制,有三个不同之处:
对偏置(Pair-biasing): 在计算查询(queries)和键(keys)的点积之后,pair表征的线性投影被添加为偏置(图17中间浅蓝色),以缩放注意力权重。注意,这个操作不涉及使用来自 的任何信息来更新 , 只是从pair表征 到 的单向流动。这样做的原因是,具有更强成对关系的原子应该更强烈地相互关注,而 已经有效地编码了一个注意力。
门控(Gating): 除了查询(queries)、键(keys)和值(values)之外,创建了 的一个额外投影,并通过一个 sigmoid 函数传递,将值压缩在 0 和 1 之间。输出在所有头重新组合之前乘以这个“门控”(图17底部部分)。这有效地迫使模型忽略它在这一注意力过程中学到的某些内容。这种类型的门控在AF3中经常出现,并且在后文思考部分有更多的讨论。简单来说,由于模型不断地将每个部分的输出添加到残差流中,这种门控机制可以被视为模型指定哪些信息被保存或不被保存在这一残差流中的方式。它之所以被称为“门控”,是因为在 LSTM 中也有类似的“门控”,LSTM 使用 sigmoid 函数来学习一个过滤器,决定哪些输入被添加到运行的状态中。
稀疏注意力(Sparse attention): 由于原子的数量可能远大于 token 的数量,AF3在这一步不运行完整的注意力机制,而是使用一种稀疏注意力(称为:序列局部原子注意力),在这种注意力中,实际上是在局部组中运行注意力,其中每次有32个原子可以全部关注其他128个原子(下图18)。
2.4.3 条件化门控 (Conditioned Gating)
AF3对数据还应用了另一个门控,但这一次门控是从原始的原子级单表征 生成的(下图19)。
这里不清楚AF3为什么要这样做。以及不清楚在原始表征 上进行条件化,或从单表征 中学习门控机制相比,有什么好处。
2.4.4 条件化转换 (Conditioned Transition)
这一步骤相当于 Transformer 中的 MLP 层,之所以被称为“条件化”,是因为 MLP 被夹在自适应层归一化(Atom Transformer 的第1步),和条件门控(Atom Transformer 的第3步)之间(见图14),这两个步骤都依赖于表征 .
本节中另一个值得注意的点是,AF3 在 transition 过渡块中使用 SwiGLU 而不是 ReLU。从 ReLU 过渡到 SwiGLU 是从AF2到AF3的变化,并且是许多架构中常见的变化,因此在此处进行了可视化(下图19)。
在使用基于 ReLU 的过渡层(如AF2中的)时,获取激活值,将它们投影到原始大小的 4 倍,应用 ReLU,然后将其投影回原始大小。
当使用 SwiGLU(在AF3中)时,输入激活会产生两个中间向上投影,其中一个通过 swish 非线性(ReLU 的改进变体)进行处理,然后将它们相乘后再进行向下投影。
2.5 Atom-Level → Token-Level
聚合Atom-Level级的表征 → token-level的表征,对应AF3原文模型框架图的位置,见下图20的蓝色部分:
2.5.1 构建token-level的single表征
到目前为止,所有数据都以原子级别(atom-level)存储。但从这里开始,AF3的表征学习部分将在 token 级别上操作。为了创建这些 token 级别的表征:
首先要将原子级别表征投影到一个更大的维度( =128 , =384)。
然后,对分配给同一 token 的所有原子取平均值。需要注意的是,这只适用于与标准氨基酸和核苷酸相关的原子(通过计算同一 token 上所有原子的平均值);而其余的原子保持不变(下图21)。
现在已经在token level中工作,将 token level 的特征和 MSA 中的统计信息(在可用的情况下)进行拼接concat(图21)。 这个矩阵为 .
在这些拼接之后有所增长,被投影回 ,这个矩阵称为 , 它是序列的起始表征,它将在表征学习部分中被更新。
注意, 在表征学习部分会得到更新,但 被保存下来,以备后续在结构预测部分使用。
2.5.2 构建token-level的pair表征
现在已经创建了 ,初始化的单表征且是token-level的。
下一步是初始化pair表征 。pair表征是一个三维张量,但最容易将其想象为类似热图的二维矩阵,具有隐含的深度维度 =128 通道。 pair表征中的 是一个 维向量(图22),旨在存储关于token序列中token 和token 之间关系的的信息。AF3创建了一个类似原子级的矩阵 ,并且在token-level遵循一个类似的过程。
为了初始化 ,AF3使用线性投影(linear projection,图22左下角)来使序列表示的通道维度与成对表示的通道维度相匹配(384 → 128),并添加得到的 和 。对此,AF3添加一个相对位置编码 。如果用户还指定了token之间的键合方式,那么这些键合在这里被线性嵌入(图22右下角黄色),并添加到pair表征中。
2.5.3 AF3准备阶段的输出
现在已经成功创建并嵌入了AF3模型其余部分将使用的所有输入,图23展示了这些所有的输入:
在紧接着的MSA/模版/pairformer模块中,AF3将搁置原子级别的表征( , , ); 但会在 和 的帮助下,集中更新迭代的token-level表征 和 .
3. 表征学习
AlphaFold3 模型的表征学习(Representation Learning)阶段,对应原文大致就是下图的这些部分,也就是AF3模型的MSA/模版/Pairformer模块。
这一部分是模型的主体核心,通常被称为“主干骨架”,因为它是大部分计算发生的地方。这里将其称为模型的表征学习部分,因为它的目标是学习上面准备阶段得到的 token level的single表征 和pair表征 , 然后对它们进行迭代更新。
表征学习这一部分包含3大模块:
模板模块:使用结构模板表征 更新pair表征 。
MSA模块:首先更新MSA表征 ,然后将其添加到 token level的pair表征 。在这一部分,将花费大量时间在两个操作上:
外积均值(Outer Product Mean):使 能够影响 。 仅使用对偏置的MSA行感应的门控自注意力 (MSA Row-wise Gated Self-Attention Using Only Pair Bias):根据表征 更新 ,这是对偏置注意力的简化版本 (适用于MSAs)。
Pairformer:使用受几何启发(三角形)的注意力更新 和 。这一部分主要描述了三角形操作 (在AF2和AF3中都广泛使用)。
为什么关注三角形?解释了三角形操作的一些直觉。 三角形更新和三角形注意力:都使用类似自注意力的方法更新 ,但受到三角形不等式启发。 对偏置的单一注意力 (Single Attention With Pair Bias):根据 更新 ,是和偏置注意力在 token-level上的等价的 (适用于单一序列)。 以上3个模块,被多次重复,然后整个部分的输出再次作为输入反馈到自身,并且该过程被重复(这被称为回收recycling)。
3.1 模版模块
模版模块(Template Module)对应AF3原文模型框架图的位置,见下图25的蓝色部分:
首先,每个模板 (在示意图中用了2个模版, =2) 通过线性投影 (linear projection),并与AF3的pair表征( )的线性投影相加 (下图26,浅蓝色为pair表征)。
接着,这个新组合的矩阵经过一系列操作,称为 Pairformer Stack (下文将详细描述)。
最后,所有模板被平均在一起,并通过另一个线性层。这里最后的线性层使用了 ReLU 作为非线性激活函数,它是AF3中仅使用 ReLU 作为非线性激活函数的两个地方之一。
3.2 MSA模块
MSA模块对应AF3原文模型框架图的位置,见下图27的蓝色部分:
这个模块与AF2中的Evoformer模块非常相似(下图28),其目标是同时改进MSA表征和pair表征。它对这两个表征独立执行一系列操作,然后还允许它们之间进行交互。
第一步是对MSA的行进行子采样,而不是使用之前生成的MSA的所有行 (这可能多达16,000行),然后向这个子采样的MSA添加AF3单表征的投影版本 (上图28,左上角绿色)。
3.2.1 外积均值 (Outer Product Mean)
外积均值的目的是,利用MSA表征更新pair表征。
接下来,获取到的MSA表征,并通过“外积均值”(图28,紫色模块)将其纳入pair表征 中。
比较MSA的两列可以揭示序列中两个位置之间的关系信息 (例如,这两个位置在进化过程中的相关性如何)。对于每一对 token 索引 , ,遍历所有进化序列,取 和 的外积,然后对所有进化序列进行平均。然后压平这个外积,将其投影回原维度,并将其添加到pair表征 (上图29)。
虽然每个外积仅在给定序列 内比较数值,但当取这些的平均值时 (mean across,图29),就混合了跨序列的信息。这是模型中唯一一个在进化序列之间共享信息的点。这是为了减少AF2中Evoformer的计算复杂性所做的重大改变。
3.2.2 仅使用对偏置的行感应门控自注意力 (Row-wise gated self-attention using only pair bias)
这里的目的是,根据pair表征来更新MSA表征。这种特定的更新模式被称为,仅使用对偏置的行感应门控自注意力 (Row-wise gated self-attention using only pair bias) (下图30)。它是在Atom Transformer部分讨论的带有对偏置的自注意力(self attention with pair bias)的简化版本,独立应用于MSA中的每个序列(行)。
它虽然受到注意力机制的启发,但不是使用查询(queries)和键(keys)来确定每个 token 应该关注哪些其他位置,而是只使用pair表征 中存储的 token 之间已有的关系(下图31)。
在pair表征中,每个 是一个向量,包含关于 token 和 token 之间关系的信息。当张量 被投影到矩阵时,每个 向量变成了一个标量,可以用来确定 token 应该在多大程度上关注 token 。在应用行感应softmax之后,这些现在相当于注意力分数,用于像典型的注意力图一样创建值的加权平均。
需要注意的是,在 MSA 中,由于它是独立地为每一行运行的,所以在进化序列之间没有信息共享。
3.2.3 更新pair表征
MSA模块的最后一步,是通过一系列称为三角形更新和注意力的步骤来更新pair表征。这些三角形操作将在下面与Pairformer模块一起描述,在那里它们被再次使用。还有一些转换块(transition)使用 SwiGLU 来上下投影矩阵,就像在Atom Transformer中所做的那样。
3.3 Pairformer模块
Pairformer模块对应AF3原文模型框架图的位置,见下图32的蓝色部分:
在根据模板模块、MSA模块更新了AF3的pair表征之后,现在在模型的其余部分将忽略表征 和 。相反,只有更新后的pair表征 和single表征 进入 Pairformer,并用于彼此更新。
由于转换块(transition)在2.4.4小节已经被描述过了,本小节重点介绍三角形更新和三角形注意力,然后简要解释“带有对偏置的单一注意力” (Single Attention with Pair Bias) 与前文描述的变体有何不同。这些基于三角形的层首次在AF2中引入,不仅保留在AF3中,而且现在在AF3架构中更加突出(下图33)。
3.3.1 为什么关注三角形(triangles)?
这里的指导原则是三角形不等式的概念:“一个三角形的任意两边之和大于或等于第三边”。回想一下,pair表征中的每个 编码了序列中位置 和 之间的关系。虽然它并不直接编码 token 对之间的物理距离,让我们暂时设想它确实如此。如果我们想象每个 是两个氨基酸之间的距离,我们知道 =1 且 =1。根据三角形不等式, 不能大于. 知道两个距离让我们对第三个距离必须是什么有了大致的要求。三角形更新和三角形注意力的目标是,尝试将这些几何约束编码到模型中。
三角形不等式并不是在模型中强制执行的,而是通过确保每个位置 通过同时查看所有可能的三元组位置 (,,) 来更新,从而鼓励这种不等式。因此, 是基于所有其他原子 的 和 来更新的。因为 代表这些 token 之间的复杂物理关系,而不仅仅是它们之间的距离,这些关系可能是有方向的。所以对于 , 希望它与所有原子 的 和 保持一致性。如果将原子视为一个图,将 视为一个有向邻接矩阵,那么 AlphaFold3 将这些称为“outgoing edges”和“incoming edges”是有意义的。
考虑这个邻接矩阵的第 =0 行,假设我们想要更新 ,它已经用紫色突出显示。更新背后的思想是,如果知道了 0→1 和 2→1 之间的距离,那就给出了 0→2 可能是什么的一些约束(下图33)。同样,如果知道了 0→3 和 2→3 之间的距离,这也给出了 0→2 的一个约束。这将适用于所有原子 。
因此,在三角形更新和注意力机制中,AF3有效地检查了这个图中所有3个节点的所有有向路径。这也就是三角形,这就是它的名字的由来!
3.3.2 三角形更新
从图论的角度仔细研究了三角形操作后,从AF3可以看到这是如何通过张量操作实现的。
在出边(outgoing)更新中,pair表征中的每个位置 都会根据同一行 () 中的其他元素的加权组合独立更新,其中每个 的权重基于其出边三角形中的第三个元素 。
实际上,对 进行三次线性投影(inear projection),得到 a、b 和 g。 和 除了线性投影,还进行了sigmoid(下图35)。
为了更新 ,取 a 中第 行和 b 中第 行的逐元素乘积。然后对所有这些行求和(不同的 值),并用投影 g 进行门控(下图35)。
对于入边(incoming)更新,实际上执行相同的操作,但是将行与列翻转,因此为了更新 ,取同一列 () 中其他元素的加权和,其中每个 的权重基于其出边三角形中的第三个元素 () 。创建相同的线性投影后,取 a 中第 列和 b 中第 列的逐元素乘积,并对此矩阵的所有行求和。你会发现这些操作完全反映了上述图论邻接视角的描述。
3.3.3 三角形注意力
在两个三角形更新步骤之后,AF3还使用出边的三角形注意力和入边的三角形注意力来更新每个 。AF3论文将“出边”称为“围绕起始节点”的注意力,将“入边”称为“围绕结束节点”的注意力。
为了逐步理解三角形注意力,从典型的一维序列自注意力开始可能会有所帮助。回想一下,查询(queries)、键(keys)和值(values)都是原始一维序列的转换。一种称为轴向注意力的注意力变体,通过在二维矩阵的不同轴上独立应用一维自注意力来扩展这一点 (首先是行,然后是列)。三角形注意力将之前讨论的三角形原则添加到这一点上,通过结合所有原子 的 和 来更新 。
围绕起始节点的注意力。 具体来说,在“起始节点”的情况下,为了计算第 行的注意力分数(以确定 应受 影响的程度),像通常那样在 和 之间进行查询(queries)-键(keys)比较,然后根据 偏置注意力,如下图37所示。
围绕结束节点的注意力。 对于“结束节点”的情况,我们再次将行换成列。对于 ,键和值都将来自pair表征 的第 列,而偏置Bias将来自第 列。因此,当比较查询(queries) 与键(keys) 时,根据 来偏置那个注意力分数。然后,一旦我们对所有的 有了注意力分数,就使用来自第 列的值向量(下图38)。
3.3.4 带有对偏置的单一注意力 (Single Attention with Pair Bias)
现在已经通过这四个三角形步骤 (outgoing 和 incoming三角形更新,outgoing 和 incoming的三角形注意力)更新了pair表征, 接着将pair表征通过前文描述的 transition 转换块。最后,希望使用这个最新的pair表征 ,来更新模型的single表征 ,因此将使用基于对偏置的单一注意力(Single Attention with Pair Bias),如下图39所示。
这与Atom Transformer部分描述的“基于对偏置的单一注意力”基本相同,但这里是在token-level的。由于它在 token-level 上操作,它使用完全注意力。而没有使用Atom Transformer原子级别(atom-level)操作时那样,采用块状的稀疏模式的注意力策略。
AF3重复 Pairformer模块48次之后,最终得到的single表征称为 ,pair表征称为 。
4. 扩散模块
AlphaFold3模型采用扩散模块,进行结构预测,对应原文大致就是下图40的这些内容。
4.1 扩散的基础
经过表征学习阶段,提取到了精炼的表征,接着使用这些表征 和 来预测分子复合物的结构。AF3 相对 AF2 引入的变化之一是,整个结构预测是基于原子级别的扩散模型。
扩散模型的基本思想,是从真实数据的蛋白原子坐标开始,向数据中添加随机噪声,然后训练一个模型来预测添加了哪些噪声。
在扩散期间,在一系列 时间步长中,通过迭代向数据中添加噪声,创建每个数据点的 个变体。原始数据点称为 ,完全噪声化的版本称为 。
在训练期间,在时间步 ,模型被给予 ,模型预测在 和 之间添加了哪些噪声。根据预测的噪声与实际添加的噪声之间的差异,去训练模型拥有正确的去噪能力。
然后,在推理时从随机噪声开始,这相当于 。对于每一个时间步,预测模型认为已经添加了噪声,并预测噪声和移除噪声。经过预先指定的若干时间步后,最终得到一个完全“去噪”的预测数据点 ,它应该类似于训练数据集中的原始数据。
条件扩散允许模型根据某些输入对这些去噪的预测进行“条件化”。实际上,这意味着对于模型的每一步,它接收三个输入:
生成的当前噪声迭代 当前所处时间步的表征 期望的条件化信息 结果,扩散模型最终生成的不是一个随机的东西,而是类似于训练数据分布的蛋白结构,且应该特别匹配由当前条件向量所代表的样本。
在 AF3 中,模型学习去噪的数据是一个矩阵 (见下图41),这个矩阵包含蛋白结构中所有原子的 、、 坐标。在训练期间,向这些坐标添加高斯噪声,直到它们变成完全随机的值。然后在推理时,模型从随机坐标开始。
首先,在每一个时间步,随机旋转和平移模型预测的整个蛋白复合物结构。这种处理增强并教会了模型,蛋白复合物的任何旋转和平移都是等价的,它完全取代了 AF2 中更复杂的 IPA 模块。
然后,向坐标添加少量噪声,以鼓励生成更多样化的结构。
最后,使用扩散模块预测一个去噪的噪音。
4.2 扩散模块
本文将在下面更详细地介绍这个 AF3 的扩散模块(下图42)。
在每个去噪扩散步骤中,AF3 会对输入序列的多种表征进行条件化预测,这些表征有:
主干网络输出的表征,即 post-Pairformer 更新后的 和 ,现在称为 和 输入嵌入器创建的初始 atom-level 的single表征 ,初始的 token-level 的 single 表征 ,这些表征未通过主干网络 AF3 的扩散模块可以分解为:token → 原子 → token → 原子,更详细来讲,这4个步骤是:
准备 token-level 的条件张量 准备 atom-level 的条件张量,使用原子级注意力,并将它们聚合回 token-level 在 token-level 应用注意力机制,并将结果投影回 atom-level 在 atom-level 应用注意力机制以预测原子级别的噪声更新 4.2.1 准备token-level的条件张量
准备token-level的pair表征。为了初始化 token-level 条件表征,需要将 与相对位置编码连接起来,然后将这个更大的表示投影回较小的尺寸,并通过几个带有残差连接的转换块(下图43)。
准备token-level的single表征。同样地,对于 token-level的 single表征,将模型开始时创建的输入最初表征( )和我们当前的表征( )连接拼接起来,然后将其投影回原始大小(下图44)。然后根据当前的扩散时间步创建一个傅里叶嵌入(下图44,底部部分),将其添加到 single 表征中,并将该组合通过几个转换块(transition)。得到最终的条件化的 。
通过在此处的条件输入中包含扩散时间步,它确保了模型在进行去噪预测时知道扩散过程中的时间步,并因此预测出该时间步应该移除的噪声的正确规模。
4.2.2. 准备atom-level的条件张量
准备 atom-level 的条件张量,使用原子级注意力进行更新,然后将它们聚合回 token-level。
模型条件向量存储的是每个 token-level 的信息,但AF3 还想在 atom-level 运行注意力机制。为了解决这个问题,在此处 AF3 利用了模型前面的 atom-level 的single表征( 和 ),并根据当前的 token-level 的表征去更新它们(下图45、46),以创建 atom-level 的条件张量。
接下来,将原子的当前坐标 按数据的方差进行缩放,有效地创建了具有单位方差的“无维度”坐标,称为 。然后根据 更新 ,使得 现在知道了原子的当前位置。最后,使用原子变换器(Atom Transformer)更新 (该变换器也接受 pair 表征 作为输入),并将原子重新聚合 token-level 得到 (下图47),就像前文描述的那样。
在这一步的最后,模型输出的返回值有:
:结合了有关原子坐标信息更新后的 atom-level 的 single 表征(图47,中间) : 的 token-level 聚合形式,并捕获了坐标和序列信息(图47) :融合了主干single表征 的 atom-level 的 single表征(图45) :更新的条件化 atom-level 的pair表征(图46) 4.2.3. 在token-level应用注意力,并将结果投影回atom-level
在 token-level 应用注意力机制,并将结果投影回 atom-level。这一步的目标是,应用注意力机制来更新关于原子坐标和序列信息的 token-level表征,即应用注意力机制更新 。这里使用的是 Diffusion Transformer(图48),它基本上和前文2.4小节介绍的 Atom Transformer 一样(图14)。但 Diffusion Transformer 是针对 token-level 表征的,Atom Transformer 是针对 atom-level 表征的,所以在此处不再深入介绍 Diffusion Transformer。
4.2.4. 在atom-level应用注意力,去预测atom-level的噪声更新
现在,返回到原子空间。使用更新后的 (基于当前“中心原子”位置的 token-level 表征),来更新 (基于当前位置的所有原子的 atom-level 表征)。 这里使用的原子变换器(Atom Transformer)。
正如上一步所做的,将 token-level 表征广播,以匹配开始时的原子数量(有选择地复制代表多个原子的 token),并运行原子变换器(Atom Transformer)。
最重要的是,最后一个线性层将这个原子级别表征 映射回 。
这是至关重要的一个步骤:使用所有这些条件表征,生成所有原子的更新坐标 。
现在,因为是在“无维度”空间 中生成这些更新坐标的,仔细地将更新从 重新缩放到具有非单位方差的形态 ,并应用更新到 。
通过这些介绍,这里已经完成了对AlphaFold3主体架构的介绍!下文将介绍有关损失函数、置信度头和训练细节的一些额外知识。
5. 损失函数
5.1 损失函数和置信度头
损失函数是机器学习中用于衡量模型预测与实际观测之间差异的量化方法,它是一个非负可测函数,用于指导模型训练过程中参数的优化。在AlphaFold 3中,损失函数由3个部分组成,以加权和的形式组成总的损失函数。
:评估预测的距离直方图在 token-level 的准确度。 :评估 atom-level 预测的距离直方图的准确度,包括所有原子对之间的距离,并考虑邻近原子和参与蛋白质-配体键合的原子之间的距离。 :评估模型对结构预测准确性的自我评估,即模型对自己预测结构的准确度有多自信。 这里的 和 是准确度或精确度的损失,它一定是和分子的真实结构去计算的。而 是置信度Head,它是一个信心度指标。就好比你考试完,但还不知道标准答案,“感觉”自己能考多少分。
这3种损失函数共同作用,帮助 AlphaFold3 在训练过程中优化其参数,以更准确地预测蛋白质和其他生物分子的结构。通过这种方式,AlphaFold3能够提高其预测的准确性,并在必要时对模型的预测结果进行校正。
5.1.1
AF3 模型的输出是原子级的坐标,这些坐标可以很容易地用来创建原子级(atom-level)的 distogram。回想一下,最初的 distogram 是通过将原子之间的成对距离分箱来创建的(图11的位置有介绍)。
然而,这里的 损失是评估的 token-level 的 distogram。要获取 token 的 坐标,我们只需使用“中心原子”的坐标。预测的 distogram 然后通过与真实的 distogram 进行比较,计算交叉熵损失。
5.1.2
扩散损失 ,它本身是三个项的加权和,每一项都是根据原子位置计算的,另外还根据当前时间步添加的噪声量进行了缩放。
其中,,代表当前时间步采样的噪声水平; ,是数据的方差,它调节每个时间步的噪声量。扩散损失 的包含下面3个损失函数:
,是一种 distogram 损失,我们刚刚讨论过,但它是针对所有原子而不仅仅是针对“中心原子”(并且对DNA、RNA和配体原子进行了加权)。此外,它查看位置之间的均方误差,而不是将它们分箱到 distogram 中。对比 , 仅针对中心原子,进行了分箱,且是token-level的。
,旨在通过在预测和真实 distograms 的原子对差异上添加额外的 MSE 损失,确保蛋白质-小分子配体键的键长准确性,这些原子对是蛋白质-小分子配体键的一部分。所以 是仅针对蛋白-小分子配体体系设置的。在训练的不同阶段, 在初始阶段被设置为0,所以这个参数是在后面阶段才引入的。
(平滑局部距离差异测试),是 distogram 损失的另一种变体,试图捕捉局部距离的准确性。如果原子对的预测距离在原子对的真实距离的给定阈值内,则“通过测试”。为了使这个度量平滑且可微分,将预测和真实 distograms 之间的差异,通过一个以测试阈值为中心的 sigmoid 函数。可以将其视为生成一个概率(介于0和1之间),这个原子对通过了测试。模型取四个“测试”的平均值,阈值越来越紧(4, 2, 1和0.5 Å,图49)。使用这个损失鼓励模型减少未通过每个测试的概率。最后,为了使测试“局部化”,如果原子对的真实距离很大,忽略原子对的损失,因为只希望模型专注于准确预测一个原子与附近原子的距离。
举例来说,对于一个原子对 ,如果原子 是核苷酸的一部分,并且 和 之间的距离超过 30Å,AF3会忽略 和 的损失。如果 和 之间的距离超过 15Å,并且 不是核苷酸,而是蛋白质或配体的一部分,也忽略 和 的损失。因为 只关注局部距离的准确性。
5.1.3
置信度损失的目标不是提高结构的准确性,而是教导模型预测自己的准确性,给预测结构提供一个置信度打分。这个损失是4个项的加权和(下图50),每个项都对应一种评估预测结构质量的方法:
lDDT,原子级的“局部距离差异测试(local distance difference test)”,捕捉一个原子预测距离与附近原子的预期准确性。
PDE,token 之间预测的距离误差(Predicted distance error),捕捉所有 token 之间预测差异的准确性。
, 实验上解析的预测,模型预测哪些原子是实验上解析的。并非每个晶体结构中的每个原子都是实验上解析的,有些原子、氨基酸有时候在晶体结构中是缺失的。
PAE,预测的对齐误差(Predicted alignment error),是第 个 token 的预测位置和真实位置之间的误差。首先将预测的 token 和真实 token 旋转和平移到 token 的框架中。也就是说,如果暂时假设 token 完全在其真实位置,AF3 预测 token 与它应该在的位置有多接近,这是基于它与token 的关系。
为了得到以上每个指标的置信度损失,AF3预测这些误差指标的值,然后在预测结构上计算这些误差度量,损失基于这两个之间的差异。所以即使晶体结构确实不正确,且本身 PAE 就很高;如果预测的 PAE 也很高,但损失 此时就会很低。
这些置信度指标的预测,是在扩散过程的中间阶段生成的。在选定的扩散步骤 时,使用的预测坐标 用于更新表征学习主干中创建的 single 和 pair 表示(上图50)。然后,基于更新后的 pair 表征(对于 和 )或更新后的single 表征(对于 和 ${L_{\text{resolved}}$ )计算预测误差。然后,基于同样生成的原子坐标计算实际误差指标。如果感兴趣,下一段有详细的介绍。
虽然这些术语被包括在置信度 Head 损失中,但这些术语的梯度仅用于更新置信度预测 Head,不影响模型的其余部分。
5.2 置信度指标如何计算?
5.2.1 pLDDT
pLDDT,原子 的 LDDT 是按以下方式计算的。在当前预测的结构中,计算原子 与一组原子 之间的距离, 由 索引,并将此与真实的等效距离进行比较。要成为这个集合的一部分,原子 必须是高分子链的一部分,根据 m 所属的分子,与 的距离在 15 或 30 Å 之内,并且是标记的中心原子。然后我们计算四个二元距离测试,阈值越来越严格(4, 2, 1, 和 0.5 Å),并取平均通过率,并在 R 中的原子上求和。我们将这个百分比分到 0 到 1 之间的 50 个区间内。
在推理时,AF3 有一个 pLDDT 头。这个头取一个给定标记的单个表示,将其重复扩展到这个标记“附加”的所有原子上,并将所有这些原子级表示投影到我们的 pLDDT_l 的 50 个区间上。我们将这些视为 50 个“类别”的 logits,使用 softmax 转换为概率,并在区间上使用多类分类损失。
5.2.2 PAE
预测对齐误差(PAE):每个标记都被视为有一个框架,即由三个原子(称为 a、b、c)创建的 3D 坐标框架,这些原子涉及该标记。这三个原子中的 b 形成这个框架的原点。在每个标记只有一个原子“附加”的情况下,框架的中心原子是标记的单个原子,另外两个最近的相同实体(例如,相同的配体)的标记构成框架的基础。对于每一对标记(i,j),我们使用标记 j 的框架重新表达标记 i 的中心原子的预测坐标。我们对标记 i 的中心原子的真实坐标也做同样的处理。这些转换后的真实和预测坐标之间的欧几里得距离就是我们的对齐误差,分到 64 个区间内。我们从标记对 zi,j 的配对表示中预测这个对齐误差,将其投影到 64 维,我们将其视为 logits 并使用 softmax 转换为概率。我们用分类损失训练这个头,每个区间作为一个类别。更多细节请参阅这里。
5.2.3 PDE
第三,AF3 预测标记之间的距离误差(PDE)。真实的距离误差是通过计算每一对标记的中心原子之间的距离,并将其分到 64 个均匀大小的区间内,从 0 Å 到 32 Å 来计算的。预测的距离误差来自于将配对表示 zi,j 加上配对表示 zj,i 投影到 64 维,我们再次将其视为 logits,并再次使用 softmax 转换为概率。
最后,AF3 预测在真实结构中每个原子是否被实验解析。与 pLDDT 头类似,我们将 单个表示重复扩展到这个标记代表的原子数量,并投影到 2 维上,并使用二元分类损失。
6. 训练细节
现在架构已经介绍完毕,最后的部分是一些 AF3 模型额外的训练细节。
6.1 循环利用 (Recycling)
正如在 AF2 中引入的那样,AF3 的表征学习阶段也同样循环利用其权重;也就是说,模型并没有变得更深,而是通过重复使用权重,并将输入多次传递给模块,以不断改进 single 和 pair 表征。
在推理预测结构时,扩散模块本质上使用了循环利用,因为模型被训练以整合时间步信息,并在每个时间步使用相同的模型权重。
6.2 交叉蒸馏 (Cross-distillation)
AF3 使用了通过自身生成的数据(自蒸馏)以及通过 AF2 生成的数据,即交叉蒸馏的方式。具体来说,作者指出,AF3 通过切换到基于扩散的生成模块,模型不再能产生的“意大利面”正确区域(下图51),这些区域允许 AF2 的用户直观地识别低置信度和可能的无序区域。 仅仅通过观察基于扩散的生成结果,所有区域看起来都同样高度可信,这使得出现错误的幻觉结构(下图51,右边结构)。
为了解决这个问题,作者在 AF3 的训练数据中包括了来自 AF2 和 AF-Multimer 的生成数据,允许模型学习到无规则区域。当 AF2 对其预测不够自信时,它应该输出这些展开的区域,并“指导”AF3 也这样做。
6.3 裁剪和多阶段训练 (Cropping and Training Stages)
虽然模型的任何部分都没有对输入序列的长度有明确的限制,但随着序列长度的增加,内存和计算需求会显著增加。回想一下,有多个 操作。
因此,为了效率,蛋白质会被随机裁剪Cropping。正如在 AF-multimer 中介绍的一样,因为想要模拟多个链之间的相互作用,随机裁剪需要包含所有这些链。AF3使用了3种裁剪(Cropping)方法,这3种方法根据训练数据的不同比例使用。这三种 Cropping 策略为:
连续裁剪(Contiguous cropping):为每条链选择氨基酸的连续序列。 空间裁剪(Spatial cropping):基于与参考原子的距离选择氨基酸,通常这个原子是特定链或感兴趣的结合界面的一部分。 空间界面裁剪(Spatial interface cropping):与空间裁剪类似,但是基于与特定结合界面的原子的距离。 虽然在随机裁剪的384个序列上训练的模型可以应用于更长的序列,但为了提高模型处理这些序列的能力,AF3 会在更大的序列长度上进行迭代微调。数据集的混合和其他训练细节也会在每个训练阶段变化,如下图52所示。
6.4 冲突 (clashing)
作者指出,AF3 的损失函数中虽然不包括对重叠原子的冲突惩罚,即clashing。但 AF3 在对生成的结构进行排名时确实采用了冲突惩罚。 转为基于扩散的结构预测模块,意味着模型在理论上可以预测两个原子位于同一位置,但在训练好的 AF3 出现 clashing 的这种情况似乎很少见。
6.5 批次大小 (Batch sizes)
尽管扩散过程听起来相当复杂,但它在计算上仍然比模型的主干部分(表征学习阶段)要简单得多。因此,AF3的作者发现,从训练的角度来看,在主干之后扩大模型的批次大小(Batch sizes)更为高效。因此,对于每个输入结构,它都会通过嵌入和trunk,然后应用48个独立的数据增强版本的结构,这48个结构都会并行训练。
训练过程就是这样!还有一些其他的小细节,这里可能没有介绍到。如果你已经读到这里,剩下的内容您应该很容易从阅读AF3的补充材料中掌握。
7. 思考总结
在前文深入探讨了 AF3 架构以及与 AF2 的比较之后,有趣的是,AF3 作者们所做的改进,是很符合更广泛的机器学习趋势的。
7.1 检索增强
在 AlphaFold2 发布之前,模型推理时包含数据库的检索并不常见。AlphaFold 是利用了多序列比对 MSA和模板的检索。检索 MSA 的方法被用于蛋白质建模,但在深度学习的其他领域,这种检索的情况很少存在。
尽管 AF3 与 AF2 相比减少了对 MSA 的强调,MSA 不再在 Evoformer/Pairformer 的 48 个块中操作和更新。但 AF3 单独包含了模版模块和 MSA 模块。 即使其它蛋白质预测模型(如:ESMFold)已经放弃了检索,转而支持完全参数化的推理。
有趣的是,一些最大和最成功的深度学习模型现在通常在推理时包含类似的额外信息。虽然检索系统的细节并不总是公开的,但大型语言模型经常在推理时使用检索增强,例如:传统的网络搜索,以将模型导向相关信息(即使这些信息可能已经在其训练数据中)来指导推理。未来在推理时使用直接相关示例,这个方向发展将是一件有趣的事情。
7.2 对偏差注意力
对偏差注意力(Pair-Bias Attention)是 AF2 中的一个主要组件,在 AF3 中更是如此。
像自注意力(self-attention),查询(query)、键(key)和值(value)都来自同一来源。而这里,对偏差注意力(Pair-Bias Attention,从另一个来源添加了一个偏差项(Bias)到注意力机制中。这有效地充当了信息共享的轻量级版本,而没有使用完全的交叉注意力(cross attention)。对偏差注意力几乎出现在每个模块中。虽然这种类型的注意力现在在其他蛋白质结构预测模型中使用,但其他领域很少使用这种特殊类型的注意力。也许它之所以在这里效果良好,是因为 pair 表征自然就类似于自注意力,但对偏差注意力是纯自注意力或纯交叉注意力的一个有趣的替代方案。
7.3 自监督训练
像 ESM 这样的自监督训练(Self-supervised training)模型,通过使用自监督预训练替换检索 MSA,也能进行蛋白结构预测任务。AF2 模型有一个额外的任务,即预测 MSA 中的掩码 token,实现了类似的自监督。但在 AF3 中这一任务被移除,我们没有看到 AF3 的作者对此进行解释或评论。这样处理实际上减少了处理 MSA 的计算量。AF3 取消使用自监督学习来初始化 MSA 嵌入的3个可能原因是:
他们认为庞大的预训练阶段是计算资源的次优使用方式 他们尝试过,并发现包含一个小的 MSA 模块的性能优于预训练嵌入,并且值得增加推理成本 对于他们的混合原子-标记(atom-token)结构,使用预训练嵌入对氨基酸 token 和随机初始化嵌入对 DNA/RNA/配体进行训练可能不兼容或表现不佳。 7.4 损失设计
与 AF2 一样,AF3 继续使用 MSE(均方误差)和分箱分类损失的混合。分类部分很有趣,因为如果模型预测的 distogram(距离直方图)箱仅“差一”,它就不会因为接近而不是完全错误而获得“信用”。目前尚不清楚是什么影响了这一损失设计决策,但也许作者发现与使用几种不同的 MSE 损失相比,这种设计的梯度更稳定,而且每个原子的损失经过如此多的梯度步骤,以至于连续损失的额外信号可能并没有证明是有益的。
7.5 回收机制
AF3的架构的 Recycling 策略,让人联想到循环神经网络(例如:LSTMs)的设计元素,这些元素在传统的 Transformer 中通常不会见到:
广泛的门控机制(Gating):AF3在其架构中使用了门控机制,,来控制残差流中的信息流。这与 LSTM 或 GRU中 的门控更为相似,这不是标准 Transformer 的特点。
迭代循环使用权重:AF3多次应用相同的权重来逐步改进其预测。这个过程,包括 Recycling 和扩散模型的迭代步,类似于循环网络 RNN 如何在时间步骤上使用一组共享的权重处理序列数据。它与标准 transformer 不同,它通常在单次前向传递中做出预测。这种方法允许 AF3 迭代改进预测的蛋白质结构,而不会增加参数数量。
自适应计算(Adaptive):Recycling 也类似于在扩散中使用的迭代更新,并且与自适应计算时间(Adaptive compute time,ACT,https://arxiv.org/abs/1603.08983)的概念非常相近,自适应计算时间 ACT 最初是为了动态决定 RNNs 应该使用多少计算资源而引入的,最近在 Mixture-of-Depths(https://arxiv.org/pdf/2404.02258)中用于实现类似的目的,与 Transformer 的固定深度形成对比。理论上,这将允许模型对具有挑战性的输入,应对更多的处理。
在AF2的消融研究中表明,循环利用是非常重要的,但对于门控的重要性讨论不多。可以推测它有助于训练变得稳定。但门控在许多其它基于 Transformer 的架构中却不那么常见。
7.6 数据的交叉蒸馏(Cross-distillation)
使用 AF2 生成的数据,来重新引入其在低置信度区域的独特风格是一个非常有趣的做法。如果从中吸取教训,那可能是最实用的一点是:如果你的旧模型(如:AF2)在某个特定方面比你的新模型(如:AF3)做得更好,你可以尝试交叉蒸馏的策略,让新模型保留原有的能力!
小编总结
AlphaFold3 作为我们 AI4Protein 领域的基础模型,值得很多下游任务的模型进行借鉴和学习。小编资质愚钝,写到这里还是不太掌握 AlphaFold3 的细节。 所以必须尽量将本文写的清晰、准确,方便自己也方便大家,后面进行追溯和反复阅读。
两万字长文,难免有错误、遗漏、疏忽之处,十分欢迎读者在留言区批评指正,共同学习一起进步,笔芯。
参考资料
https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/