0

Accelerating Birkhoff Projection for Manifold-Constrained Hyper-Connections

Manifold-constrained hyper-connections (mHCs) have recently been proposed as a principled extension of hyper-connections, where the residual mixing matrices are constrained to be doubly stochastic via projection onto the Birkhoff polytope.

Preview
Year
2026
Hosting
Full text hostedCC-BY-4.0

Cite

Notes

Only stored in your browser.

Attribution

Abstract & full text
arxiv.org/abs/2606.07574CC-BY-4.0
TL;DR
Semantic Scholar
Attribution policy →

Abstract

Manifold-constrained hyper-connections (mHCs) have recently been proposed as a principled extension of hyper-connections, where the residual mixing matrices are constrained to be doubly stochastic via projection onto the Birkhoff polytope. In practical mHC implementations, this constraint is enforced by Sinkhorn-Knopp iterations, and the backward pass relies on unrolling the iterative solver. This design introduces substantial computation and memory overhead, and may also yield inaccurate projections when the algorithm converges slowly on challenging inputs, undermining the intended norm-control and stability guarantees of mHCs. In this work, we focus on the practically important 4x4 Birkhoff projection setting and develop an end-to-end acceleration framework. By leveraging the dual formulation, we reduce the problem to a three-dimensional unconstrained convex problem and solve it with Newton's method, achieving fast convergence and high accuracy. For the backward pass, we replace the unrolled differentiation with implicit differentiation, yielding exact gradients without storing intermediate states. To exploit massive parallelism, we design a warp-level CUDA kernel that uses only register-level primitives, avoiding global and shared memory I/O. Extensive experiments against representative open-source baselines demonstrate that the proposed solver yields substantially more reliable doubly stochastic projections -- especially when the input magnitude is large -- and achieves significant end-to-end speedups (including the backward pass), reaching over 20x acceleration at large batch sizes while maintaining orders of magnitude smaller marginal errors.