0

Exact Learning of Permutations for Nonzero Binary Inputs with Logarithmic Training Size and Quadratic Ensemble Complexity

Two-layer neural networks in the infinite-width NTK regime can learn permutations of binary inputs through gradient descent, and ensemble methods yield highly accurate predictions with Gaussian process dynamics.

Year
2025
Venue
arXiv 2025
Authors
3
Hosting
Abstract onlyARXIV-DEFAULT

Cite

Notes

Only stored in your browser.

Attribution

Abstract & full text
arxiv.org/abs/2502.16763ARXIV-DEFAULT
TL;DR
Semantic Scholar
Attribution policy →

Abstract

The ability of an architecture to realize permutations is quite fundamental. For example, Large Language Models need to be able to correctly copy (and perhaps rearrange) parts of the input prompt into the output. Classical universal approximation theorems guarantee the existence of parameter configurations that solve this task but offer no insights into whether gradient-based algorithms can find them. In this paper, we address this gap by focusing on two-layer fully connected feed-forward neural networks and the task of learning permutations on nonzero binary inputs. We show that in the infinite width Neural Tangent Kernel (NTK) regime, an ensemble of such networks independently trained with gradient descent on only the $k$ standard basis vectors out of $2^k - 1$ possible inputs successfully learns any fixed permutation of length $k$ with arbitrarily high probability. By analyzing the exact training dynamics, we prove that the network's output converges to a Gaussian process whose mean captures the ground truth permutation via sign-based features. We then demonstrate how averaging these runs (an "ensemble" method) and applying a simple rounding step yields an arbitrarily accurate prediction on any possible input unseen during training. Notably, the number of models needed to achieve exact learning with high probability (which we refer to as ensemble complexity) exhibits a linearithmic dependence on the input size $k$ for a single test input and a quadratic dependence when considering all test inputs simultaneously.

Authors

3