Equitorch: 基于pyg的模块化等变图神经网络包

学术   2024-11-12 09:33   韩国  

我们发布了等变图神经网络包Equitorch,以模块化的形式集成了大多等变算子,基于pyg构建图神经网络,并提供了详细的文档、示例与教程。

  背景

近两年来,随着AI for Science任务的兴起,3D等变图神经网络(3D-Equivariant GNN)获得了广泛的关注,出现了eSCN、NeuqIP、SEGNN、Equiformer(V2)、DPA(-2)等模型,并在分子几何生成、能量预测、粒子运动求解等领域获得了优异表现。


3D旋转等变函数:φ(Rx)=Rφ(x),若将函数的输入进行旋转,则函数的输出也会相应旋转。
(图片源自https://arxiv.org/abs/2110.02905)

然而,即使等变图神经网络已取得相当的进展,我们发现现有的等变神经网络实现风格及方式相当多样,不同的工作中存储、表示等变特征的形式甚至可能不同。即使对已经相当了解等变相关知识的人而言,将不同工作中的模块对齐也要花费大量精力,这一定程度上增加了将不同工作中的技术进行迁移、组合的成本;而对具有纯粹AI背景,希望切入等变神经网络领域进行探索的研究者而言,复杂的数学概念与不统一的实现更是大幅增加了学习的成本。

在这种背景下,我们提出了一个模块化的、基于Pytorch-Geometric的包Equitorch,希望能够使研究者能够更加灵活地构建等变图神经网络。

  特点

  • 在Equitorch中,我们以一种模块化的方式集成了现在工作中使用的大多等变算子以及其它相关辅助函数,并统一这些操作的数据格式约定。
  • 对图神经网络相关的部分,如消息传递机制、图数据表示等,我们完全基于Pytorch-Geometric这一广泛使用的图神经网络框架,更方便传统图神经网络的研究者迁移。

  • 我们正逐渐提供详细的文档、示例与教程,尽可能清晰地展示Equitorch中操作的定义与用途。

统一数据格式:

在Equitorch中,我们对数据的维度含义进行了约定:数据的第一个维度总是样本维度,即节点指标或边指标,最后一个维度总为特征维度,所有几何相关的维度均在中间。这里的几何维度可以是笛卡尔坐标x、y、z,等变特征的度(l)与阶(m),或是球面角(θ与φ)的网格坐标。

如下图即展示了旋转矩阵(左)与等变特征(右)的数据排布。

左边存储旋转矩阵的张量维度为4×3×3×4,右边存储等变特征的张量维度为4×9×4。

当前实现的操作:

在Equitorch中,我们现在主要实现的操作有:

  • 神经网络模块equitorch.nn

    • 等变线性变换

    • (加权)等变张量积

    • 等变激活函数

    • 等变层归一化

    • 基础注意力操作

    • (径向、角度与球谐)基展开

    • 软截断操作

  • 数学函数equitorch.math

    • 球谐张量、球谐函数、Wigner D矩阵的相关操作

    • 球谐变换与逆球谐变换

  • 数据变换equitorch.transform(基于Pytorch-Geometric的Data数据类型)

    • 通过点坐标构建几何图、获取边的方向向量、长度嵌入、球谐嵌入等

  • 其他通用功能性操作equitorch.utils

我们实现了当下等变神经网络中大部分基础操作,基于这些操作,可以十分灵活地搭建各种等变图神经网络架构。

文档与教程:

在实现操作之外,为明确Equitorch中使用的数学约定与操作定义,我们同时也编写了比较详细的文档,希望可以方便使用者对操作的理解。


TensorProduct张量积模块的文档

另外,我们也提供了面向操作的教程,尽可能针对操作规则进行介绍,减少复杂的数学背景引入,以方便使用者更快地上手等变神经网络。目前教程包括了对球张量、张量积、SO(3)等变线性操作、以及搭建SO(3)等变图神经网络的部分。



等变性介绍


张量积介绍



张量积规则说明



SO(3)线性层的等变性说明

目前,我们只实现了比较基础的操作,在将来我们还会基于这些操作逐渐实现更多的模型。我们已经将Equitorch上传到pip欢迎研究或希望研究等变神经网络各位老师与同学尝试我们的工作,并对我们的工作提出宝贵的意见和建议。

项目地址: https://github.com/GTML-LAB/Equitorch

文档地址(点击“阅读原文”跳转): https://equitorch.readthedocs.io/en/latest/ 

DrugAI
关注人工智能与化学、生物、药学和医学的交叉领域进展,提供“原创、专业、实例”的解读分享。
 最新文章