Large language models (LLMs) have shown strong performance on code translation between widely used programming languages. However, translation becomes much less reliable for domain-specific code, where correctness depends on framework-specific APIs and execution semantics. One example is translating deep-learning code from PyTorch to JAX, where LLM outputs often contain subtle bugs or non-idiomatic usage that prevents execution or changes behavior. Prior work suggests that curated bug-fix data from LLM-generated code can help improve code generation quality, but such resources are still limited for PyTorch-to-JAX translation. In this work, we introduce T2J, a benchmark of LLM translation bugs paired with developer-written fixes for PyTorch-to-JAX code. We start from 20 kernels in the TorchLeet dataset, translate them to JAX using the weak LLM gpt-4o-mini, and hire software developers to debug and repair the generated JAX implementations. We then use T2J to improve PyTorch-to-JAX translation for the weak LLM gpt-4o-mini via in-context learning. Our evaluation shows that using T2J yields up to 20% improvement of our proposed metric T2J-CodeTrans-Score.
Learning Bug Context for PyTorch-to-JAX Translation with LLMs
Large language models (LLMs) have shown strong performance on code translation between widely used programming languages. However, translation becomes much less reliable for domain-specific code, where correctness depends on framework-specific APIs and execution semantics.
- Preview

- Year
- 2025
- Hosting
- Full text hostedCC-BY-4.0
Cite
Notes
Only stored in your browser.
Attribution
- Abstract & full text
- arxiv.org/abs/2510.09898CC-BY-4.0
- TL;DR
- Semantic Scholar