AlphaTensor的矩阵乘法算法自动化解读

文摘   科技   2022-10-07 21:27   上海  
本公众号的推送以互联网大数据技术为主,是《人工智能安全》《Python爬虫大数据采集与挖掘》等课程的配套号。内容涉及大数据采集、存储、分析挖掘的模型算法、隐私等技术问题,其特色为原创性、技术性。
Python爬虫大数据采集与挖掘(PPT、代码、视频)
人工智能安全(PPT、在线实验、视频)


矩阵乘法是一种经典计算,在大数据、人工智能算法中广泛使用。Deepmind团队2022.10.5(实际上投稿于一年前)在Nature上发表了AlphaTensor一种矩阵乘法算法自动发现的强化学习方法(Discovering faster matrix multiplication algorithms with reinforcement learning)。

我们先看看AlphaTensor做了什么?对于矩阵乘法(4*4*4)运算,目前最好的结果是Strassen于1969年所提出的算法,需要49次基本乘法。

然而AlphaTensor却找到了47次乘法的计算步骤,实现了50多年来的突破,也许有些数学家奋斗一生惨于一旦,公式推导过程并不重要,无需AI可解释性,此时人工智能表现出II型的适应性威胁(更多的,请阅读人工智能安全)。它生成的计算公式(部分)如下,可以很快转换成为成为代码,所谓算法自动化

这只是个例子,AlphaTensor可以计算任意大小矩阵的乘法。根据论文在4*4*4、5*5*5、3*4*5、4*4*5和4*5*5的矩阵乘法中生成的算法都获得了比SOTA更少的乘法次数,而其他大小矩阵乘法与当前最好结果一样。

AlphaTensor自动发现矩阵乘法算法的方法在于,

1. 问题的形式化

如图2*2*2的例子是标准的矩阵乘法,运算过程可以看作是定义在一个3D张量(4*4*4)上的一系列运算,如a1*b1+a2*b3=c1,图中深色元素为1,浅色为0。由此论文将矩阵乘法作为一个张量分解问题,分解公式如下,

其中,R是张量的秩,u  v  w都是向量。例如,上面的例子用u v w就可以表示Strassen算法,进行了7次乘法运算。张量中元素值为0,1或-1,

2. 求解

AlphaTensor通过强化学习来填充张量的元素值u v w。为此,将该问题看作是单人棋盘对弈,初始状态就是矩阵相乘所代表的张量,在每个步骤,玩家选择一组u v w,然后更新棋盘状态,重复这个过程直到0张量为止。在每个步骤,设置了奖励值以鼓励获取到0张量的最短路径。与AlphaZero类似,AlphaTensor也使用深度神经网络来指导蒙特卡洛搜索树进行路径规划。对弈完成后的棋盘作为神经网络的输入去学习获得网络参数的优化。AlphaTensor采用基于transformer的神经网络结构。


点击阅读原文,查看《人工智能安全》图书信息。

互联网大数据处理技术与应用
互联网大数据与安全相关的各种技术,包括爬虫采集提取、大数据语义、挖掘算法、大数据安全、人工智能安全、相关技术平台以及各种应用。同时也会分享相关技术研究和教学的心得体会。
 最新文章