橦言无忌

一个不想改变世界的程序媛

Convexifying Transformers

前言

文章:Convexifying Transformers: Improving optimization and understanding of transformer networks

essay link

凸优化的角度理解和优化Transformer网络~

摘要

了解Transformer网络成功背后的基本机制仍然是深度学习中一个悬而未决的问题,尽管它们的出色表现主要归功于自我注意机制,但文献仍然缺乏对这些网络的可靠分析和对它们所学函数的解释。为此,我们研究了注意力Transformer网络的训练问题,并引入了一种新颖的凸分析方法来提高对这些网络的理解和优化。特别是,我们首先引入了自注意力机制的凸替代方案,并用我们的凸注意力重新表述了Transformer网络的正则化训练问题。然后,我们将重构为一个可解释且更易于优化的凸优化问题。此外,作为我们凸分析的副产品,我们揭示了一种隐式正则化机制,它促进了token之间的稀疏性。因此,我们不仅改进了注意力或称Transformer网络的优化,而且还提供了对它们学到的函数的理论理解,并且通过几个数值实验证明了我们理论的有效性。

1,介绍

vaswani2017attention 提出的 Transformer 网络已成为各种任务中的主导架构,尤其是自然语言处理 (NLP),因为它们具有非凡的泛化特性和从海量数据中学习的高能力。 尽管有大量关于Transformer网络有效性的经验证据,但由于其高度非线性和非凸结构,揭示其成功背后的潜在理论原因仍然是一个悬而未决的研究问题。

大量研究侧重于通过实证研究分析Transformer网络的某些组件,例如 liu2021analyzingattention 等研究了注意力机制对 transformer 网络的影响。尽管这些研究一致认为注意力是 Transformer 的重要组成部分,但它们也提出了一些有关可解释性和优化的问题。特别是,voita2019analyzing 证明可以删除大多数注意力头而不影响网络性能,这是网络中大量冗余的一个指标。 attentionacrossNLP 提供了一组经验证据表明某些 NLP 任务可能不需要注意力。此外,2021dong_attention 透露,虽然注意力是Transformer网络的核心,但在没有全连接 (FCN) 层和跳跃连接的情况下训练注意力网络极具挑战性,因为没有它们,网络输出会迅速退化。类似地,takase2022layer 讨论了层归一化和跳跃连接对Transformer网络的重要性,因此即使改变它们的位置也可能显着影响Transformer网络的性能。然而,仍然缺乏对这些问题背后的潜在因素的可靠理论分析,这可能是由于Transformer网络的高度复杂和非凸结构。

一系列论文还侧重于设计自注意机制的新替代方案,这些替代方案表现相似,并可能为整体模型提供进一步的解释。一组工作利用基于多层感知器的架构,如tolstikhin2021mlp等,而另一组论文提出基于傅里叶的模型如lee2021fnet等,21adaptive等还提出用矩阵分解 geng2021attention 代替 self-attention 机制。尽管这些工作成功地应用于某些应用,但它们缺乏从优化角度进行扎实的理论分析和理解。最近,sahiner2022convex 尝试通过完全改变自注意力机制的结构并移除 FC 层,通过凸对偶分析 transformer 网络。即使那样,它们也未能为Transformer提供可靠的实际意义,因为它们的公式极具挑战性且在实践中难以解决。

最近,另一项研究侧重于理解 transformer 网络训练过程中出现的结构和模式 如power2022grokking等。 特别是,grokking 现象首先由 power2022grokking 在特定算法任务(例如模除法运算)中观察到。 具体来说,grokking 指的是验证或测试准确性突然过渡到完美泛化,并且这种泛化发生在完美训练准确性点之后。 这个有趣的行为与深度学习模型训练中早期停止的常见做法相矛盾,并且肯定需要进一步了解为什么会出现这种现象。

为了解决与标准Transformer网络相关的问题,在本文中,我们开发了一个凸优化视角来训练、分析和理解Transformer网络。 特别是,我们首先提出了自注意力机制的凸替代方案,然后在结果模型上开发了我们的凸分析框架,如图1所示。

1.1 贡献

本文贡献如下:

  • 我们提出了标准自注意力机制的替代公式,并用它研究了注意力Transformer网络的正则化训练问题。
  • 如图1所示,我们使用提出的注意层凸化了注意Transformer网络的正则化训练问题,因此能够找到全局最优解而不需要任何非凸优化启发式,例如层归一化和跳跃连接。
  • 我们还将我们的凸分析框架应用于各种架构,例如,有或没有 FCN 层的网络。因此,我们能够解释在整个训练过程中每个组件对学习模型的影响。
  • 我们揭示了一种由注意力机制引起的隐式正则化机制,之后进一步将这种正则化描述为跨token的稀疏性诱导因素。
  • 我们通过各种实验结果证明了凸重构的有效性。我们还表明,我们的重新表述显着减轻了最近论文 power2022grokking等,其研究中指出的 grokking 现象。

1.2 记号

我们分别使用小写和大写粗体字母表示向量和矩阵,用下标表示向量或矩阵的某个列元素。 例如,$w_{jk}$ 表示矩阵 $\matrix{W}$ 的第 $jk$ 项,用 $\mathbf{I}_k$ 表示大小为 $k \times k$ 的单位矩阵,使用 $\mathbf{0}$(或 $\mathbf{1}$)表示具有适当尺寸的零(或1)向量或者矩阵,还使用 $[n]$ 表示范围从 $1$ 到 $n$ 的整数集,将 Euclidean 和 Frobenius 范数分别表示为 $|\cdot|_2$ 和 $|\cdot |_{F}$,还使用 $\mathbb{1}[x\geq0]$ 来表示 0-1 值指示函数,在表1中提供了在整篇论文中使用的更多符号。

2,Transformer网络

给定一个数据样本(或句子)$\mathbf{X} \in \mathbb{R}^{h \times d}$ 作为具有嵌入维度 $d$ 的 $h$ token序列,我们将键(key)、查询(query)和值(value)矩阵定义为:

它们是自注意力机制的主要组成部分。 然后,一个基本上是自注意力堆叠的单个Transformer块、残差连接、层归一化和逐点前馈连接可以表示如下:

其中 $\sigma(\cdot)$ 表示 FCN 层的激活函数,尽管跳跃连接、层归一化和 FCN 在Transformer块中也起着至关重要的作用,但这些网络的成功主要归功于自注意力部分,表示为 $\mathbf{A}_o$ 。 因此,在下一节中,我们首先研究简化的transformer网络的训练问题,网络输出直接为 $\mathbf{A}_o$。 然后,我们将推导扩展到具有 FCN 层的Transformer网络。

3,仅注意力网络

我们首先考虑一个简化的Transformer网络,它只有一个自注意层,将输入序列 $\mathbf{X}\in \mathbb{R}^{n \times d}$ 映射到 $c$ 维输出序列 $\hat{\mathbf{Y}} \in \mathbb{R}^{n \times c}$ ,即:

我们也称模型 \eqref{eq2} 为仅注意力网络。这是一个有意义的模型,已应用于各种任务,包括机器翻译、语言建模、图像字幕和对象识别。

接下来我们考虑一个带有任意凸损失函数的标准回归框架。 给定训练集 $\{\mathbf{X}_i, \mathbf{Y}_i\}_{i=1}^N$,其中 $\mathbf{X}_i \in \mathbb{R}^{n \times d}$ 和 $\mathbf{Y}_i \in \mathbb{R}^{n \times c}$ 分别表示输入序列和标签输出,\eqref{eq2} 中仅注意力网络的权重衰减正则化训练问题如下:

其中 $\mathcal{L}(\cdot)$ 是任意凸损失函数,包括平方损失和交叉熵,$\beta >0$ 是正则化系数。

尽管 \eqref{eq2} 中的注意力模型在各种 NLP 任务中非常强大,例如,自然语言推理、神经机器翻译和文本分类,\eqref{eq3} 中相应的训练问题是一项极具挑战性的优化任务,需要对各种非凸优化启发式进行充分训练。 为了解决这些问题,在接下来的部分中,我们首先通过用替代凸层替换注意力部分来重新制定训练问题,然后将重新制定的训练问题转换为可解释的凸优化问题,从而实现全局优化网络参数。

3.1 凸注意力层

我们首先注意到,由于 $\text{softmax}(\cdot)$ 操作是高度非线性和非凸的,因此 \eqref{eq3} 中的训练问题是一个具有挑战性的非凸优化问题。 因此,人们可能无法充分训练注意力网络并在训练结束时获得微不足道的模型。 例如,dong_attention 表明注意力网络在整个训练过程中可能会退化,并且输出会收敛到秩为 1 的矩阵。 因此,他们无法学习基础任务。

为了避免与 \eqref{eq2} 中的非凸公式相关的问题,我们首先用更简单但有效的替代方法替换 $\text{softmax}$ 操作。 特别是,由于 $\text{softmax}$ 将其输入矩阵的行转换为概率分布,因此可以将其放宽为具有单位单纯形约束的线性运算,如下所示

其中 $\Delta := \{\mathbf{W} \in \mathbb{R}^{n \times n}: \mathbf{w}_i\geq 0, \matrix{1}^\top\mathbf{w}_i=1 , \forall i \in [ n]\}$ 表示约束的凸集,也称为单位单纯形约束。 因此,我们在不扰乱其结构的情况下简化和凸化了注意力机制。 基于这一观察,\eqref{eq3} 可以重新表述如下:

请注意,上面的模型使用单个头部注意力模型,因此,由于其表达能力不足,可能不具有实际意义。 因此,我们在 \eqref{eq4} 中引入 head 的概念如下:

现在,我们已准备好将凸分析工具应用于 \eqref{eq5},详见下一节。

3.2 凸优化应用到仅注意力网络

作为热身,让我们考虑标量输出预测问题,目标是一维的,即 $y_i \in \mathbb{R}$。 然后,\eqref{eq5} 简化为以下优化问题

接下来,我们首先在参数 $\mathbf{w}_{2j}$ 和 $w_{3j}$ 之间应用重缩放,使得 \eqref{eq6} 可以描述为 $\ell_1$ 正则优化问题。

引理1
\eqref{eq6} 中的问题等同于下面的 $\ell_1$ 正则化训练问题:

基于引理1中的等价公式,下一个定理引入了等价于 \eqref{eq6} 的凸优化问题。

定理1
非凸优化问题 \eqref{eq6} 可以等效地转换为以下凸优化问题:

请注意,\eqref{eq8} 中的等效凸模型需要单个参数矩阵 $\matrix{Z} \in \mathbb{R}^{n \times d}$,其中每一行是相应token的注意力分数。我们还注意到 \eqref{eq8} 中的正则化,即参数矩阵 $\matrix{Z}$ 行的 $\ell_2$ 范数之和,是一种特定类型的正则化,也称为组 $\ell_1$ 或 Lasso,由 bakin1999adaptive 引入并可促进跨参数的组稀疏性。在我们的例子中,组稀疏度跨越token索引 $k$,因此,可以将 \eqref{eq8} 中的模型解释为稀疏线性模型,其中稀疏性跨token。换句话说,\eqref{eq8} 可以解释为一个模型,它试图使用尽可能少的token来拟合训练标签 $\{y_i\}_{i=1}^N$。

与 \eqref{eq6} 中表示为 $\mathbf{w}_{1j} \in \Delta$的非负注意力得分不同,凸参数 $\matrix{Z} \in \mathbb{R}^{ n \times d}$ 不需要任何约束 . 因此,可以直接应用标准训练算法,如 SGD 和 Adam 来训练凸问题 \eqref{eq8}。 此外,可以从 \eqref{eq8} 的解中恢复 \eqref{eq6} 的一组最佳参数,如以下结果所示。

命题1
在求解 \eqref{eq8} 中的凸优化问题后,可以恢复 \eqref{eq6} 中的非凸优化问题的最优解,表示为 $\{\mathbf{w}_{1j}^,\mathbf{w}_{2j}^,w_{3j}^*\}_{j=1}^h$,如下:

其中 $\mathbf{e}_j \in \mathbb{R}^{n}$ 是第 $j$ 个普通基向量, $\mathbf{z}_j \in \mathbb{R}^{d}$ 是 $\matrix{Z}$矩阵的第 $j$ 行,我们假设由于 \eqref{eq8} 中的稀疏诱导正则化,$\matrix{Z}$ 的 $n$ 行中有 $h$ 非零行。

命题1证明了 \eqref{eq6} 中的非凸公式的参数与 \eqref{eq8} 中凸公式的参数之间存在一对一的映射关系。 因此,无需解决具有挑战性的非凸优化问题 \eqref{eq6} ,该问题还需要对多种启发式优化才能进行充分训练。 相反,可以解决凸问题 \eqref{eq8},然后使用命题1中的映射来获得 \eqref{eq6} 的最优解。

3.3 拓展到多维输出

在上一节中,我们考虑了目标变量为标量的问题,即 $y_i \in \mathbb{R}$。 然而,对于某些问题,例如多类分类,目标变量可以是多维的,因此,我们现在将分析扩展到多个向量输出的问题,如下所示

其中 $\mathbf{y}_i \in \mathbb{R}^{c}$ 和 $c$ 表示输出类别的数量。 请注意,这里我们在 $\mathbf{w}_{3j}$ 上应用 $\ell_1^2$范数,但这不会影响网络在实践中的性能。 然后,按照相同的推导在下一个结果中产生凸规划。

定理2
非凸优化问题 \eqref{eq9} 等价于以下凸优化问题

定理2表明等效的凸模型在输出索引 $l$ 上变得可分离,即,而不是 \eqref{eq8} 中的单个参数矩阵,这里我们有 $c$ 个参数矩阵由于,因为在非凸模型 \eqref{eq9} 中有 $c$ 个输出(详见表2)。 这也说明了网络中输出的数量直接控制了等效凸公式的超参数化水平。

3.4 带FCN层的注意力网络

尽管 \eqref{eq5} 中的模型在各种应用中表现出有趣的特性,但它基本上是token矩阵 $\mathbf{X}$ 的线性函数。 因此,它很可能会遇到性能不足的问题,尤其是对于 NLP 中的一些具有挑战性的问题。 一系列论文 dong_attention 等也通过广泛的经验证据证实了 FCN 的重要性。 因此,在本节中,我们在 \eqref{eq5} 中将一个 FCN 层添加到我们的注意力模型中,并为这个新模型推导出一个等效的凸公式。

在这里,我们考虑以下优化问题

其中 $\sigma(\cdot)$ 是激活函数。

定理3
具有gated ReLU 激活的非凸优化问题 \eqref{eq11} 等价于以下凸优化问题:

其中 $1_{ij} := 1\{\mathbf{u}_{1j}^\top \mathbf{X}_i \mathbf{u}_{2j} \geq 0\}$ 表示gated ReLU 激活的指示函数,这里 $\{\mathbf{u}_{1j},\mathbf{u}_{2j}\}_{j=1}^h$ 是固定向量,可以随机选择。

定理3意味着引入激活函数会进一步增加等效凸公式的超参数化水平。 准确地说,\eqref{eq12} 的参数比 \eqref{eq10} 多 $h$ 倍,如表2所示。

4,数值实验

在本节中,我们将展示实验结果,以证实我们在前几节中的理论。

BERT中的师生设置
我们首先考虑在 Hugging Face 库中使用预训练 BERT 模型的师生设置,即 bert-base-uncased。特别是,我们从 glue 数据集的 mrpc 子集中获取数据,传递给预训练的 BERT 模型,并将输入和输出激活保存在特定层中。然后,我们训练仅注意模型,即标准非凸自注意 \eqref{eq3}、替代非凸注意 \eqref{eq9} 和凸 \eqref{eq10},从头开始分别前后使用这些激活函数,用前面得到的数据作为我们的训练数据集。

本节中的所有实验都是使用 Google Colab 上的单个 GPU 执行的,还使用相同的正则化系数 $\beta$ 和优化器,即 Adam,并通过对两种算法的验证数据集执行网格搜索来调整学习率和正则化系数。但是,请注意,我们没有对所有实验中的凸模型使用任何非凸优化启发式方法,例如层归一化和跳过连接。在图2中,我们使用从预训练 BERT 模型的第六层提取的数据绘制了目标值(即训练损失 + 正则化项)和测试损失,以秒为时间单位。我们观察到,凸训练方法实现的目标值几乎比标准的非凸训练小一个数量级,后者可能停留在局部最小值,这种训练也能增强泛化能力,即我们的凸训练方法比标准的非凸训练获得更低的测试损失。为了理解每个模型学习的函数,我们还分析了图3中的注意力图。在这里,标准的非凸训练无法学习底层模型并输出跨token的统一注意力图。然而,凸训练输出了一个与地面真实注意力图非常相似的注意力图,因此我们成功地学习了训练数据中的结构。因此,这些实验清楚地说明了凸训练方法在训练和测试中的有效性。

算法数据集和 Grokking
受 power2022grokking 中观察到的 grokking 现象的启发,我们接下来使用 \eqref{eq1} 中的自我注意机制,在算法数据集上,验证凸训练方法对标准Transformer网络的有效性。特别是,我们使用在 power2022grokking 中相同的设置,并使用 $\mod 97$ 和 $\mod 15$ 评估模块化除法运算的性能,一直训练框架直至达到 $99\%$ 的测试精度。

在图4中,我们首先复制了 power2022grokking 中的结果,并确认这里确实出现了 grokking 现象,即非凸曲线(图4a中紫色)在 $10^3$ 左右时达到 $100 \%$ 的训练精度,但是图4b中需要超过 $10^5$ 次迭代才能达到完美泛化。我们还比较了非凸和凸训练方法,凸训练方法比图 4b 中的非凸训练方法收敛到完美泛化精度要快 10 倍。此外,凸模型在图 4c 中产生的测试损失也显着降低,这意味着它对测试预测具有更高的置信度,因此比标准非凸训练更稳健。

我们注意到在上一节中,我们理论上只分析了单个注意力Transformer块。然而,由于深度或层数(表示为$L$)的良性影响已经在深度学习文献中得到了经验证明,我们还建议将我们的凸模型扩展到更深的设置。我们在 \eqref{eq12} 中堆叠凸Transformer层以获得任意深度的网络。在图5 中,我们比较了双层Transformer网络与一层网络的性能。在这里,我们观察到虽然增加一层可以显着改善凸模型,尤其是在优化速度方面,但它无法对非凸模型产生任何明显的差异。此外,我们在 $\mod 15$ 操作上运行算法,由于样本数量较少,这基本上是更具挑战性的任务。在这种情况下,如图6 所示,单层模型无法完美地学习底层任务,但我们的凸模型在测试精度和测试损失方面明显更好。通过将层数增加到四层,我们使两个模型都能达到完美的泛化精度。我们的深度模型比非凸模型更快地达到这个水平并且产生更低的测试损失。

接下来,我们根据经验分析我们的凸模型和标准非凸模型上的 grokking 现象。 为此,我们绘制了图7a 中每个实验达到 $99\%$ 测试准确度所需的迭代次数。 请注意,这里我们不包括 $\mod 15$ 案例的单层模型结果,因为在该案例中两个模型都未能实现完美泛化。 图7a 清楚地表明,我们的凸训练方法比标准非凸训练更快地收敛到 $99\%$ 的准确度水平。 因此,我们还减轻了 grokking 现象的影响,如图7b 所示,我们根据迭代次数量化了 grokking 的数量。 基于这个实验,我们还推测 grokking 现象主要归因于标准Transformer模型的高度非线性和非凸结构。

5,结论

在本文中,我们研究了注意力Transformer网络的正则化训练问题,并开发了一个凸分析框架来训练这些网络。 特别是,我们首先提出了自注意力机制的凸替代方案,然后将这种替代注意力机制的训练问题重新表述为凸优化问题。 由于我们的凸重构,我们全局优化网络参数而不需要任何类型的非凸优化启发式。 此外,我们重构的训练框架学到的函数是透明和可解释的。 更重要的是,重构的问题揭示了数据中跨token的稀疏性诱导正则化机制,这也更清楚地说明了结果函数的结构及其泛化属性。 然后,我们通过几个数值实验,凭经验验证了我们的凸训练方法相对于标准非凸训练的有效性。

我们还注意到,通过凸优化理论的视角分析Transformer网络是极其重要的,因为它可能会大大改善对这些网络的理解和优化。 然而,由于网络模型固有的非凸结构,这也非常具有挑战性。 据我们所知,本文是朝着这个方向迈出的第一步,因此存在一些局限性,希望在未来的工作中消除这些局限性。 具体来说,在本文中,我们主要关注凸分析的理论方面,并在一些小规模问题实例上对理论进行了实证验证。 我们希望后续论文能对我们的理论进行全面、大规模的实证验证。

// 代码折叠