Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Anna Korba (ENSAE & CREST, Palaiseau, France) S...

Jia-Jie Zhu
March 15, 2024
260

Anna Korba (ENSAE & CREST, Palaiseau, France) Sampling through Optimization of Discrepancies

WORKSHOP ON OPTIMAL TRANSPORT
FROM THEORY TO APPLICATIONS
INTERFACING DYNAMICAL SYSTEMS, OPTIMIZATION, AND MACHINE LEARNING
Venue: Humboldt University of Berlin, Dorotheenstraße 24

Berlin, Germany. March 11th - 15th, 2024

Jia-Jie Zhu

March 15, 2024
Tweet

More Decks by Jia-Jie Zhu

Transcript

  1. Sampling through Optimization of Divergences Anna Korba ENSAE, CREST, Institut

    Polytechnique de Paris OT Berlin 2024 Joint work with many people cited on the flow.
  2. Outline 1 Introduction 2 Sampling as Optimization 3 Choice of

    the Divergence 4 Optimization error 5 Quantization error 6 Further connections with Optimization
  3. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Why sampling? Suppose you are interested in some target probability distribution on Rd , denoted µ∗, and you have access only to partial information, e.g.: 1 its unnormalized density (as in Bayesian inference) 2 a discrete approximation 1 m m k=1 δxi ≈ µ∗ (e.g. i.i.d. samples, iterates of MCMC algorithms...) Problem: approximate µ∗ ∈ P(Rd ) by a finite set of n points x1, . . . , xn , e.g. to compute functionals Rd f (x)dµ∗(x). The quality of the set can be measured by the integral error: 1 n n i=1 f (xi ) − Rd f (x)dµ∗(x) . a Gaussian density I.i.d. samples. Particle scheme (SVGD). 1 / 33
  4. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Example 1: Bayesian inference We want to sample from the posterior distribution µ∗(x) ∝ exp (−V (x)) , V (x) = m i=1 ∥yi − g(wi , x)∥2 loss on labeled data (wi , yi )m i=1 + ∥x∥2 2 . Ensemble prediction for a new input w: ˆ y = Rd g(w, x)dµ∗(x) ”Bayesian model averaging” Predictions of models parametrized by x ∈ Rd are reweighted by µ∗(x). 2 / 33
  5. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization (Some, Non parametric, Unconstrained) Sampling methods (1) Markov Chain Monte Carlo (MCMC) methods: generate a Markov chain in Rd whose law converges to µ∗ ∝ exp(−V ) Example: Langevin Monte Carlo (LMC) [Roberts and Tweedie, 1996] xt+1 = xt − γ∇V (xt ) + 2γϵt, ϵt ∼ N(0, Id Rd ). Picture from https://chi-feng.github.io/mcmc-demo/app.html. 3 / 33
  6. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization (2) Interacting particle systems, whose empirical measure at stationarity approximates µ∗ ∝ exp(−V ) Example: Stein Variational Gradient Descent (SVGD) [Liu and Wang, 2016] xi t+1 = xi t − γ N N j=1 ∇V (xj t )k(xi t , xj t ) − ∇2 k(xi t , xj t ), i = 1, . . . , N. where k : Rd × Rd → R+ is a kernel (e.g. k(x, y) = exp −∥x − y∥2 ). Picture from https://chi-feng.github.io/mcmc-demo/app.html. 4 / 33
  7. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Difficult cases (in practice and in theory) µ∗(x) ∝ exp (−V (x)) , V (x) = m i=1 ∥yi − g(wi , x)∥2 loss + ∥x∥2 2 . µ∗ = arg min µ KL(µ|µ∗) if V is convex (e.g. g(w, x) = ⟨w, x⟩), these methods are known to work quite well [Durmus and Moulines, 2016, Vempala and Wibisono, 2019] but if its not (e.g. g(w, x) is a neural network), the situation is much more delicate A highly nonconvex loss surface, as is common in deep neural nets. From https://www.telesens.co/2019/01/16/neural-network-loss-visualization. 5 / 33
  8. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Example 2: Thinning (Postprocessing of MCMC output) In an ideal world we would be able to post-process the MCMC output and keep only those states that are representative of the posterior µ∗. Picture from Chris Oates. Fix problems with MCMC (automatic identification of burn-in; number of points proportional to the probability mass in a region; etc.) Compressed representation of the posterior, to reduce any downstream computational load. Idea: minimize a divergence from the distribution of the states to µ∗ [Korba et al., 2021]: µn = arg min µ KSD(µ|µ∗) 6 / 33
  9. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Example 3 : Regression with infinite width shallow NN min (xi )n i=1 ∈Rd E(w,y)∼Pdata y − 1 n n i=1 ϕxi (w) ˆ y 2 − − − − → n→∞ min µ∈P(Rd ) E(w,y)∼Pdata y − Rd ϕx (w)dµ(x) 2 F(µ) Optimising the neural network ⇐⇒ approximating µ∗ ∈ arg min F(µ) [Chizat and Bach, 2018, Mei et al., 2018] If y(w) = 1 m m i=1 ϕxi (w) is generated by a neural network (as in the student-teacher network setting), then µ∗ = 1 m m i=1 δxm and F can be identified to an MMD [Arbel et al., 2019]: min µ Ew∼Pdata ∥yµ∗ (w) − yµ(w)∥2 = MMD2(µ, µ∗), k(x, x′) = Ew∼Pdata [ϕx′ (w)T ϕx (w)]. 7 / 33
  10. Outline 1 Introduction 2 Sampling as Optimization 3 Choice of

    the Divergence 4 Optimization error 5 Quantization error 6 Further connections with Optimization
  11. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Sampling as optimization over probability distributions Assume that µ∗ ∈ P2 (Rd ) = µ ∈ P(Rd ), ∥x∥2dµ(x) < ∞ . The sampling task can be recast as an optimization problem: µ∗ = arg min µ∈P2(Rd ) D(µ|µ∗) := F(µ), where D is a discrepancy, for instance: a f-divergence: f µ µ∗ dµ∗, f convex, f (1) = 0 an integral probability metric: supf ∈G fdµ − fdµ∗ an optimal transport distance... Starting from an initial distribution µ0 ∈ P2 (Rd ), one can then consider the Wasserstein-2∗ gradient flow of F over P2 (Rd ) to transport µ0 to µ∗. ∗W 2 2 (ν, µ) = infs∈Γ(ν,µ) Rd ×Rd ∥x − y∥2 ds(x, y), where Γ(ν, µ)= couplings between ν, µ. 8 / 33
  12. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Wasserstein gradient flows (WGF) [Ambrosio et al., 2008] The first variation of µ → F(µ) evaluated at µ ∈ P(Rd ) is the unique function ∂F(µ) ∂µ : Rd → R s. t. for any µ, ν ∈ P(Rd ), ν − µ ∈ P(Rd ): lim ϵ→0 1 ϵ (F(µ + ϵ(ν − µ)) − F(µ)) = Rd ∂F(µ) ∂µ (x)(dν − dµ)(x). The family µ : [0, ∞] → P2 (Rd ), t → µt is a Wasserstein gradient flow of F if: ∂µt ∂t = ∇ · (µt ∇W2 F(µt )) , where ∇W2 F(µ) := ∇∂F(µ) ∂µ denotes the Wasserstein gradient of F. It can be implemented by the deterministic process: dxt dt = −∇W2 F(µt )(xt ), where xt ∼ µt 9 / 33
  13. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Particle system/Gradient descent approximating the WGF Space/time discretization : Introduce a particle system x1 0 , . . . , xn 0 ∼ µ0 , a step-size γ, and at each step: xi l+1 = xi l − γ∇W2 F(ˆ µl )(xi l ) for i = 1, . . . , n, where ˆ µl = 1 n n i=1 δxi l . In particular, if F is well-defined for discrete measures, the algorithm above simply corresponds to gradient descent of F : RN×d → R, F(x1, . . . , xN ) := F(µN ) where µN = 1 N N i=1 δxi . We consider several questions: what can we say as time goes to infinity ? (optimization error) =⇒ heavily linked with the geometry (convexity, smoothness in the Wasserstein sense) of the loss (for minimizers) what can we say as the number of particles grow ? (quantization error) 10 / 33
  14. Outline 1 Introduction 2 Sampling as Optimization 3 Choice of

    the Divergence 4 Optimization error 5 Quantization error 6 Further connections with Optimization
  15. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Loss function for the unnormalized densities - the KL Recall that we want to minimize F = D(·|µ∗). Which D can we choose? For instance, D could be the Kullback-Leibler divergence: KL(µ|µ∗) = Rd log µ µ∗ (x) dµ(x) if µ ≪ µ∗ +∞ otherwise. The KL as an objective is convenient when the unnormalized density of µ∗ is known since it does not depend on the normalization constant! Indeed writing µ∗(x) = e−V (x)/Z we have: KL(µ|µ∗) = Rd log µ e−V (x) dµ(x) + log(Z). But, it is not convenient when µ or µ∗ are discrete, because the KL is +∞ unless supp(µ) ⊂ supp(µ∗). 11 / 33
  16. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization KL Gradient flow in practice The gradient flow of the KL can be implemented via the Probability Flow (ODE): d ˜ xt = −∇ log µt µ∗ (˜ xt )dt (1) or the Langevin diffusion (SDE): dxt = ∇ log µ∗(xt )dt + √ 2dBt (2) (they share the same marginals (µt )t≥0 ) (2) can be discretized in time as Langevin Monte Carlo (LMC) xm+1 = xm + γ∇ log µ∗(xm ) + 2γϵm, ϵm ∼ N(0, Id Rd ). (1) can be approximated by a particle system (e.g. Stein Variational Gradient Descent [Liu, 2017, He et al., 2022]) however MCMC methods suffer an integral approximation error of order O(n−1/2) if we use µn = 1 n n i=1 δxi (xi iterates of MCMC) [Latuszy´ nski et al., 2013], and for SVGD we don’t know [Xu et al., 2022]. 12 / 33
  17. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Another f-divergence? Consider the chi-square (CS) divergence: χ2(µ|µ∗) := dµ dµ∗ − 1 2 dµ∗ µ ≪ µ∗ +∞ otherwise. It is not convenient neither when µ, µ∗ are discrete χ2-gradient requires the normalizing constant of µ∗: ∇ µ µ∗ However, the GF of χ2 has interesting properties KL decreases exp. fast along CS flow/χ2 decreases exp. fast along KL flow if µ∗ satisfies Poincar´ e we have χ2(µ|µ∗) ≥ KL(µ|µ∗). =⇒ distinguishing whether KL or χ2 GF is more favorable is an active area of research† †see [Chewi et al., 2020, Craig et al., 2022] for a discussion, results from [Matthes et al., 2009, Dolbeault et al., 2007] 13 / 33
  18. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization 14 / 33
  19. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Losses for the discrete case When we have a discrete approximation of µ∗, it is convenient to choose D as an integral probability metric (to approximate integrals). For instance, D could be the MMD (Maximum Mean Discrepancy): MMD2(µ, µ∗) = sup f ∈Hk ,∥f ∥Hk ≤1 fdµ − fdµ∗ = ∥mµ − mµ∗ ∥2 Hk , where mµ = k(x, ·)dµ(x) = Rd k(x, y)dµ(x)dµ(y) + Rd k(x, y)dµ∗(x)dµ∗(y) − 2 Rd k(x, y)dµ(x)dµ∗(y). where k : Rd × Rd → R is a p.s.d. kernel (e.g. k(x, y) = e−∥x−y∥2 ) and Hk is the RKHS associated to k‡. ‡Hk = m i=1 αi k(·, xi ); m ∈ N; α1, . . . , αm ∈ R; x1, . . . , xm ∈ Rd . 14 / 33
  20. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Why we care about the loss SVGD MMD Figure: Toy example with 2D standard Gaussian. The green points represent the initial positions of the particles. The light grey curves correspond to their trajectories. Gradient flow of the KL to a Gaussian µ∗(x) ∝ e− ∥x∥2 2 is well-behaved, but not the MMD. Question: is there an IPM (integral probability metric) that enjoys a better behavior? 15 / 33
  21. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Variational formula of f-divergences Recall that f -divergences write D(µ|µ∗) = f µ µ∗ dµ∗, f convex, f (1) = 0. They admit a variational form [Nguyen et al., 2010]: D(µ|µ∗) = sup h:Rd →R hdµ − f ∗(h)dµ∗ where f ∗(y) = supx ⟨x, y⟩ − f (x) is the convex conjugate (or Legendre transform) of f and h measurable. Examples: KL(µ|µ∗): f (x) = x log(x) − x + 1 , f ∗(y) = ey − 1 χ2(µ|µ∗): f (x) = (x − 1)2, f ∗(y) = y + 1 4 y2 16 / 33
  22. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization A proposal§: Interpolate between MMD and χ2 ”De-Regularized MMD” leverages the variational formulation of χ2: DMMD(µ||µ∗) = (1 + λ) max h∈Hk hdµ − (h + 1 4 h2)dµ∗ − 1 4 λ∥h∥2 Hk (3) It is a divergence for any λ, recovers χ2 for λ = 0 and MMD for λ = +∞. DMMD and its gradient can be written in closed-form, in particular if µ, µ∗ are discrete (depends on λ and kernel matrices over samples of µ, µ∗): DMMD(µ||µ∗) = (1 + λ) (Σµ∗ + λ Id)− 1 2 (mµ − mµ∗ ) 2 Hk , ∇ DMMD(µ||µ∗) = ∇h∗ µ,µ∗ where Σµ∗ = k(·, x) ⊗ k(·, x)dµ∗(x), where (a ⊗ b)c = ⟨b, c⟩Hk a; and h∗ µ,µ∗ solves (3). Complexity: O(M3 + NM) for µ∗, µ supported on M, N atoms, can be decreased to O(M + N) with random features. A similar idea was proposed for the KL, yielding Kale divergence [Glaser et al., 2021] but was not closed-form. §with H. Chen, A. Gretton, P. Glaser (UCL), A. Mustafi, B. Sriperumbudur (CMU) 17 / 33
  23. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Several interpretations of DMMD DMMD can be seen as: A reweighted χ2-divergence: for µ ≪ π DMMD(µ∥π) = (1 + λ) i≥1 ϱi ϱi + λ dµ dπ − 1, ei 2 L2(π) , where (ρi , ei ) is the eigendecomposition of Tπ : f ∈ L2(π) → k(x, ·)f (x)dπ(x) ∈ L2(π). An MMD with the kernel: ˜ k(x, x′) = i≥1 ϱi ϱi + λ ei (x)ei (x′) which is a regularized version of the original kernel k (x, x′) = i≥1 ϱi ei (x)ei (x′). 18 / 33
  24. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Ring Experiment MMD T=0 T=2 T=30 T=99 KALE DMMD λ = 10−5 λ = 10−4 λ = 10−3 λ = 10−2 λ = 10−1 λ = 100 λ = 102 λ = ∞ 0 20000 40000 60000 80000 100000 Iteration 10−2 10−1 100 W2 (·|π) 19 / 33
  25. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Student-teacher networks experiment¶ λ = 10−5 λ = 10−4 λ = 10−3 λ = 10−2 λ = 10−1 λ = 100 λ = 102 λ = ∞ DMMD MMD DMMD (Noise) MMD (Noise) 0 2500 5000 7500 10000 12500 15000 Iteration 10−4 10−3 10−2 10−1 Validation MMD2 0 2500 5000 7500 10000 12500 15000 Iteration 10−4 10−3 10−2 10−1 Validation MMD2 the teacher network w → yµ∗ (w) is given by M particles (ξ1, ..., ξM ) which are fixed during training =⇒ µ∗ = 1 M M j=1 δξj the student network w → yµ (w) has n particles (x1, ..., xn ) that are initialized randomly =⇒ µ = 1 n n i=1 δxj min µ Ew∼Pdata (yµ∗ (w) − yµ (w)2 ⇐⇒ min µ MMD(µ, µ∗) with k(x, x′) = Ew∼Pdata [ϕx′ (w)ϕx (w)]. ¶Same setting as [Arbel et al., 2019]. 20 / 33
  26. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Another idea - ”Mollified” discrepancies [Li et al., 2022a]‖ What if we don’t have access to samples of µ∗? (recall that DMMD involves an integral over µ∗) Choose a mollifiers/kernels (Gaussian, Laplace, Riesz-s): kg ϵ (x) := exp −∥x∥2 2 2ϵ2 Zg (ϵ) , kg ϵ (x) := exp −∥x∥2 ϵ Zl (ϵ) , ks ϵ (x) := 1 (∥x∥2 2 + ϵ2)s/2Zr (s, ϵ) Mollified chi-square (differs from χ2(kϵ ⋆ µ|µ∗) as in [Craig et al., 2022]): Eϵ (µ) = kϵ (x − y)(µ∗(x)µ∗(y))−1/2µ(x)µ(y) dx dy = kϵ ∗ µ √ µ∗ (x) µ √ µ∗ (x) dx − − − → ε→0 χ2(µ|µ∗) + 1 It writes as an interaction energy, allowing to consider µ discrete and µ∗ with a density. ‖Sampling with mollified interaction energy descent. Li, L., Liu, Q., Korba, A., Yurochkin, M., and Solomon, J. (ICLR 2023). 21 / 33
  27. Outline 1 Introduction 2 Sampling as Optimization 3 Choice of

    the Divergence 4 Optimization error 5 Quantization error 6 Further connections with Optimization
  28. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Background on convexity and smoothness in Rd Recall that if f : Rd → R is twice differentiable, f is λ-convex ∀x, y ∈ Rd , t ∈ [0, 1] : f (tx + (1 − t)y) ≤ tf (x) + (1 − t)f (y) − λ 2 t(1 − t)∥x − y∥2 ⇐⇒ vT ∇f (x)v ≥ λ∥v∥2 2 ∀x, v ∈ Rd . f is M-smooth ∥∇f (x) − ∇f (y)∥ ≤ M∥x − y∥ ∀x, y ∈ Rd ⇐⇒ vT ∇f (x)v ≤ M∥v∥2 2 ∀x, v ∈ Rd . 22 / 33
  29. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization (Geodesically)-convex and smooth losses F is said to be λ-displacement convex if along W2 geodesics (ρt )t∈[0,1] : F(ρt ) ≤ (1 − t)F(ρ0 ) + tF(ρ1 ) − λ 2 t(1 − t)W 2 2 (ρ0, ρ1 ) ∀ t ∈ [0, 1]. The Wasserstein Hessian of a functional F : P2 (Rd ) → R at µ is defined for any ψ ∈ C∞ c (Rd ) as: Hessµ F(ψ, ψ) := d2 dt2 t=0 F(µt ) where (µt , vt )t∈[0,1] is a Wasserstein geodesic with µ0 = 0, v0 = ∇ψ. F is λ-displacement convex ⇐⇒ Hessµ F(ψ, ψ) ≥ λ∥∇ψ∥2 L2(µ) (See [Villani, 2009, Proposition 16.2]). In an analog manner we can define smooth functionals as functionals with upper bounded Hessians. 23 / 33
  30. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Guarantees for Wasserstein gradient descent Consider Wasserstein gradient descent (Euler discretization of Wasserstein gradient flow) µl+1 = (Id −γ∇F′(µl ))# µl Assume F is M-smooth. Then, we have a descent lemma (if γ < 2 M ): F(µl+1 ) − F(µl ) ≤ −γ 1 − γ 2 M ∥∇F′(µl )∥2 L2(µl ) . Moreover, if F is λ-convex, we have the global rate F(µL ) ≤ W 2 2 (µ0, µ∗) 2γL − λ L L l=0 W 2 2 (µl , µ∗). (so the barrier term degrades with λ). 24 / 33
  31. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Some examples Let µ∗ ∝ e−V , we have [Wibisono, 2018] Hessµ KL(ψ, ψ) = ⟨HV (x)∇ψ(x), ∇ψ(x)⟩ + ∥Hψ(x)∥2 HS µ(x) dx. If V is m-strongly convex, then the KL is m-geo. convex; however it is not smooth (Hessian is unbounded wrt ∥∇ψ∥2 L2(µ) ). Similar story for χ2-square [Ohta and Takatsu, 2011]. For a M-smooth kernel k [Arbel et al., 2019] Hessµ MMD2(ψ, ψ) = ∇ψ(x)⊤∇1 ∇2 k(x, y)∇ψ(y)dµ(x)dµ(y)+ 2 ∇ψ(x)⊤ H1 k (x, z) dµ(z) − H1 k (x, z) dµ∗(z) ∇ψ(x)dµ(x) It is M-smooth but not geodesically convex (Hessian lower bounded by a big negative constant) 25 / 33
  32. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization For DMMD (interpolating between χ2 and MMD), for µ∗ ∝ e−V . If V is m-strongly convex, for λ small enough, we can lower bound Hessµ DMMD(µ||µ∗) by a positive constant times ∥∇ψ∥2 L2(µ) , and obtain: Th1, informal: for step size γ small enough (depending on λ, k) we get a O(1/L) rate Th2, informal: we can obtain a linear O(e−L) rate if we have a lower bound on the density ratios and a source condition (µ π ∈ Ran(T r π ), 0 < r ≤ 1 2 ) Idea: 1 We can write Hessian of χ2 Hessµ χ2(µ∥µ∗) = µ(x)2 µ∗(x) (Lµ∗ ψ(x))2dx + µ(x)2 µ∗(x) ⟨HV (x)∇ψ(x), ∇ψ(x)⟩ dx + µ(x)2 µ∗(x) ∥Hψ(x)∥2 HS dx where Lµ∗ is the Langevin diffusion Lµ∗ ψ = ⟨∇V (x), ∇ψ(x)⟩ − ∆ψ(x). 2 DMMD(µ∥π) = (1 + λ) i≥1 ϱi ϱi +λ dµ dπ − 1, ei 2 L2(π) , where (ρi , ei ) eigendecomposition of Tπ : f ∈ L2(π) → k(x, ·)f (x)dπ(x) ∈ L2(π) 26 / 33
  33. Outline 1 Introduction 2 Sampling as Optimization 3 Choice of

    the Divergence 4 Optimization error 5 Quantization error 6 Further connections with Optimization
  34. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Recent results For smooth and bounded kernels in [Xu et al., 2022] and µ∗ with exponential tails, we get using Koksma-Hlawka inequality min µn MMD(µn, µ∗) ≤ Cd (log n)5d+1 2 n . This bounds the integral error for f ∈ Hk (by Cauchy-Schwartz): Rd f (x)dµ∗(x) − Rd f (x)dµ(x) ≤ ∥f ∥Hk MMD(µ, π). we can apply these results to DMMD which is a regularized MMD with kernel ˜ k, replacing Cd by Cd λ . 27 / 33
  35. Outline 1 Introduction 2 Sampling as Optimization 3 Choice of

    the Divergence 4 Optimization error 5 Quantization error 6 Further connections with Optimization
  36. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization More ideas can be borrowed to optimization (but there are limitations) Sampling with inequality constraints [Liu et al., 2021, Li et al., 2022b] min µ∈P2(Rd ) KL(µ∥µ∗) subject to Ex∼µ g(x) ≤ 0 Bilevel sampling ∗∗ min θ∈Rp ℓ(θ) := min θ∈Rp F(µ∗(θ)) where for instance µ∗(θ) is a Gibbs distribution, minimizing the KL µ∗(θ)[x] = exp(−V (x, θ))/Zθ . µ∗(θ) is the output of a Diffusion model parametrized by θ, this does not minimize a divergence on P(Rd ) ∗∗with P. Marion, Q. Berthet, P. Bartlett, M. Blondel, V. Bortoli, A. Doucet, F. Llinares-Lopez, C. Paquette 28 / 33
  37. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization A numerical example from [Li et al., 2022a] Figure: Sampling from the von Mises-Fisher distribution obtained by constraining a 3-dimensional Gaussian to the unit sphere. The unit-sphere constraint is enforced using the dynamic barrier method and the shown results are obtained using MIED with Riesz kernel and s = 3. The six plots are views from six evenly spaced angles. 29 / 33
  38. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization A numerical example from [Li et al., 2022a] −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 −1 0 1 −1.0 −0.5 0.0 0.5 1.0 Uniform sampling of the region {(x, y) ∈ [−1, 1]2 : (cos(3πx) + cos(3πy))2 < 0.3} using MIED with a Riesz mollifier (s = 3) where the constraint is enforced using the dynamic barrier method. 30 / 33
  39. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization IV - Fairness Bayesian neural network Given a dataset D = {w(i), y(i), z(i)}|D| i=1 consist of features w(i), labels y(i) (whether the income is ≥ $50, 000), and genders z(i) (protected attribute), we set the target density to be the posterior of a logistic regression using a 2-layer Bayesian neural network ˆ y(·; x). Given t > 0, the fairness constraint is g(x) = (cov(w,y,z)∼D [z, ˆ y(w; x)])2 − t ≤ 0. 0.83 0.84 Test Accuracy 0.02 0.04 0.06 0.08 0.10 0.12 0.14 0.16 Demographic Parity MIED (Ours) Control+Langevin Primal-Dual+Langevin Primal-Dual+SVGD Other methods come from [Liu et al., 2021]. 31 / 33
  40. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Open questions, directions Finite-particle/quantization guarantees are still missing for many losses (e.g. mollified chi-square) D(µn ||µ∗) ≤ error(n, µ∗)? How to improve the performance of the algorithms for highly non-log concave targets? e.g. through sequence of targets (µ∗)t∈[0,1] interpolating between µ0 and µ∗? Shape of the trajectories? change the underlying metric and consider Wc gradient flows 32 / 33
  41. Introduction Sampling as Optimization Choice of the Divergence Optimization error

    Quantization error Further connections with Optimization Main references (code available): Maximum Mean Discrepancy Gradient Flow. Arbel, M., Korba, A., Salim, A., and Gretton, A. (Neurips 2019). Accurate quantization of measures via interacting particle-based optimization. Xu, L., Korba, A., and Slepcev, D. (ICML 2022). Sampling with mollified interaction energy descent. Li, L., Liu, Q., Korba, A., Yurochkin, M., and Solomon, J. (ICLR 2023). (De)-regularized Maximum Mean Discrepancy Gradient Flow. Chen, H., Mustafi, A., Glaser, P., Korba, A., Gretton, A., Sriperumbudur, B. (Submitted 2024) 33 / 33
  42. Quantization error References I Ambrosio, L., Gigli, N., and Savar´

    e, G. (2008). Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media. Arbel, M., Korba, A., Salim, A., and Gretton, A. (2019). Maximum mean discrepancy gradient flow. In Advances in Neural Information Processing Systems, pages 6481–6491. Chewi, S., Le Gouic, T., Lu, C., Maunu, T., and Rigollet, P. (2020). Svgd as a kernelized wasserstein gradient flow of the chi-squared divergence. Advances in Neural Information Processing Systems, 33:2098–2109. Chizat, L. and Bach, F. (2018). On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31. Chopin, N., Crucinio, F. R., and Korba, A. (2023). A connection between tempering and entropic mirror descent. arXiv preprint arXiv:2310.11914. Craig, K., Elamvazhuthi, K., Haberland, M., and Turanova, O. (2022). A blob method method for inhomogeneous diffusion with applications to multi-agent control and sampling. arXiv preprint arXiv:2202.12927. Dolbeault, J., Gentil, I., Guillin, A., and Wang, F.-Y. (2007). Lq-functional inequalities and weighted porous media equations. arXiv preprint math/0701037. Durmus, A. and Moulines, E. (2016). Sampling from strongly log-concave distributions with the unadjusted langevin algorithm. arXiv preprint arXiv:1605.01559, 5. 1 / 9
  43. Quantization error References II Glaser, P., Arbel, M., and Gretton,

    A. (2021). Kale flow: A relaxed kl gradient flow for probabilities with disjoint support. Advances in Neural Information Processing Systems, 34:8018–8031. He, Y., Balasubramanian, K., Sriperumbudur, B. K., and Lu, J. (2022). Regularized stein variational gradient flow. arXiv preprint arXiv:2211.07861. Kloeckner, B. (2012). Approximation by finitely supported measures. ESAIM: Control, Optimisation and Calculus of Variations, 18(2):343–359. Korba, A., Aubin-Frankowski, P.-C., Majewski, S., and Ablin, P. (2021). Kernel stein discrepancy descent. arXiv preprint arXiv:2105.09994. Latuszy´ nski, K., Miasojedow, B., and Niemiro, W. (2013). Nonasymptotic bounds on the estimation error of mcmc algorithms. Bernoulli, 19(5A):2033–2066. Li, J. and Barron, A. (1999). Mixture density estimation. Advances in neural information processing systems, 12. Li, L., Liu, Q., Korba, A., Yurochkin, M., and Solomon, J. (2022a). Sampling with mollified interaction energy descent. arXiv preprint arXiv:2210.13400. 2 / 9
  44. Quantization error References III Li, R., Tao, M., Vempala, S.

    S., and Wibisono, A. (2022b). The mirror langevin algorithm converges with vanishing bias. In International Conference on Algorithmic Learning Theory, pages 718–742. PMLR. Liu, Q. (2017). Stein variational gradient descent as gradient flow. In Advances in neural information processing systems, pages 3115–3123. Liu, Q. and Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. In Advances in neural information processing systems, pages 2378–2386. Liu, X., Tong, X., and Liu, Q. (2021). Sampling with trusthworthy constraints: A variational gradient framework. Advances in Neural Information Processing Systems, 34:23557–23568. Matthes, D., McCann, R. J., and Savar´ e, G. (2009). A family of nonlinear fourth order equations of gradient flow type. Communications in Partial Differential Equations, 34(11):1352–1397. Mei, S., Montanari, A., and Nguyen, P.-M. (2018). A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665–E7671. M´ erigot, Q., Santambrogio, F., and Sarrazin, C. (2021). Non-asymptotic convergence bounds for wasserstein approximation using point clouds. Advances in Neural Information Processing Systems, 34:12810–12821. 3 / 9
  45. Quantization error References IV Nguyen, X., Wainwright, M. J., and

    Jordan, M. I. (2010). Estimating divergence functionals and the likelihood ratio by convex risk minimization. IEEE Transactions on Information Theory, 56(11):5847–5861. Ohta, S.-i. and Takatsu, A. (2011). Displacement convexity of generalized relative entropies. Advances in Mathematics, 228(3):1742–1787. Roberts, G. O. and Tweedie, R. L. (1996). Exponential convergence of langevin distributions and their discrete approximations. Bernoulli, pages 341–363. Vempala, S. and Wibisono, A. (2019). Rapid convergence of the unadjusted langevin algorithm: Isoperimetry suffices. Advances in neural information processing systems, 32. Villani, C. (2009). Optimal transport: old and new, volume 338. Springer. Wibisono, A. (2018). Sampling as optimization in the space of measures: The langevin dynamics as a composite optimization problem. In Conference on Learning Theory, pages 2093–3027. PMLR. Xu, L., Korba, A., and Slepˇ cev, D. (2022). Accurate quantization of measures via interacting particle-based optimization. International Conference on Machine Learning. 4 / 9
  46. Quantization error What is known What can we say on

    infx1,...,xn D(µn |µ∗) where µn = n i=1 δxi ? Quantization rates for the Wasserstein distance [Kloeckner, 2012, M´ erigot et al., 2021] W2 (µn, µ∗) ∼ O(n− 1 d ) Forward KL [Li and Barron, 1999]: for every gP = kϵ (· − w)dP(w), arg min µn KL(µ∗|kϵ ⋆ µn ) ≤ KL(µ∗|gP ) + C2 µ∗,P γ n where C2 µ∗,P = kϵ(x−m)2dP(m) ( kϵ(x−w)dP(w))2 dµ∗(x), and γ = 4 log 3 √ e + a is a constant depending on ϵ with a = supz,z′∈Rd log (kϵ (x − z)/kϵ (x − z′)). 5 / 9
  47. Quantization error Recent results For smooth and bounded kernels in

    [Xu et al., 2022] and µ∗ with exponential tails, we get using Koksma-Hlawka inequality min µn MMD(µn, µ∗) ≤ Cd (log n)5d+1 2 n . This bounds the integral error for f ∈ Hk (by Cauchy-Schwartz): Rd f (x)dµ∗(x) − Rd f (x)dµ(x) ≤ ∥f ∥Hk MMD(µ, π). For the reverse KL (joint work with Tom Huix) we get (in the well-specified case) adapting the proof of [Li and Barron, 1999]: min µn KL(kϵ ⋆ µ|µ∗) ≤ C2 µ∗ log(n) + 1 n . This bounds the integral error for measurable f : Rd → [−1, 1] (by Pinsker inequality): fd(kϵ ⋆ µn ) − fdµ∗ ≤ C2 µ∗ (log(n) + 1) 2n . 6 / 9
  48. Quantization error Mixture of Gaussians Langevin Monte Carlo on a

    mixture of Gaussians does not manage to target all modes in reasonable time, even in low dimensions. Picture from O. Chehab. 7 / 9
  49. Quantization error Annealing One possible fix : sequence of tempered

    targets as: µ∗ β ∝ µβ 0 (µ∗)1−β, β ∈ [0, 1] It is discretized Fisher-Rao gradient flow [Chopin et al., 2023]. 8 / 9
  50. Quantization error Other tempered path ”Convolutional path” (β ∈ [0,

    +∞[) frequently used in Diffusion Models µ∗ β = 1 √ 1 − β µ0 . √ 1 − β ∗ 1 √ β µ∗ . √ β (vs ”geometric path” µ∗ β ∝ µβ 0 (µ∗)1−β) 9 / 9