In the realm of deep learning, the Transformer model has achieved remarkable success, largely due to its attention mechanism. This mechanism enables Transformer-based models to focus on relevant parts of input sequences, thereby enhancing contextual understanding. However, the attention mechanism comes with a significant drawback: its computational expense grows quadratically with input size, making it challenging for Transformers to handle very long texts. Recently, the emergence of Mamba has broken this barrier, offering linear scalability with the increase in context length. Now, a new paper by the Mamba authors introduces a method to distill large Transformer models into hybrid linear RNNs, potentially revolutionizing the field once again.
Background and Context
The attention mechanism’s computational cost has been a bottleneck for Transformer models, limiting their ability to process extensive sequences. Mamba, a state-space model (SSM), introduced a novel selection mechanism that reparameterizes the SSM based on input, allowing the model to indefinitely retain necessary and relevant data while filtering out irrelevant information. This innovation enabled Mamba to match or even outperform Transformers in smaller to medium-sized scales while maintaining linear scalability with sequence length.
The New Research
In a new paper titled The Mamba in the Llama: Distilling and Accelerating Hybrid Models, the authors demonstrate that large Transformers can be distilled into large hybrid linear RNNs with minimal additional computation, while preserving most of their generation quality. The research, published on September 1, 2024, by the Mamba authors, reveals a significant breakthrough in model compression and efficiency.
The Distillation Process
The distillation process involves transferring knowledge from a large Transformer (teacher model) to a smaller hybrid linear RNN (student model). The authors employed a multi-level distillation approach that combines progressive distillation, supervised fine-tuning, and targeted preference optimization. This method outperforms standard distillation techniques in terms of perplexity and downstream evaluation results.
The authors hypothesized that most of the knowledge in a Transformer is preserved in the MLP layers, and thus focused on distilling the fine-tuning and alignment steps of the LLM while keeping the MLP layers frozen. The Mamba layers were then trained to replace the Transformer attention heads, maintaining the original Transformer’s MLP layers without training them.
Linearizing Attention
The key innovation in this research is the linearization of the attention mechanism. By removing the softmax function, the attention formula can be linearized, but this can lead to a degradation in model capabilities. To address this, the authors designed a distillation linear RNN that closely matches the original Transformer’s parameterization while effectively expanding the capacity of the linear RNN.
Instead of attempting to capture the exact original attention function, the research uses the linearized form as a starting point for distillation. The standard Q, K, and V heads from the attention mechanism are directly fed into the Mamba discretization, followed by the application of the linear RNN. This approach allows the model to learn richer interactions through the extended hidden state.
Performance and Efficiency
The resulting hybrid model, which includes one-quarter of the attention layers, achieves performance comparable to the original Transformer on chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with billions of tokens. Furthermore, the research team introduced a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and the hybrid models.
The best-performing model in this study was distilled from Llama3-8B-Instruct and achieved a length-controlled win rate of 29.61 on AlpacaEval 2 compared to GPT-4 and a win rate of 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN models.
Conclusion
The new research by the Mamba authors represents a significant advancement in the field of deep learning. By distilling large Transformers into hybrid linear RNNs, the team has opened up possibilities for more efficient and scalable models that can handle long sequences with ease. This work not only addresses the computational challenges of the attention mechanism but also paves the way for future innovations in model compression and efficiency.
For those interested in delving deeper into the technical details, the paper is available at: https://arxiv.org/pdf/2408.15237.
Views: 0