加州大学伯克利分校等机构联合研究成果,通过统计建模实现线性复杂度,为长序列任务带来曙光。
旧金山 – Transformer 架构在人工智能领域取得了显著成就,尤其在计算机视觉、自然语言处理和长序列任务中表现出色。然而,其核心的自注意力机制面临着计算复杂度随输入 token 数量呈二次方增长的挑战,这限制了其在更长序列和更大模型上的应用。近日,来自加州大学伯克利分校、宾夕法尼亚大学、密歇根大学、清华大学、忆生科技、香港大学、约翰·霍普金斯大学等机构的研究团队,共同推出了一种全新的注意力机制——Token Statistics Transformer (ToST),并在 ICLR 2025 大会上以 Spotlight 论文的形式亮相。
该研究成果由加州大学伯克利分校三年级博士生吴梓阳(导师为马毅教授)作为第一作者完成。吴梓阳的主要研究方向为表征学习与多模态学习。据悉,马毅教授已受邀在今年四月的 ICLR 大会上就和此项成果相关的一系列白盒神经网络相关工作,进行为时一小时的主题报告(Keynote)。
ToST 的核心创新在于其基于统计学的线性注意力机制。与传统自注意力机制依赖于 token 两两相似性计算不同,ToST 通过对序列特征进行统计建模,显著降低了计算复杂度,实现了线性时间复杂度 O(n)。这意味着 ToST 在处理长序列时,资源消耗将大大减少,从而能够扩展到更大的模型和更长的序列。
该研究团队通过变分编码率缩减(Variational Rate Reduction, VRR)框架,对 ToST 进行了深入探讨,并通过实验验证了其在不同任务中的卓越性能。实验结果表明,ToST 在提高效率的同时,也保持了甚至超越了传统注意力机制的性能。
ToST 的核心方法主要包括:
- 统计特征提取: 对序列中的每个 token 提取其统计特征。
- 变分编码率缩减: 利用 VRR 框架对特征进行压缩,减少信息冗余。
- 线性复杂度实现: 通过一系列优化,其计算复杂度从 O(n²) 降低为 O(n)。
研究团队通过扩展先前的 CRATE 工作推导出网络架构,并表明通过对 MCR² 目标进行展开梯度下降所得到的架构会引入一种新的注意力模块,称为 Token Statistics Self-Attention (TSSA)。TSSA 拥有线性的计算和内存复杂度,并从根本上不同于典型的注意力架构,其后者通过计算 token 之间的两两相似性来实现。
ToST 的主要技术细节包括:
- 线性时间注意力机制 (TSSA): 通过白盒设计方法,从最大编码率减少 (MCR²) 的变分形式中推导而来。
- 创新性的网络结构: 通过将 TSSA 替代标准的自注意力模块,不仅实现了显著的效率提升,还增强了模型的可解释性。
该研究的开源代码和项目主页已发布,方便研究者和开发者进一步探索和应用 ToST。
- 论文地址: https://arxiv.org/abs/2412.17810
- 项目主页: https://robinwu218.github.io/ToST/
- 开源代码: https://github.com/RobinWu218/ToST
ToST 的出现,为解决 Transformer 架构的效率瓶颈带来了新的思路和方法。它不仅在理论上具有重要意义,也为实际应用提供了强大的技术支持。随着人工智能技术的不断发展,ToST 有望在长序列建模、大规模数据处理等领域发挥更大的作用,推动人工智能技术的进步。
参考文献:
- 吴梓阳 et al. Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction. arXiv preprint arXiv:2412.17810 (2024).
- 机器之心. 首个基于统计学的线性注意力机制ToST,高分拿下ICLR Spotlight. 机器之心, 2025年2月17日, https://www.jiqizhixin.com/.
Views: 0