a neural network in the mean field regime - Training a transformer for in-context learning - Finetuning a generative model Mean field NN β’ Part 1: Convex β’ Part 2: Non-convex
Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025. β’ [Optimization of a probability measure on a strict saddle objecive] - Kim, Suzuki: Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML2024, oral. - Yamamto, Kim, Suzuki: Hessian-guided Perturbed Wasserstein Gradient Flows for Escaping Saddle Points. 2025. Part 2: Non-convex (strict-saddle objective) 1. Gaussian process perturbation: A polynomial time method to avoid a saddle point by a βrandom perturbationβ of a probability measure. 2. In-context learning: The objective to train a two-layer transformer is strict-saddle. Mean-field Langevin dynamics (πΉ π + πEnt(π)) β’ Optimization of probability measures by WG-flow β’ Particle approximation β’ Defective log-Sobolev inequality Part 1: Convex (propagation of chaos)
dynamics β’ Nitanda, Wu, Suzuki (AISTATS2022); Chizat (TMLR2022) Uniform-in-time propagation of chaos: β’ Super log-Sobolev inequality: Suzuki, Nitanda, Wu (ICLR2023) β’ Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. 2022. (arXiv:2212.03050) β’ Convergence analysis with finite particle/discrete time alg. : Suzuki, Nitanda, Wu (NeurIPS2023) (particle approximation) β’ Mean field Langevin dynamics: β’ Nitanda, Lee, Kai, Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025.
the probability measure space convex strictly convex = strictly convex + β’ β’ [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] π β β Linear with respect to π. Mean field neural network : Convex w.r.t. π! Mean field limit
of the loss function, it holds that Suppose that ππ satisfies log-Sobolev inequality with a constant πΌ. Prop [Chen, Ren Wang, 22][Suzuki, Wu, Nitanda, 23] (Existing result) (smoothness) π»πΏπΏ π πΏπ π₯ βπ»πΏπΏ π πΏπ π¦ β€ πΆ(π2 π, π + π₯ β π¦ ) and (boundedness) π»πΏπΏ π πΏπ π₯ β€ π . Assumption: [Suzuki, Wu, Nitanda: Convergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction. arXiv:2306.07221] [Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. arXiv:2212.03050, 2022.] β’However, πΆ can be like exp βπΆ π . (π(π): a joint distribution of π particles) Uniform in time !
smoothness condition, for any π΄, Theorem (our result) New bound [Nitanda, Lee, Kai, Sakaguchi, Suzuki, ICML2025]: The number of particles π΄ is independent of πΆ β exp(βπΆ(π )). Existing bound: 1 π2πΌπ [Chen, Ren, Wang, 2022]
β’ Wasserstein GF converges a critical point. [Second order optimality] β’ It can be stacked at a saddle point. β’ How to escape the saddle point? Wasserstein gradient flow β’ β’ Non-convex (no entropy)
18 β’ If this is an (π, πΏ)-stationary point, we may finish. β’ If not, how to escape the saddle point? Finite dimensional case: - Move to the min-eigenvalue direction of Hessian. [Agarwal et al. 2016; Carmon et al. 2016; Nesterov&Polyak 2006] - Random perturbation. [Jin et al. 2017; Li 2019] [Jin et al. 2017] How to perturb probability measures? (infinite dimensional objective) Random perturbation
a Gaussian process vector field: Then, perturb the distribution as This induces a small random perturbation of the distribution. (ππ > 0 is a small step size) (Hessian)
a (π, πΏ)-saddle point where π0 β πmin π»π+ < βπΏ and π = π πΏ2 . Then, for the GP-perturbation π βΌ πΊ(0, πΎπ+), let π0 = Id + ππ π # π+ as the initial point of W-GF, then it holds that with probability 1 β π: where [Proof overview] β’ The GP-perturbation π has a direction to the negative eigenvalue direction with a positive probability. β’ The negative curvature direction is exponentially amplified. β Escape from the saddle. Random perturbation
The Wasserstein GF ππ‘ around a critical point π+ can be written as id + ππ£π‘ #π+ where the velocity field π£π‘ follows Negative curvature direction exponentially grows up, if the initial point π£0 contains a component toward the minimal eigenvalue direction. The Gaussian process perturbation ensures the negative curvature component. (c.f., Otto calculus)
β is (π, πΏ, πΌ)-strict saddle if one of the following conditions hold for any π β π«2 βπ : Def ((π, πΏ, πΌ)-strict saddle) (2) (1) (3) Example: β’ In-context learning (Kim&Suzuki, 2024): β’ Matrix decomposition (recommendation system):
π = 20, π = 5, and 500 neurons with sigmoid act. All models are pre-trained using SGD on 10K prompts of 1K token pairs. 1. attention: jointly optimizes β(π, Ξ). 2. static: directly minimizes β(π). 3. modified: static model implementing birth-death & GP β verify global convergence as well as improvement for misaligned model (πtrue = 7) and nonlinear test tasks π π₯ = max πβ€π βπβ π₯ π or π π₯ = βπβ π₯ 2 .
Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025. β’ [Optimization of a probability measure on a strict saddle objecive] - Kim, Suzuki: Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML2024, oral. - Yamamto, Kim, Suzuki: Hessian-guided Perturbed Wasserstein Gradient Flows for Escaping Saddle Points. 2025. Part 2: Non-convex (strict-saddle objective) 1. Gaussian process perturbation: A polynomial time method to avoid a saddle point by a βrandom perturbationβ of a probability measure. 2. In-context learning: The objective to train a two-layer transformer is strict-saddle. Mean-field Langevin dynamics (πΉ π + πEnt(π)) β’ Optimization of probability measures by WG-flow β’ Particle approximation β’ Defective log-Sobolev inequality Part 1: Convex (propagation of chaos)