LONGNET:将Transformer扩展到10亿个标记

在本篇文章中,我们将详细讨论一个近期发布的先进模型——“LongNet”。该模型由微软亚洲研究院研发,于大约两周前正式公布。LongNet基于Transformer模型构建,其核心理念在于拓展Transformer的应用规模。值得一提的是,研究团队成功地将其扩展至处理10亿个令牌的规模。对于熟悉语言模型的人来说,会明白序列长度对模型性能的影响,因为序列长度决定了在执行注意力机制时,能够关联的令牌数量,从而影响模型可以获取的上下文信息长度。例如,我们希望像GPT这样的模型能拥有更长的上下文,使得模型可以参考更久之前的单词来预测下一个令牌。而LongNet就成功地将这个能力扩展到了10亿个令牌。以下图为例,可以清晰看出,GPT的序列长度仅为512,而Power Transformer的序列长度可扩展至12、000、64、262、000、甚至1000万,然而LongNet将序列长度扩展至惊人的10亿个令牌。试想一下,我们可以将所有维基百科的文本信息输入到模型中,模型可以利用所有这些令牌进行注意力计算。接下来,让我们首先来了解一下LongNet的工作原理。

摘要

在大型语言模型时代,扩展序列长度已成为一个重要需求。然而,现有方法在处理计算复杂性或模型表达能力时遇到困难,导致最大序列长度受限。为了解决这个问题,我们引入了LONGNET,这是一种Transformer变体,可以将序列长度扩展到10亿个标记以上,同时不损害对较短序列的性能。具体而言,我们提出了扩张注意力(dilated attention),随着距离增加,它以指数级扩展注意力范围。LONGNET有显著的优势:1)它具有线性计算复杂性,并且序列中任意两个标记之间存在对数依赖关系;2)它可以作为分布式训练器用于非常长的序列;3)它的扩张注意力可以无缝替换标准注意力,并且可以与现有基于Transformer的优化方法无缝集成。实验结果表明,LONGNET在长序列建模和一般语言任务上表现出强大的性能。我们的工作为建模非常长的序列打开了新的可能性,例如将整个语料库甚至整个互联网视为一个序列。

LongNet的优点

LongNet具有多种优点。首先,其计算复杂度与序列长度呈线性关系,稍后将具体解释原因。其次,令牌之间存在对数依赖,也就是说,两个距离较远的令牌之间的依赖性较弱,而距离较近的令牌之间的依赖性较强。此外,它可在分布式网络中进行训练,这意味着我们可以利用分布式系统计算该注意力机制,如使用多个GPU或多台计算机。同时,LongNet可以作为标准注意力的替代品,这意味着如果我们已经有一个使用注意力机制的模型,我们只需将注意力机制替换为LongNet的机制,无需改变模型的其他部分,模型仍然能够像以前一样运行,但通过使用这种改进的注意力机制,可以处理更长的序列长度。

关于Transformer模型的梳理

自注意力机制,我们使用了被称为“Q、K、V”的矩阵。其中,“Q”矩阵代表查询,其规模为“序列长度乘以模型大小”,模型大小指的是每个词嵌入的向量表示。当我们计算查询与键(K)的乘积,或者查询与K的转置的乘积来产生此矩阵时,所需的操作次数是“序列长度的平方乘以模型大小”,因为我们需要为矩阵中的每个元素计算点积。这就是为什么自注意力的复杂度是“序列长度的平方乘以模型大小”。这个比较在相关论文中也有详细描述,常规的注意力复杂度是“序列长度的平方乘以模型大小”,然而LongNet这种新模型,其注意力机制复杂度仅为“序列长度乘以模型大小”,下文我将说明如何实现这种线性复杂度。

LongNet的注意力分配原理

LongNet的核心原理是,令牌间的注意力分配会随着它们之间距离的增加而呈指数级地减小。让我们参照图表来理解它的运作方式。在传统方式中,我们计算所有令牌与其他所有令牌之间的注意力,但LongNet并未如此操作。它采用了一种将序列切分为不同大小窗口的方法。首先,以4为窗口大小为例,这里的“N”是序列的令牌数,我们将其分成四个大小为4的段,并计算这个小窗口内的所有词与其他词之间的注意力。然后,我们对所有这些小段中的词执行同样的操作,接着使用更大的窗口,这次窗口大小为8。如此类推,直到覆盖整个序列长度,然后我们再以增加窗口大小的方式进行操作,同时我们也增加了跳过的令牌数,即“R”。例如,我们可以先计算大小为8的窗口,然后将“R”设为2,这意味着我们会跳过一个令牌,然后计算注意力,再跳过一个令牌,继续计算注意力。以这种方式,随着窗口大小和跳过的令牌数的增加,计算的复杂度变得更小,因为我们并不是计算每个令牌与所有其他令牌之间的注意力,而是只计算在有限范围内的注意力。这样,LongNet的注意力分配就遵循了对数依赖的原则,即,距离较远的令牌之间的依赖性较弱,而距离较近的令牌之间的依赖性较强。这是LongNet能在更大序列长度上进行工作的关键。
扩展注意力由一系列用于建模短程和长程依赖关系的注意力模式组成,注意力模式的数量可以根据序列长度进行扩展。在每个注意力模式中,查询向量和键向量之间的点积被分解为多个子点积,每个子点积仅涉及到一小部分的键向量。这种分解方式可以减少计算复杂度,同时也可以使模型更好地处理长序列。具体如下图所示:

扩张注意力还引入了“多头”机制,可以在不同的头之间分别计算注意力。每个头都有自己的偏移量,这样就可以在不同的位置上计算注意力,从而更好地捕捉序列中的信息。通过这种方式,扩张注意力可以更好地处理长序列,同时保持较短序列的性能。具体如下图所示:

计算复杂度的优化

现在我们来看看为什么LongNet的计算复杂度是线性的。在传统的自注意力机制中,我们需要执行序列长度平方次数的点积操作,而LongNet通过使用窗口和跳过的方式,将计算的复杂度降低到了线性。如果我们假设窗口大小是固定的,例如为4,然后“R”也是固定的,例如为2,那么计算复杂度将是“O(N)”。当然,在实际操作中,窗口大小和跳过的令牌数可能会根据实际情况进行调整,但是它们是常数,不随序列长度增加而增加。这就是为什么LongNet的计算复杂度是线性的。

Token扩展10亿+

分布式训练方法,利用LONGNET的线性计算复杂度,将序列维度分布式地进行训练。具体而言,算法首先将输入序列沿着序列维度进行切分,每个序列片段被分配到不同的设备上进行计算。然后,每个设备将序列片段投影为查询、键和值,并使用本地计算得到局部的注意力权重。对于超出本地设备序列长度的部分,键和值将被发送到其他设备上进行计算。最后,所有设备将局部的注意力权重进行汇总,得到全局的注意力权重,并使用全局的注意力权重计算每个标记的表示。具体如下图所示:

该算法可以在任意数量的设备上进行扩展,并且可以通过并行计算来加速训练过程。由于LONGNET具有线性计算复杂度,因此该算法可以有效地处理超长序列,而不会牺牲训练速度和模型性能。此外,该算法还支持标准Transformer的优化技术,例如内核融合、量化和分布式训练,从而使得LONGNET可以无缝地与现有的深度学习框架进行集成。

LongNet的应用前景

LongNet的发布为自然语言处理领域带来了诸多潜在的应用前景。首先,它可以应用于更长文本的生成任务,如生成长篇小说或长篇新闻报道。其次,它可以应用于更复杂的对话任务,因为在对话中,我们往往需要处理大量的历史信息和上下文。另外,它还可能在翻译任务中发挥更大的作用,因为翻译往往涉及到处理长句子或长段落的情况。总的来说,LongNet的发布为我们提供了处理更长文本的新工具和可能性。

结论

总的来说,LongNet是一个基于Transformer的模型,它成功地将自注意力机制扩展到了10亿个令牌,实现了处理更长文本的能力。它的优势包括计算复杂度是线性的、遵循对数依赖原则,以及可以在分布式系统上进行训练。通过LongNet,我们可以探索更多自然语言处理任务,并处理那些过去由于序列长度限制而难以处理的任务。这个新模型的发布为我们带来了更多可能性。

References