Upgrade to PRO for Only $50/Yearβ€”Limited-Time Offer! πŸ”₯

Convergence theory and application of distribut...

Avatar for Jia-Jie Zhu Jia-Jie Zhu
August 18, 2025
24

Convergence theory and application of distribution optimization: Non-convexity, particle approximation, and diffusionΒ models

Taiji Suzuki

ICSP 2025 invited session

Avatar for Jia-Jie Zhu

Jia-Jie Zhu

August 18, 2025
Tweet

More Decks by Jia-Jie Zhu

Transcript

  1. Convergence theory and application of distribution optimization: Non-convexity, particle approximation,

    and diffusion models 1 Taiji Suzuki The University of Tokyo / AIP-RIKEN 28th/July/2025 ICSP2025@Paris (Deep learning theory team)
  2. Probability measure optimization 2 Convex / Non-convex Application: - Training

    a neural network in the mean field regime - Training a transformer for in-context learning - Finetuning a generative model Mean field NN β€’ Part 1: Convex β€’ Part 2: Non-convex
  3. Presentation overview 3 β€’ [Propagation of chaos] Nitanda, Lee, Kai,

    Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025. β€’ [Optimization of a probability measure on a strict saddle objecive] - Kim, Suzuki: Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML2024, oral. - Yamamto, Kim, Suzuki: Hessian-guided Perturbed Wasserstein Gradient Flows for Escaping Saddle Points. 2025. Part 2: Non-convex (strict-saddle objective) 1. Gaussian process perturbation: A polynomial time method to avoid a saddle point by a β€œrandom perturbation” of a probability measure. 2. In-context learning: The objective to train a two-layer transformer is strict-saddle. Mean-field Langevin dynamics (𝐹 πœ‡ + πœ†Ent(πœ‡)) β€’ Optimization of probability measures by WG-flow β€’ Particle approximation β€’ Defective log-Sobolev inequality Part 1: Convex (propagation of chaos)
  4. Mean field Langevin 4 Linear convergence of mean field Langevin

    dynamics ➒ Nitanda, Wu, Suzuki (AISTATS2022); Chizat (TMLR2022) Uniform-in-time propagation of chaos: ➒ Super log-Sobolev inequality: Suzuki, Nitanda, Wu (ICLR2023) ➒ Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. 2022. (arXiv:2212.03050) ➒ Convergence analysis with finite particle/discrete time alg. : Suzuki, Nitanda, Wu (NeurIPS2023) (particle approximation) β€’ Mean field Langevin dynamics: β€’ Nitanda, Lee, Kai, Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025.
  5. Distribution optimization 5 Nonlinear convex functional Objective: Convex optimization on

    the probability measure space convex strictly convex = strictly convex + β€’ β€’ [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] 𝑀 β†’ ∞ Linear with respect to πœ‡. Mean field neural network : Convex w.r.t. 𝝁! Mean field limit
  6. Mean field Langevin 6 Mean field Langevin dynamics Def (first

    variation) Objective Distribution dependent convex strictly convex = strictly convex + [Hu et al. 2019][Nitanda, Wu, Suzuki, 2022][Chizat, 2022] [Nitanda&Suzuki, 2017][Chizat&Bach, 2018][Mei, Montanari&Nguyen, 2018][Rotskoff&Vanden-Eijnden, 2018] 𝐹 Gradient
  7. Entropy sandwich 9 Proximal Gibbs measure: Theorem (Entropy sandwich) [Nitanda,

    Wu, Suzuki (AISTATS2022)][Chizat (2022)] LSI of π‘πœ‡π‘‘ LSI of π‘πœ‡π‘‘ Assumption: Log-Sobolev inequality of π‘πœ‡ [Nitanda, Wu, Suzuki (AISTATS2022)][Chizat (2022)]
  8. Log-Sobolev inequality 10 𝐿(π‘₯) is πœ‡-strongly convex Theorem (Bakry-Emery criterion)

    [Bakry and Γ‰mery, 1985] 𝑝 satisfies LSI with Theorem (Holley-Stroock bounded perturbation lemma) [Holley and Stroock, 1987] β‡’ β€’ π‘ž satisfies LSI with 𝛼′ β€’ β„Ž ∞ ≀ 𝐡 𝑝 satisfies LSI with β‡’ Issue: 𝛼 can be easily exp(βˆ’O(𝑑)) (e.g., Gaussian mixture)
  9. Practical algorithm 11 β€’ Space discretization: πœ‡π‘‘ is approximatd by

    𝑀 particles πœ‡π‘‘ ≃ ΖΈ πœ‡π‘‘ = 1 𝑀 βˆ‘π›Ώ 𝑋𝑑 (𝑖) where (space discretization) 𝑀 particles 𝑋𝑑 𝑖 𝑖=1 𝑀 NaΓ―ve application of Gronwal’s inequality yields Error = Ξ©(exp 𝑑 /𝑀) (not uniform-in-time)
  10. Propagation of chaos 12 Space discr. Under smoothness and boundedness

    of the loss function, it holds that Suppose that π‘πœ‡ satisfies log-Sobolev inequality with a constant 𝛼. Prop [Chen, Ren Wang, 22][Suzuki, Wu, Nitanda, 23] (Existing result) (smoothness) 𝛻𝛿𝐿 πœ‡ π›Ώπœ‡ π‘₯ βˆ’π›»π›ΏπΏ 𝜈 π›Ώπœ‡ 𝑦 ≀ 𝐢(π‘Š2 πœ‡, 𝜈 + π‘₯ βˆ’ 𝑦 ) and (boundedness) 𝛻𝛿𝐿 πœ‡ π›Ώπœ‡ π‘₯ ≀ 𝑅. Assumption: [Suzuki, Wu, Nitanda: Convergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction. arXiv:2306.07221] [Chen, Ren, Wang. Uniform-in-time propagation of chaos for mean field Langevin dynamics. arXiv:2212.03050, 2022.] ➒However, 𝜢 can be like exp βˆ’π‘Ά 𝒅 . (πœ‡(𝑀): a joint distribution of 𝑀 particles) Uniform in time !
  11. Defective LSI and entropy sandwich 13 Theorem (Defective entropy sandwich)

    [Nitanda, Lee, Kai, Sakaguchi, Suzuki (ICML2025)] : LSI of conditional distribution. = 𝑂(1/𝑀2) : Bregman-divergence No LSI const.
  12. Defective LSI 14 (Fisher divergence) For any 𝑴, Under some

    smoothness condition, for any 𝑴, Theorem (our result) New bound [Nitanda, Lee, Kai, Sakaguchi, Suzuki, ICML2025]: The number of particles 𝑴 is independent of 𝜢 ≃ exp(βˆ’πšΆ(𝒅)). Existing bound: 1 πœ†2𝛼𝑀 [Chen, Ren, Wang, 2022]
  13. Escaping from saddle point 15 [Yamamto, Kim, Suzuki: Hessian-guided Perturbed

    Wasserstein Gradient Flows for Escaping Saddle Points. 2025]
  14. Non-convex objective 16 (Discrete time dynamics) (Continuous time dynamics) Convergence?

    β€’ Wasserstein GF converges a critical point. [Second order optimality] β€’ It can be stacked at a saddle point. β€’ How to escape the saddle point? Wasserstein gradient flow β€’ β€’ Non-convex (no entropy)
  15. 2nd order stationary point/ Saddle point 17 Second order derivative:

    for Def ((πœ€, 𝛿)-second-order stationary point) Assumption: ΰ·© π»πœ‡ ∞ ≀ 𝑅 βˆ‡ 𝛿𝐹 π›Ώπœ‡ 𝐿2(πœ‡) Def ((πœ€, 𝛿)-saddler point) (note that, when βˆ‡ 𝛿𝐹 π›Ώπœ‡ = 0, then ΰ·© π»πœ‡ = 0)
  16. Escape from saddle β€’ W-GF convergences to a critical point:

    18 ➒ If this is an (πœ–, 𝛿)-stationary point, we may finish. ➒ If not, how to escape the saddle point? Finite dimensional case: - Move to the min-eigenvalue direction of Hessian. [Agarwal et al. 2016; Carmon et al. 2016; Nesterov&Polyak 2006] - Random perturbation. [Jin et al. 2017; Li 2019] [Jin et al. 2017] How to perturb probability measures? (infinite dimensional objective) Random perturbation
  17. Gaussian process perturbation 19 Let the β€œkernel” function be Generate

    a Gaussian process vector field: Then, perturb the distribution as This induces a small random perturbation of the distribution. (πœ‚π‘ > 0 is a small step size) (Hessian)
  18. Escape from saddle point 20 Proposition Suppose that πœ‡+ is

    a (πœ–, 𝛿)-saddle point where πœ†0 ≔ πœ†min π»πœ‡+ < βˆ’π›Ώ and πœ– = 𝑂 𝛿2 . Then, for the GP-perturbation πœ‰ ∼ 𝐺(0, πΎπœ‡+), let πœ‡0 = Id + πœ‚π‘ πœ‰ # πœ‡+ as the initial point of W-GF, then it holds that with probability 1 βˆ’ 𝜁: where [Proof overview] β€’ The GP-perturbation πœ‰ has a direction to the negative eigenvalue direction with a positive probability. β€’ The negative curvature direction is exponentially amplified. β‡’ Escape from the saddle. Random perturbation
  19. Algorithm 21 At a 1st order stationary point, apply the

    Gaussian process perturbation. Check if it was a saddle point. If not, halt the algorithm.
  20. Proof outline 22 Let the Hessian at πœ‡ be Lemma

    The Wasserstein GF πœ‡π‘‘ around a critical point πœ‡+ can be written as id + πœ–π‘£π‘‘ #πœ‡+ where the velocity field 𝑣𝑑 follows Negative curvature direction exponentially grows up, if the initial point 𝑣0 contains a component toward the minimal eigenvalue direction. The Gaussian process perturbation ensures the negative curvature component. (c.f., Otto calculus)
  21. 23 KL-decomposition: (ONS in 𝐿2(πœ‡)) (𝑍𝑗 ∼ 𝑁(0,1)) Then, it

    holds that Let a β€œshifted” initial-point be Then, they diverge and one of them should be out of the neighbor:
  22. Global optimality for strict saddle 25 𝐹: 𝒫2 ℝ𝑑 β†’

    ℝ is (πœ–, 𝛿, 𝛼)-strict saddle if one of the following conditions hold for any πœ‡ ∈ 𝒫2 ℝ𝑑 : Def ((πœ€, 𝛿, 𝛼)-strict saddle) (2) (1) (3) Example: β€’ In-context learning (Kim&Suzuki, 2024): β€’ Matrix decomposition (recommendation system):
  23. Global optimality for strict saddle 26 𝐹: 𝒫2 ℝ𝑑 β†’

    ℝ is (πœ–, 𝛿, 𝛼)-strict saddle if one of the following conditions hold for any πœ‡ ∈ 𝒫2 ℝ𝑑 : Def ((πœ€, 𝛿, 𝛼)-strict saddle) (2) (1) (3) Theorem Suppose that 𝐹: 𝒫2 ℝ𝑑 β†’ ℝ is (πœ–, 𝛿, 𝛼)-strict saddle. Then, after 𝑇 = ΰ·© O 1 πœ–2 + 1 𝛿4 time, the solution achieves π‘Š2 πœ‡ 𝑇 , πœ‡βˆ— ≀ 𝛼, for a global optima πœ‡βˆ—.
  24. Numerical experiment (in-context learning) 29 We compare 3 models with

    𝑑 = 20, π‘˜ = 5, and 500 neurons with sigmoid act. All models are pre-trained using SGD on 10K prompts of 1K token pairs. 1. attention: jointly optimizes β„’(πœ‡, Ξ“). 2. static: directly minimizes β„’(πœ‡). 3. modified: static model implementing birth-death & GP β†’ verify global convergence as well as improvement for misaligned model (π‘˜true = 7) and nonlinear test tasks 𝑔 π‘₯ = max π‘—β‰€π‘˜ β„Žπœ‡βˆ˜ π‘₯ 𝑗 or 𝑔 π‘₯ = β„Žπœ‡βˆ˜ π‘₯ 2 .
  25. Presentation overview 30 β€’ [Propagation of chaos] Nitanda, Lee, Kai,

    Sakaguchi, Suzuki: Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble. ICML2025. β€’ [Optimization of a probability measure on a strict saddle objecive] - Kim, Suzuki: Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML2024, oral. - Yamamto, Kim, Suzuki: Hessian-guided Perturbed Wasserstein Gradient Flows for Escaping Saddle Points. 2025. Part 2: Non-convex (strict-saddle objective) 1. Gaussian process perturbation: A polynomial time method to avoid a saddle point by a β€œrandom perturbation” of a probability measure. 2. In-context learning: The objective to train a two-layer transformer is strict-saddle. Mean-field Langevin dynamics (𝐹 πœ‡ + πœ†Ent(πœ‡)) β€’ Optimization of probability measures by WG-flow β€’ Particle approximation β€’ Defective log-Sobolev inequality Part 1: Convex (propagation of chaos)