现在的大型语言模型,能力个个都挺强但,它们的计算能力都不太行:比如 GPT—3,没法做超过三位数的加法再比如它们当中的大多数都可以写代码,但是理解代码却很费劲 —— 稍微遇到点带循环的程序就GG
不过,来自 MIT 和谷歌的研究人员发现:不用修改模型的底层架构,只需训练它们学会像程序员 debug 时那样打断点,读代码的能力就唰唰唰地涨上去了。
将同样的思路用于大数加法,多项式计算,那就更不在话下了所以,语言模型的数学能力终于也要跟上了
教语言模型用打断点的方法做加法,读程序
前面说的打断点,其实指的是在计算步骤较多的程序中,让模型把每一步都编码为文本,并将它们记录到一个称为便签的暂存器中,或者叫草稿纸。
听起来是个笨方法,但正是这样才使得模型的计算过程变得清晰有条理,性能也就比以往直接计算的方式提升了很多。
具体操作也很简单。就比如在简单的加法计算中,计算29+57的方式就是像这样的:
其中 C 表示进位,#表注释先计算 9+7,进位 1,再计算 2+5 + 进位 1,最后得出 86
从上可以看出,这个训练示例由输入和目标组成训练时将两者都喂给模型,测试时,模型就能根据输入预测出正确的目标
而目标就是要发送到临时暂存器上的内容,通过关注其上下文就可以引用,实际操作中,还可以对草稿内容进行检查纠错。
显著提高语言模型的计算能力
研究人员选用了仅含解码器结构的 Transformer 语言模型来实验,其参数规模介于 200 万到 1370 亿之间原则上,任何序列模型都可以使用这个方法,包括编—解码器模型或循环网络等
首先,他们按这种打断点的方式训练语言模型进行 1—8 位数的整数加法训练包含 10 万个示例,并进行了 5000 步的微调,batch size 为 32然后分别用 1 万个数据来测试训练分布内的加法,1 千个数据来测试训练分布之外,也就是 9 位和 10 位数的加法
将结果分别与直接运算的语言模型进行比较,发现:即使超出临界模型大小,用了打断点法的模型也能够进行加法运算,而直接运算的基线模型就没法做到这一点。
而在分布外的任务中,直接运算的基线模型完全挂掉 ——没练过就不会做,而用了断点法的模型伴随着规模的增大 hold 住了 9—10 位数的加法。
好,大数加法搞定接下来上多项式他们生成了一个包含 1 万个多项式的训练数据集和 2000 个数据的测试集其中项数一般不超过 3 项,系数范围在—10 到 + 10 之间,结果在—1000 到 + 10000 之间
结果发现:无论是微调还是少样本训练之后,断点法的性能都优于直接预测。
最后就是读 Python 代码了。
代码的训练示例中,记录了正在执行的是哪行代码,以及此时各变量的值,用 json 格式表示。
此前的语言模型读代码的能力都表现不佳打断点的方式可以让它们一改常态么
首先,经过 200 个程序的测试发现,断点法整体执行精度更高与直接执行的模型相比,微调还可以将模型性能从 26.5% 提高到 41.5%
一个真实例子:
断点法经过 3 次 while 循环,最终给出了正确的变量值。
接着,他们又用包含了 1000 个程序的 MBPP 数据集进行训练和测试这些程序涉及多种数据类型的计算,包括整数,字符串,浮点数等,以及涉及循环,API 调用和递归等流程结构并添加训练数据之外的single line程序集和 CodeNet 程序集进行测试
结果发现,模型也可以很好地扩展。
当然,这个方法也有局限性:比如复杂的计算可能需要很长的暂存器,这可能需要进一步改进 Transformer 生成窗口的大小好在这也是 NLP 领域的一个活跃研究领域而在未来,他们可能会尝试在无监督情况下用强化学习让语言模型学会打断点
总之,语言模型的计算能力,读代码的能力会越来越强。
论文地址:
。