Upgrade to Pro — share decks privately, control downloads, hide ads and more …

NLP Colloquium

NLP Colloquium

May 21, 2025
Presentation of "Transformers Provably Solve Parity Efficiently with Chain of Thought" (ICLR 2025) at the NLP Colloquium (Japanese, online)

Avatar for Juno Kim

Juno Kim

May 21, 2025
Tweet

Other Decks in Research

Transcript

  1. Transformers Provably Solve Parity Efficiently with Chain of Thought ICLR

    2025 oral Juno Kim∗ Taiji Suzuki University of Tokyo, RIKEN AIP ∗Incoming PhD student at UC Berkeley
  2. Chain of thought (CoT) • Prompting an LLM to solve

    complex tasks step-by-step by recursively generating intermediate reasoning steps to arrive at the final answer • Directly fine-tuning models on ground-truth CoT greatly improves multi-step reasoning 1
  3. Existing theory • Evaluating increased expressivity with complexity theory: Feng

    et al. (2023); Merrill and Sabharwal (2023, 2024); Li et al. (2024b) • In-context learning setting with recursion: Li et al. (2023, 2024a) • Challenging to obtain computational (optimization) guarantees: difficult to characterize GD dynamics of a recursively applied transformer network • Concurrent work: Wen et al. (2025) study solving parity with SGD and memory constraint (for constant-pass learning, lower bound of either Ω(k2) memory or 2Ω(k) samples is known) and show that CoT achieves poly(k) sample complexity How does optimizing transformers to explicitly generate CoT improve multi-step reasoning capability? 3
  4. Our contributions • We study k-parity as a model for

    multi-step problems • SQ theory: any iterative gradient-based algorithm requires exp(Ω(d)) samples or runtime • Model: 1-layer softmax transformer recursively applied to its own output to generate CoT • When teacher forcing is used & loss is summed over intermediate parity states, our model solves parity in one gradient step and O(d2) samples • When CoT must be generated end-to-end, our model still solves parity in log(d) steps if a self-consistency scheme is used to bound error 010110 1 010110 1 0 1 4
  5. Problem setup • For d-bit inputs x = (xj )d

    j=1 ∼ Unif({±1}d ) and size k subset p ⊂ {1, · · · , d}, learn y = p(x) := j∈p xj given n i.i.d. samples (xi , yi )n i=1 with yi = p(xi ) • Heuristic (SQ lower bound): when k ∼ d, any gradient-based algorithm requires eΩ(d) queries or e−Ω(d) precision to solve parity ⇒ exponential samples or runtime is needed! 5
  6. SQ hardness (finite samples) • Let {fθ : θ ∈

    Θ}: any parametrized model with ∥∇θ fθ ∥ = poly(d) • Finite sample loss (note P is not orthogonal w.r.t. empirical inner product ⟨·, ·⟩n ): Ln (θ) = 1 2n n i=1 (yi − fθ (xi ))2 = 1 2 ∥p − fθ ∥2 n • ε-approximate gradient oracle ∇: takes query θ and returns ∇Ln (θ) = v (potentially adversarially) such that ∥v − ∇Ln (θ)∥2 ≤ ε • Cannot directly cross-reference between samples, otherwise Gaussian elimination solves parity with O(d) samples, O(d3) iterations 6
  7. SQ hardness (finite samples) Theorem (learning parities with empirical loss)

    1. If n = eΩ(d), there exists an e−Ω(d)-approximate oracle ∇ such that any iterative algorithm θ(A) which makes at most poly(d) queries to ∇Ln satisfies w.p. 1 − e−Ω(d) Ep,x (p(x) − fθ(A) (x))2 ≥ 1 − e−Ω(d) 2. If n = Ω(dν) with ν > 4ν1 + 4ν2 + 2ν3 + 1 and ∥∇fθ ∥ = O(dν1 ), there exists an O(d−ν2 )-approximate oracle ∇ such that any iterative algorithm θ(A) which makes at most O(dν3 ) queries to ∇Ln satisfies w.p. 1 − e−Ω(d) Ep,x (p(x) − fθ(A) (x))2 ≥ 1 − o(1) 7
  8. Proof sketch • Idea: since P is (nearly) orthogonal, show

    that Varn (θ; P) := Ep∈P ∥∇Ln,p (θ) − Ep′∈P [∇Ln,p′ (θ)]∥2 is small so that ∇Ln,p does not reveal any information about target p • Correlation supp̸=p′ |⟨p, p′⟩n| ≤ δ := 4d/n w.h.p. • By Gershgorin’s circle theorem, we obtain the frame bound λmax (Gram(P)) ≤ 1 + |P|δ Varn (θ; P) ≤ Ep∈P   D j=1 ⟨∇θj fθ , p⟩2 n   ≤ 2 1 |P| ∨ 4d n sup θ,x ∥∇fθ (x)∥2 8
  9. Proof sketch • By Chebyshev’s inequality, for any ε >

    0 it holds that Pr (∥∇Ln,p (θ) − Ep′∈P [∇Ln,p′ (θ)]∥ > ε) ≤ Varn (θ; P) ε2 • Construct ∇ as ∇Ln,p (θ) = Ep′∈P [∇Ln,p′ (θ)] ∥∇Ln,p (θ) − Ep′∈P [∇Ln,p′ (θ)]∥ ≤ ε, ∇Ln,p (θ) otherwise • Querying ∇Ln,p at most O(dν3 ) times does not reveal any information about p w.h.p. • Conditioned on this event, for any random guess θ we have Ep ∥p − fθ ∥2 ≥ 1 − o(1) 9
  10. Task decomposition x1 x2 x3 x4 x5 x6 x7 x8

    x9 x10 x11 x12 x13 x14 x15 x16 h = 0 h = 1 h = 2 h = 3 x17 x18 x19 x20 x21 x22 x23 = y • Decompose parity into intermediate 2-parity subtasks, which are easy to learn • Can a fixed model learn to generate the ‘reasoning chain’ xd+1 → · · · → xd+k−1 in order when recursively applied to its own output? 10
  11. Transformer model ↓ ↓ ↓ ↓ ↓ x1 x2 ·

    · · xd ˆ xd+1 · · · ˆ xm · · · ˆ xd+k−1 = ˆ y e1 e2 ed ed+1 em ϕ xi • We study GD dynamics of a one-layer transformer • Bit tokens x1, · · · , xd ∈ {±1}n + dummy tokens xd+1, · · · , xd+k−1 initially set to 0n • One-hot positional encoding em is appended to xm 11
  12. Transformer model • Softmax layer: reparametrize as K⊤Q = 0n×n

    0n×(d+k−1) 0(d+k−1)×n W , V = In×n 0n×(d+k−1) • Feedforward layer: fixed link function ϕ : R → R such that ϕ(0) = −1, ϕ(±1) = 1 (converts sums to parities) and ϕ′(0) = ϕ′(±1) = 0 • The transformer computes TF(x1, · · · , xd+k−1 ; W) = (ˆ x1, · · · , ˆ xd+k−1 ) where the original d bits remain unchanged and ˆ xm = ϕ(ˆ zm ) where ˆ zm = m−1 j=1 Vˆ pj · softmax(ˆ p⊤ j K⊤Qˆ pm ) = m−1 j=1 σj (wm )xj , σj (wm ) = ewj,m m−1 α=1 ewα,m 12
  13. CoT model ↓ ↓ ↓ ↓ ↓ x1 x2 ·

    · · xd ˆ xd+1 · · · ˆ xm · · · ˆ xd+k−1 = ˆ y e1 e2 ed ed+1 em ϕ xi • Repeatedly apply TF to its own output until the chain stops autoregressively updating: TF(k−1)(x1, · · · , xd , 0n, · · · , 0n ; W) = (x1, · · · , xd , ˆ xd+1, · · · ˆ xd+k−1 ) • ˆ y = ˆ xd+k−1 is returned as the model prediction 13
  14. Teacher forcing • Teacher forcing: ground truth labels are used

    to predict each new token during training, widely used to train autoregressive models such as RNNs & transformers • Prevents error accumulation and stabilizes training ˆ xm = TF(x1, · · · , xm−1, 0n, · · · , 0n ; W)m 14
  15. Teacher forcing • Loss is summed over all intermediate states

    L(W) = 1 2n d+k−1 m=d+1 ∥ˆ xm − xm∥2 • At test time, predictions for new ytest = p(xtest ) are generated by iterating TF as before Theorem (CoT with teacher forcing) Suppose n = Ω(d2+ϵ) for ϵ > 0 and ∇ is any O(d−2−ϵ/8)-approximate oracle. Then for any target parity p ∈ P, the one-step update W(1) = W(0) − η∇L(W(0)) with W(0) = 0, η = Θ(d2+ϵ/16) achieves w.p. 1 − exp(−dϵ/2) test loss: ∥ˆ ytest − ytest∥∞ ≤ O(d−ϵ/8) 15
  16. Sketch of proof • Dynamics decomposes into each 2-parity subtask

    • Explicitly compute the gradient of L w.r.t. wj,m at initialization and expand to obtain multilinear contraction (interaction) terms, e.g. 1 n ⟨xm, ˆ zm, ˆ zm⟩ = 1 n(m − 1)2 α,β ⟨xm, xα , xβ ⟩ • If α, β are child nodes of m, ⟨xm, xα , xβ ⟩ = n, otherwise O( √ n log d) • Computing up to 4th order interactions shows the leading term is Θ(d−2) if j is a child of m and O(d−2−ϵ/8) otherwise, so the correct signal can be extracted 16
  17. Without teacher forcing • In autoregressive models, teacher forcing can

    induce exposure bias/distribution shift where the model is not robust to its own error • Partial (scheduled or random) teacher forcing is employed in practice to overcome this (Goyal et al., 2017; Mihaylova and Martins, 2019) • Without teacher forcing, the model needs to generate CoT chains end-to-end during training, causing error accumulation and complicating dynamics 17
  18. Modified model • Quantization: round all weights to nearest integer

    W(t+1) = r[W(t) − η∇W L(W(t), U)] • Impose stronger autoregressivity where each token only depends on previous levels d d0 d1 d2 • Self-consistency to filter for ‘faulty reasoning’: if tokens on some level are uninformative (≈ 0), zero out its output since all subsequent reasoning will be wrong • This induces curriculum learning: each 2-parity level is ‘unlocked’ sequentially 18
  19. Self-consistency • L(W) = 1 2n d+k−1 m=d+1 ∥ˆ xm

    − xm∥2 where ˆ xm = TF(ˆ x1, · · · , ˆ xm−1, 0n, · · · , 0n ; W) = TF(log2 k)(x1, · · · , xd , 0n, · · · , 0n ; W) Theorem (CoT without teacher forcing) Suppose n = Ω(d2+ϵ) for ϵ > 0 and ∇ is any O(d−2−ϵ/8)-approximate oracle. Then for any target parity p ∈ P, the sequence of quantized updates W(t+1) = r[W(t) − η∇L(W(t), U)] with W(0) = 0, η = Θ(d2+ϵ/16) achieves test loss: ∥ˆ ytest − ytest∥∞ ≤ exp(−Ω(dϵ/16)) in log2 k iterations w.p. 1 − exp(−d(ϵ∧1)/2). ⇒ polynomial samples & logarithmic runtime suffices with CoT! 19
  20. Numerical experiments Compare one-layer transformers with d = 64, k

    = 32, GD on 100K samples • Direct: end-to-end generation, trained on target loss 1 2n ∥ˆ y − y∥2 • CoT: end-to-end chain generation, CoT loss • CoT+forcing: teacher forcing, CoT loss • CoT+consistency: end-to-end chain generation with self-consistency, CoT loss 20
  21. Takeaways • Training explicitly for CoT generation can provably improve

    performance on complex multi-step tasks • Controlling error accumulation via teacher forcing or self-consistency is key to ensuring effective step-by-step learning 21
  22. Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL

    and Distillation ICML 2025 Juno Kim† Denny Wu‡ Jason D. Lee§ Taiji Suzuki† †University of Tokyo, RIKEN AIP ‡NYU, Flatiron Institute §Princeton University
  23. Test time scaling • Pretraining → post-training paradigms: SFT, RLHF,

    preference optimization, distillation... • Reasoning capabilities drastically improve by allocating more compute during inference time, e.g. running search against a verifier or trained reward model (Jaech et al., 2024; Kimi et al., 2025; Snell et al., 2024; Wu et al., 2024; Guo et al., 2025) • The search trace can be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models (Zhang et al., 2024; Busbridge et al., 2025) How can the benefits of test time scaling methods be rigorously understood? 22
  24. Our contributions • • • • • • • •

    • • • • • • • distillation Xin Xout • Idea: model long CoT generation as a Markov chain over abstract reasoning states • Distinguish between easy/trivial reasoning steps (e.g. rearranging terms in an equation) and hard/crucial reasoning steps (e.g. applying an abstract theorem) • Task: find a path from Xin (problem statement) to Xout (conclusion or end-state, e.g. QED) 23
  25. Our contributions • • • • • • • •

    • • • • • • • distillation Xin Xout • We introduce a perturbed Markov chain model for CoT that differentiates easy/hard reasoning steps via a dense-sparse structure and study its metastable dynamics • Search based on intrinsic reward improves hitting times of target states by identifying key reasoning steps, whose generation can be enhanced by fine-tuning the base model with RL 24
  26. Our contributions • • • • • • • •

    • • • • • • • distillation Xin Xout • A compressed version of the CoT dynamics can be distilled to a smaller model by only learning macroscopic cluster transitions • In contrast, we show local search fails to solve a computational version of the path-finding task via a stricter SQ measure 25
  27. References i Busbridge, D., Shidani, A., Weers, F., Ramapuram, J.,

    Littwin, E., and Webb, R. (2025). Distillation scaling laws. arXiv preprint arXiv:2502.08606. Feng, G., Zhang, B., Gu, Y., Ye, H., He, D., and Wang, L. (2023). Towards revealing the mystery behind chain of thought: a theoretical perspective. In Advances in Neural Information Processing Systems. Goyal, K., Dyer, C., and Berg-Kirkpatrick, T. (2017). Differentiable scheduled sampling for credit assignment. In Association for Computational Linguistics. Guo, D., Yang, D., Zhang, H., Song, J., Zhang, R., Xu, R., Zhu, Q., Ma, S., Wang, P., Bi, X., et al. (2025). DeepSeek-R1: incentivizing reasoning capability in LLMs via reinforcement learning. arXiv preprint arXiv:2501.12948. Jaech, A., Kalai, A., Lerer, A., Richardson, A., El-Kishky, A., Low, A., Helyar, A., Madry, A., Beutel, A., Carney, A., et al. (2024). OpenAI o1 system card. arXiv preprint arXiv:2412.16720. Kimi, T., Du, A., Gao, B., Xing, B., Jiang, C., Chen, C., Li, C., Xiao, C., Du, C., Liao, C., et al. (2025). Kimi k1.5: scaling reinforcement learning with LLMs. arXiv preprint arXiv:2501.12599. Li, H., Wang, M., Lu, S., Cui, X., and Chen, P.-Y. (2024a). How do nonlinear transformers acquire generalization-guaranteed CoT ability? In High-dimensional Learning Dynamics 2024: The Emergence of Structure and Reasoning. Li, Y., Sreenivasan, K., Giannou, A., Papailiopoulos, D., and Oymak, S. (2023). Dissecting chain-of-thought: compositionality through in-context filtering and learning. In Advances in Neural Information Processing Systems. Li, Z., Liu, H., Zhou, D., and Ma, T. (2024b). Chain of thought empowers transformers to solve inherently serial problems. In International Conference on Learning Representations. 26
  28. References ii Merrill, W. and Sabharwal, A. (2023). A logic

    for expressing log-precision transformers. In Advances in Neural Information Processing Systems. Merrill, W. and Sabharwal, A. (2024). The expressive power of transformers with chain of thought. In International Conference on Learning Representations. Mihaylova, T. and Martins, A. F. T. (2019). Scheduled sampling for transformers. In Association for Computational Linguistics: Student Research Workshop. Snell, C., Lee, J., Xu, K., and Kumar, A. (2024). Scaling LLM test-time compute optimally can be more effective than scaling model parameters. arXiv preprint arXiv:2408.03314. Wen, K., Zhang, H., Lin, H., and Zhang, J. (2025). From sparse dependence to sparse attention: unveiling how chain-of-thought enhances transformer sample efficiency. In International Conference on Learning Representations. Wu, Y., Sun, Z., Li, S., Welleck, S., and Yang, Y. (2024). Inference scaling laws: An empirical analysis of compute-optimal inference for problem-solving with language models. arXiv preprint arXiv:2408.00724. Zhang, D., Zhoubian, S., Hu, Z., Yue, Y., Dong, Y., and Tang, J. (2024). ReST-MCTS*: LLM self-training via process reward guided tree search. arXiv preprint arXiv:2406.03816. 27