消息传递等变图神经网络-学习笔记01

学术   2024-12-21 00:36   中国台湾  


这是【未来已来】系列学习笔记,为将来会发生的事情而学习,做好准备。目前神经网络势函数发展的技术路线非常多,尤其是基于神经网络势函数的大模型,会是未来AI for Science领域最重要的发展方向,笔者认为最有希望取得决定性突破的技术路线是消息传递等变图神经网络(equivariant message-passing neural networks。这个学习笔记不保证内容的正确性,也不会过多介绍背景,仅仅是一个学习记录,大概会写十几篇笔记,从算法到代码实现,一点点来记录学习过程。


所谓的消息传递等变图神经网络,大致可以分成两个独立的技术,一个是消息传递神经网络(MPNN),一个是等变表示(equivariant representations)。根据发展的时间线看,MPNN是2017年提出来的,equivariant 是2020年以后了。所以,就按照时间线看看 图神经网络怎么消息传递的。

图(Graph)由节点(Node)和边(Edge)组成,用G=(V,E)表示,其中V 是节点集合,E 是边集合。每个节点和边可以携带特定的特征信息,分别用hv 和euv 表示。如下图:

很明显,这样图表示非常适合应用到分子/材料体系的表示,原子是node,原子之间的相互作用是edge:

MPNN的核心思想是通过边传递信息,使节点能够融合邻居的信息,从而更新自身的表示。【注意:这里的关键就是怎么更新自身的representation】以最原始的,MPNN框架为例子:

Message Passing(消息传递阶段)。在图中,每个节点从其邻居节点收集信息(称为消息),然后汇总这些信息来更新自身的表示。对于节点v,它接收邻居节点u 的消息:

可以分成3个部分。

(1)消息生成Mt。

(2)消息聚合

(3)节点状态更新

比如PaiNN网络:

消息生成Mt说白了就是把节点v周围的信息,揉搓到一起,啥信息呢,比如:

这里有一点要注意,最早2017年Justin Gilmer的设想是把元素的本征性质,比如杂化类型/是否芳香性/结合几个H原子等信息都引入作为节点信息,把化学键(单键/双键/三键)作为边的信息。实际后面做力场的时候根本不用这么多信息,直接就元素编号就可以了,比如周期表119个元素,直接构建一个119维向量,来表示不同的元素即可。边的信息也不需要预设化学键类型,这样在应用的时候更方便。

为了方便理解,可以用个CH4的例子,简单理解:

边的话可以用边列表表示法,或者稀疏邻接矩阵表示法。

边列矩阵为例子,每条边可以用一个向量表示键的类型,例如单键为[1]

稀疏矩阵为例:

那么具体网络构建的时候还是要比这个复杂一些的,后面再详细讨论。尤其是向量生成函数的处理方式可以很多样,然后再嵌套个MLP多层感知肌,套上激活函数。

消息聚合对所有邻居节点  的消息进行聚合,得到更新节点 表示的中间值。聚合的方式非常多,这也是不同网络里可以调节的东西,比如简单求和/均值/最大值:

真实情况比如paiNN里,还要先split把message分成三份,再分别处理

节点状态更新:通过更新函数 Ut,将聚合后的消息 mv(t+1)和当前节点状态 hv(t) 合并,得到新的节点状态 hv(t+1)

最简单的方式当然是直接覆盖,或者加上权重做非线性变换。但是这样会引入很多不稳定的因素,使得网络训练过程中不稳定不收敛,爆炸。当然可以在这个传递的过程中可以引入ResNet

当然消息传递可以进行多轮,每次称为一个1-hop。在torch_geometric就有MessagePassing框架可供使用。

import torchfrom torch_geometric.nn import MessagePassingfrom torch_geometric.utils import add_self_loops

学术之友
\x26quot;学术之友\x26quot;旨在建立一个综合的学术交流平台。主要内容包括:分享科研资讯,总结学术干货,发布科研招聘等。让我们携起手来共同学习,一起进步!
 最新文章