Transformers Can Navigate Mazes With Multi-Step Prediction
一种能提升Transformer复杂规划任务表现的训练目标:MLM-U
NTP(Next Token Prediction)的训练目标仍很难让Transformer在长期规划任务上有优异表现(模型往往会采取捷径,因为它没有真正从长远考虑路径规划),尤其在需要多步前瞻的迷宫导航(Maze Navigation)中,因为这种训练目标未提供提前预测多步路径或回顾已走路径的显式机制,针对此,作者考察了在训练Transformer时显式要求多步预测(包括向前和向后预测)是否能提升迷宫导航性能。作者采用另外一篇论文里提出的训练目标MLM-U(Uniform-Rate Masked Language Modeling)从头训练Transformer用于迷宫导航,有以下发现:
MLM-U显著提升transformer的迷宫导航能力,无论迷宫类型和大小,MLM-U的表现都优于相同参数规模下的NTP模型
MLM-U训练在数据与训练时间上更高效
随着迷宫难度变大,用MLM-U作为训练目标的Transformer在规模变大时获益更多
迷宫导航任务如figure1左侧所示,相比NTP,MLM-U训练时的context很不同且目标更多。此前已经有人提出用多个预测头预测未来的多个token,受启发于此作者研究了MLUM-U目标,该目标在训练时对输入序列的任意子集进行mask,并要求模型根据剩余上下文同时预测被mask掉的前后任意位置的token。通过这种方式,模型在训练中不断习惯在不固定的上下文范围内预测多个步骤,从而显式体现出多步预测能力。其数学形式如下equation2,其中μ是mask率,m_\mu X是被mask后的序列,m_\mu^C X是未被mask的上下文,即用于预测的条件序列。
为研究学习目标在迷宫导航中的作用,作者从头训练transformer,让其在不断增加复杂度的迷宫上生成最短路径。首先是迷宫的两种表示形式,见figure2,
图中间是DFS迷宫,其特点是从随机起点开始深度搜索,生成较长路径且无模糊性:最短路径也是唯一不重叠的路径。迷宫通过枚举边连接关系的图元组序列进行文本化表示,再将起点、目标和解决路径拼接到文本中。
图右边是A迷宫,使用(x, y)坐标对每个单元格进行token化表示,相较DFS的图元组表示,会产生更长的输入序列。A迷宫可能存在多个等长最短路径,其数据集中会选择一条作为真值路径。
作者首先评估NTP对于迷宫导航的效果,他们训练模型根据已有的迷宫解码序列来预测后续token,他们测试多种transformer规模与架构,包括decoder-only架构和一种encoder-decoder架构,另外还有一种他人利用A*搜索路径额外监督训练的NTP模型。而对于MLU-U目标训练,与NTP基线相同,MLM-U也在迷宫解决路径的文本表示上训练,并在推断时以相同方式从左到右生成。由于MLM-U的均匀mask率训练使模型适应不同长度的上下文和预测序列,推断时不会出现分布偏移问题。
实验结果如下,作者从迷宫复杂度、训练数据效率和计算效率三个维度将两种训练目标做了比较:
MLM-U在DFS迷宫中全面优于NTP:见上table1,模型为8M参数的Transformer,MLM-U能在20x20大小以下的迷宫中实现完美导航(100%正确率),在更复杂的30x30迷宫中,MLM-U的性能几乎是next token训练的3倍。
MLM-U数据效率更高且小型迷宫上更具计算效率:见下figure3,在5x5迷宫中,若两者都能解决任务,MLM-U仅需25k训练样本即可达到完美表现,而NTP需全部100k样本才能达到类似效果。见figure4,在训练收敛所需的迭代次数上,MLM-U比NTP快约2.17倍。
无论是否有A*搜索监督,MLM-U都优于next token预测:见下table2,一个接近2倍参数规模(15M参数)的NTP模型在30x30迷宫上仅达成13.3%的精度,而MLM-U(8M参数)能达到85.5%。
对比MLM-U与next token的训练动力学,发现NTP训练更易过拟合,而MLM-U泛化更好,另外MLM-U在更复杂迷宫中从更大模型规模中获益更多。
撰文:戴剑波;编辑:戴剑波
未经本公众号授权不得转载,欢迎转发。