The pursuit of computational efficiency has driven the adoption of low-precision formats for training transformer models. However, this progress is often hindered by notorious training instabilities. This paper provides the first mechanistic explanation for a long-standing and unresolved failure case where training with flash attention in low-precision settings leads to catastrophic loss explosions. Our in-depth analysis reveals that the failure is not a random artifact but caused by two intertwined phenomena: the emergence of similar low-rank representations within the attention mechanism and the compounding effect of biased rounding errors inherent in low-precision arithmetic. We demonstrate how these factors create a vicious cycle of error accumulation that corrupts weight updates, ultimately derailing the training dynamics. To validate our findings, we introduce a minimal modification to the flash attention that mitigates the bias in rounding errors. This simple change stabilizes the training process, confirming our analysis and offering a practical solution to this persistent problem.
Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
Low-precision training of transformer models with flash attention suffers from catastrophic loss explosions due to low-rank representations and biased rounding errors, which are addressed by a minimal modification to the flash attention mechanism.
- Year
- 2025
- Venue
- arXiv 2025
- Authors
- 2
- Hosting
- Abstract onlyARXIV-DEFAULT
Cite
Notes
Only stored in your browser.
Attribution
- Abstract & full text
- arxiv.org/abs/2510.04212ARXIV-DEFAULT
- TL;DR
- Semantic Scholar