Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Wasserstein gradient flows of Moreau envelopes ...

Wasserstein gradient flows of Moreau envelopes of f-divergences in reproducing kernel Hilbert spaces

Most commonly used f-divergences of measures, e.g., the Kullback-Leibler divergence, are subject to limitations regarding the support of the involved measures. A remedy consists of regularizing the f-divergence by a squared maximum mean discrepancy (MMD) associated with a characteristic kernel K. In this paper, we use the so-called kernel mean embedding to show that the corresponding regularization can be rewritten as the Moreau envelope of some function in the reproducing kernel Hilbert space associated with K. Then, we exploit well-known results on Moreau envelopes in Hilbert spaces to prove properties of the MMD-regularized f-divergences and, in particular, their gradients. Subsequently, we use our findings to analyze Wasserstein gradient flows of MMD-regularized f-divergences. Finally, we consider Wasserstein gradient flows starting from empirical measures. We provide proof-of-the-concept numerical examples for f-divergences with both infinite and finite recession constant.

Joint work with Sebastian Neumayer, Gabriele Steidl, and Nicolaj Rux, see https://arxiv.org/abs/2402.04613.

You can view this talk here: https://www.youtube.com/watch?v=iuaQ1w4U-q8.

Viktor Stein

August 21, 2024
Tweet

More Decks by Viktor Stein

Other Decks in Research

Transcript

  1. Wasserstein Gradient Flows of Moreau Envelopes of f-Divergences in Reproducing

    Kernel Hilbert Spaces joint work with Sebastian Neumayer, TU Chemnitz Gabriele Steidl, TU Berlin Nicolaj Rux, TU Berlin UCLA Level set seminar (Stan Osher) 19.08.2024
  2. Goal. Recover ν ∈ P(Rd) from samples by minimizing f-divergence

    Df,ν to ν, e.g. KL(· | ν). Problem. Only samples ⇝ empirical measures, but µ ̸≪ ν =⇒ Df,ν (µ) = ∞. weak convergence Our Solution. Regularize Df,ν : M(Rd) → [0, ∞]. pointwise convergence “Df,ν ◦ m−1” = Gf,ν : HK → [0, ∞] λGf,ν m(µ) = min σ∈M+(Rd) Df,ν (σ) + 1 2λ ∥m(σ) − m(µ)∥2 HK , λ > 0. 1. “Kernel trick” m: M(Rd) → HK , µ → Rd K(x, ·) dµ(x) 2. Moreau envelope regularization We prove existence & uniqueness of W2 gradient flows of (λGf,ν ) ◦ m. Simulate particle flows = W2 gradient flows starting at empirical measure
  3. Literature review of prior work • KALE functional = MMD-regularized

    KL divergence [Glaser, Arbel, Gretton. NeurIPS’21] No Moreau envelope interpretation. • Kernel methods of moments = f-divergence-regularized MMD [Kremer, Nemmour, Schölkopf, Zhu. ICML’23] Doesn’t cover all f-divergences. • (f, Γ)-divergence = Pasch-Hausdorff envelope of f-divergences. [Birrell, Dupuis, Katsoulakis, Pantazis, Rey-Bellet, JMLR’23] Yields only Lipschitz, not differentiable functional. • W1 -Moreau envelope of f-divergences [Terjék. ICML’21] No RKHS, which makes optimization finite-dimensional, hence tractable. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 3 / 23
  4. 1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4.

    MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences
  5. Reproducing Kernel Hilbert Spaces “Kernel trick”: embed data into high-dimensional

    Hilbert space. K : Rd × Rd → R symmetric, positive definite. We consider radial kernels K(x, y) = ϕ(∥x − y∥2 2 ) with ϕ ∈ C∞((0, ∞)) ∩ C2([0, ∞)), (−1)kϕ(k)(r) ≥ 0, ∀k ∈ N, r > 0. ⇝ reproducing kernel Hilbert space (RKHS) HK := span({K(x, ·) : x ∈ Rd}). Key property: h → h(x) cts. Fig. 1: “Kernel trick”. Source: songcy.net/posts/story-of-basis-and-kernel-part-2/ 0.5 1 1.5 2 2.5 3 0.2 0.4 0.6 0.8 1 1.2 (1 − √ x)3 + (s + x)− 1 2 √ s exp − 1 2s x Examples (with parameter s > 0). • Gaussian ϕ(r) = exp − 1 2s r • inverse multiquadric ϕ(r) := (s + r)− 1 2 • spline ϕ(r) = max(0, (1 − √ r)s+2). Nonexamples. • Laplace ϕ(r) = exp(− 1 2s √ r) (not smooth enough) • K(x, y) = ∥x∥+∥y∥−∥x−y∥ (not radial) Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 5 / 23
  6. Kernel mean embedding and Maximum Mean Discrepancy “Kernel trick for

    signed measures” µ ∈ M(Rd) (in- stead of points): kernel mean embedding (KME) m: M(Rd) → HK , µ → Rd K(x, ·) dµ(x). HK Rd M(Rd) ⟲ x → K(x, ·) x → δx m We require m to be injective (HK “characteristic”) ⇐⇒ HK ⊂ C0 (Rd) dense. ⇝ Instead of measures, compare their embeddings in HK : maximum mean discrepancy (MMD) dK : M(Rd) × M(Rd) → [0, ∞), (µ, ν) → ∥m(µ − ν)∥HK . m injective =⇒ dK is a metric, but (M(Rd), dK ) is not complete. Easy to evaluate, e.g. for discrete measures since dK (µ, ν)2 = Rd × Rd K(x, y) d(µ − ν)(x) d(µ − ν)(y) ∀µ, ν ∈ M(Rd). Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 6 / 23
  7. 1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4.

    MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences
  8. Regularization in Convex Analysis - Moreau envelopes Let (H, ⟨·,

    ·⟩, ∥ · ∥) Hilbert space, f ∈ Γ0 (H), i.e. f : H → (−∞, ∞] convex lower semicontinuous, dom(f) := {x ∈ H : f(x) < ∞} ̸= ∅. For ε > 0, the ε-Moreau envelope of f, εf : H → R, x → min f(x′) + 1 2ε ∥x − x′∥2 : x′ ∈ H is convex, differentiable regularization of f preserving its min- imizers. Asymptotics: εf(x) ↗ f(x) for ε ↘ 0 and εf(x) ↘ inf(f) for ε → ∞. (ε, x) → εf(x) is viscosity solution of Hamilton-Jacobi equation:      ∂ε (εf)(x) + 1 2 ∥∇(εf)(x)∥2 2 = 0, 0f(x) → f(x). [Osher, Heaton, Fung, PNAS 120, 14, 2023]. Moreau envelope of an extended-real-valued non-differentiable function (top) and of | · | for different ε (bottom). ©Trygve U. Helgaker, Pontus Giselsson Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 8 / 23
  9. 1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4.

    MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences
  10. Entropy functions We consider f ∈ Γ0 (R) with f|(−∞,0)

    ≡ ∞ and with unique minimizer at 1: f(1) = 0 and positive recession constant f′ ∞ := limt→∞ 1 t f(t) > 0. Examples. fKL (x) := x ln(x) − x + 1 for x ≥ 0 yields the Kullback-Leibler divergence and fα (x) := 1 α−1 (xα − αx + α − 1) the Tsallis-α divergence Tα for α > 0. In the limit: T1 = KL. −0.5 0.5 1 1.5 2 2.5 3 0.5 1 1.5 2 2.5 x ln(x) − x + 1 |x − 1| (x − 1) ln(x) x ln(x) − (x + 1) ln x+1 2 max(0, 1 − x)2 Left: Examples of entropy functions, except the red. Right: The functions fα for α ∈ [0.1, 2.5]. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 10 / 23
  11. f-divergences f-divergence of µ = ρν + µs ∈ M+

    (Rd) (unique Lebesgue decomposition) to ν ∈ M+ (Rd) Df,ν (ρν + µs ) := Rd f ◦ ρ dν + f′ ∞ · µs (Rd) (∞ · 0 := 0) = sup h∈Cb(Rd;dom(f∗)) E µ [h] − E ν [f∗ ◦ h], E σ [h] := Rd h(x) dσ(x) The convex conjugate of f is f∗ : R → (−∞, ∞], s → sup {st − f(t) : t ≥ 0} . Theorem (Properties of Df,ν ) Df,ν : M+ (Rd) → [0, ∞] is convex, weak* lower semicontinuous. We have: Df,ν (µ) = 0 ⇐⇒ µ = ν. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 11 / 23
  12. MMD-Regularized f-divergence - Moreau envelope interpretation We define the MMD-regularized

    f-divergence functional Dλ f,ν (µ) := min Df,ν (σ) + 1 2λ dK (µ, σ)2 : σ ∈ M(Rd) , λ > 0, µ ∈ M(Rd). (1) Theorem (Moreau envelope interpretation of Dλ f,ν [NSSR24]) The HK -extension of Df,ν , Gf,ν : HK → [0, ∞], h →      Df,ν (µ), if ∃µ ∈ M+ (Rd) s.t. h = m(µ), ∞, else. is convex, lower semicontinuous and its Moreau envelope concatenated with m is the MMD-regularized f-divergence: λGf,ν ◦ m = Dλ f,ν [0, ∞) HK M(Rd) [0, ∞] Gf,ν λGf,ν m Df,ν Dλ f,ν Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 12 / 23
  13. Properties of Dλ f,ν (Properties of Dλ f,ν) [NSSR24] •

    Dual formulation Dλ f,ν (µ) = max E µ [p] − E ν [f∗ ◦ p] − λ 2 ∥p∥2 HK : p ∈ HK , p ≤ f′ ∞ . (2) ˆ p ∈ HK maximizes (2) ⇐⇒ ˆ g = m(µ) − λˆ p is primal solution. λ 2 ∥ˆ p∥2 HK ≤ Dλ f,ν (µ) ≤ ∥ˆ p∥HK (∥mµ ∥HK + ∥mν ∥HK ) and ∥ˆ p∥HK ≤ 2 λ dK (µ, ν). • Dλ f,ν is Fréchet differentiable on M(Rd) and its gradient is λ-Lipschitz with respect to dK : ∇Dλ f,ν (µ) = argmax (2). Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 13 / 23
  14. Theorem. (Properties of Dλ f,ν) [NSSR24] • Asymptotic regimes: Mosco

    resp. pointwise convergence (if 0 ∈ int(dom(f∗)) resp. f∗ differentiable in 0) Dλ f,ν → Df,ν λ ↘ 0 and (1 + λ)Dλ f,ν → 1 2 dK (·, ν)2 λ → ∞ • Divergence property: Dλ f,ν (µ) = 0 ⇐⇒ µ = ν. • If f∗ is differentiable in 0, then (µ, ν) → Dλ f,ν (µ) metrizes weak convergence on M+ (Rd)-balls. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 14 / 23
  15. 1. RKHS & MMD 2. Moreau envelopes 3. f-divergences 4.

    MMD-Moreau envelopes of f-divergences 5. Wasserstein gradient flow 6. WGF of MMD-Moreau en- velopes of f-divergences
  16. Wasserstein space and generalized geodesics P2 (Rd) := {µ ∈

    P(Rd) : Rd ∥x∥2 2 < ∞}, ∥ · ∥2 Eucl. norm. W2 (µ, ν)2 = min π∈Γ(µ,ν) Rd × Rd ∥x − y∥2 2 dπ(x, y), µ, ν ∈ P2 (Rd). Fig. 2: Vertical (L2 ) vs. horizontal (W2 ) mass displacement. ©A. Korba Fig. 3: Generalized geodesic from µ2 to µ3 with base µ1 [AGS08]. Definition (Generalized geodesic convexity) A function F : P2 (Rd) → (−∞, ∞] is M-convex along generalized geodesics if, for every σ, µ, ν ∈ dom(F), there exists a α ∈ P2 (R3d)with (P1,2 )# α ∈ Γopt(σ, µ) and (P1,3 )# α ∈ Γopt(σ, ν) such that F (1−t)P2 +tP3 # α ≤ (1−t) F(µ)+t F(ν)− M 2 t(1−t) Rd × Rd × Rd ∥y−z∥2 2 dα(x, y, z), ∀t ∈ [0, 1]. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 16 / 23
  17. Wasserstein gradient flows Definition (Fréchet subdifferential in Wasserstein space) The

    (reduced) Fréchet subdifferential of F : P2 (Rd) → (−∞, ∞] at µ ∈ dom(F) is ∂ F(µ) := ξ ∈ L2(Rd; µ) : F(ν) − F(µ) ≥ inf π∈Γopt(µ,ν) Rd × Rd ⟨ξ(x1 ), x2 − x1 ⟩ dπ(x, y) + o(W2 (µ, ν)) A curve γ : (0, ∞) → P2 (Rd) is absolutely continuous if ∃ L2-Borel velocity field v : Rd ×(0, ∞) → Rd s.t. ∂t γt + ∇ · (vt γt ) = 0, (t, x) ∈ (0, ∞) × Rd, weakly. (Continuity Eq.) Definition (Wasserstein gradient flow) A locally absolutely continuous curve γ : (0, ∞) → P2 (Rd) with velocity field vt ∈ Tγt P2 (Rd) is a Wasserstein gradient flow with respect to F : P2 (Rd) → (−∞, ∞] if vt ∈ −∂ F(γt ), for a.e. t > 0. ©Petr Mokrov Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 17 / 23
  18. Wasserstein Gradient Flow with respect to Dλ f,ν Theorem (Convexity

    and gradient of Dλ f,ν [NSSR24]) Since K being radial and smooth, Dλ f,ν is M-convex along generalized geodesics with M := −8λ−1 (d + 2)ϕ′′(0)ϕ(0) and its (reduced) Fréchet subdifferential is ∂Dλ f,ν (µ) = {∇ argmax (2)}. Remark. M seems non-optimal, since for λ → 0, Dλ f,ν → Df,ν and Df,ν is 0-convex, but M → −∞. Corollary There exists a unique Wasserstein gradient flow (γt )t>0 of Dλ f,ν starting at µ0 ∈ P2 (Rd), fulfilling the continuity equation ∂t γt = ∇ · γt ∂Dλ f,ν (γt ) , γ0 = µ0 . Lemma (Particle flows are W2 gradient flows) If µ0 is empirical, then so is γt for all t > 0. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 18 / 23
  19. Numerical Experiments - Particle Descent Algorithm Take i.i.d. samples (x(0)

    j )N j=1 ∼ µ0 and (yj )M j=1 ∼ ν. Forward Euler discretization in time with step size τ > 0 yields γn+1 := (id −τ∇ˆ pn )# γn , ˆ pn = argmax in Dλ f,ν (γn ) so (γn )n∈N = 1 N N j=1 δ x(n) j with gradient step x(n+1) j = x(n) j − τ∇ˆ pn x(n) j , j ∈ {1, . . . , N}, n ∈ N . Theorem (Representer-type theorem [NSSR24]) If f′ ∞ = ∞ or if λ > 2dK (γn , ν) ϕ(0) 1 f′ ∞ , then finding ˆ pn is a finite-dimensional strongly convex problem. To find ˆ pn , we use L-BFGS-B, a quasi-Newton method. We use annealing strategy for λ if f′ ∞ < ∞. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 19 / 23
  20. Numerical experiments Fig. 4: IMQ kernel, λ = 1 100

    τ = 1 1000 , Top: Tsallis-3 divergence, Bottom: Tsallis- 1 2 divergence, with annealing. Fig. 5: Number of starting particles N, less than number of samples of target, M ⇝ quantization Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 20 / 23
  21. Further work • Non-differentiable (e.g. Laplace = 1 2 -Matérn)

    and unbounded (e.g. Riesz, Coulomb) kernels. • Convergence rates in suitable metric. • Prove consistency bounds [Leclerc, Mérigot, Santambrogio, Stra. 2020] and better M-convexity estimates. • Convergence for annealing strategy? • Different domains, e.g. compact subsets of Rd (manifolds like sphere, torus), groups, infinite-dimensional spaces. • Regularize other divergences, e.g. Rényi divergences, Bregman divergences. • Gradient flow of Dλ f,ν with respect to other metrics, like Kantorovich-Hellinger (related to unbalanced OT), MMD, Fisher-Rao or Wasserstein-p for p ∈ [1, ∞]. • More elaborate time discretizations, variable step sizes. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 21 / 23
  22. Conclusion • We created novel objective. Minimizing it allows sampling

    from a target measure of which only samples are known. • Clear, rigorous interpretation using Convex Analysis and RKHS. • Theory covers (almost) all f-divergences. • Best of both worlds: Dλ f,ν interpolates between Df,ν and dK (·, ν)2. • Effective algorithms due to (modified) representer theorem & GPU / PyTorch. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 22 / 23
  23. Thank you for your attention! I am happy to take

    any questions. Paper link: arxiv.org/abs/2402.04613 My website: viktorajstein.github.io [AGS08, BDK+22, GAG21, HWAH24, KYSZ23, LMSS20, LMS17, Ter21] Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 23 / 23
  24. References I [AGS08] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré,

    Gradient flows: in metric spaces and in the space of probability measures, 2 ed., Springer Science & Business Media, 2008. [BDK+22] Jeremiah Birrell, Paul Dupuis, Markos A. Katsoulakis, Yannis Pantazis, and Luc Rey-Bellet, (f, Γ)-divergences: Interpolating between f-divergences and integral probability metrics, J. Mach. Learn. Res. 23 (2022), no. 39, 1–70. [GAG21] Pierre Glaser, Michael Arbel, and Arthur Gretton, KALE flow: A relaxed KL gradient flow for probabilities with disjoint support, Advances in Neural Information Processing Systems (Virtual event), vol. 34, 6–14 Dec 2021, pp. 8018–8031. [HWAH24] J. Hertrich, C. Wald, F. Altekrüger, and P. Hagemann, Generative sliced MMD flows with Riesz kernels, International Conference on Learning Representations (ICLR) (Vienna, Austria), 7 – 11 May 2024. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 1 / 3
  25. References II [KYSZ23] H. Kremer, Nemmour Y., B. Schölkopf, and

    J.-J. Zhu, Estimation beyond data reweighting: kernel methods of moments, ICML’23: Proceedings of the 40th International Conference on Machine Learning (Honolulu, Hawaii, USA), vol. 202, July 23 - 29 2023, p. 17745–17783. [LMS17] Matthias Liero, Alexander Mielke, and Giuseppe Savaré, Optimal entropy-transport problems and a new Hellinger–Kantorovich distance between positive measures, Invent. Math. 211 (2017), no. 3, 969–1117. [LMSS20] Hugo Leclerc, Quentin Mérigot, Filippo Santambrogio, and Federico Stra, Lagrangian discretization of crowd motion and linear diffusion, SIAM J. Numer. Anal. 58 (2020), no. 4, 2093–2118. MR 4123686 [Ter21] Dávid Terjék, Moreau-Yosida f-divergences, International Conference on Machine Learning (ICML) (Virtual event), PMLR, Jul 18–24 2021, pp. 10214–10224. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 2 / 3
  26. Shameless plug: other works Interpolating between OT and KL regularized

    OT using Rényi Divergences Rényi divergence ̸∈ {f-div., Bregman div.}, α ∈ (0, 1) Rα (µ | ν) := 1 α − 1 ln X dµ dτ α dν dτ 1−α dτ , OTε,α (µ, ν) := min π∈Π(µ,ν) ⟨c, π⟩ + εRα (π | µ ⊗ ν) is a metric, where ε > 0, µ, ν ∈ P(X), X compact. OT(µ, ν) α↘0 ← − − − − or ε→0 OTε,α (µ, ν) α↗1 − − − → OTKL ε (µ, ν). In the works: debiased Rényi-Sinkhorn divergence OTε,α (µ, ν) − 1 2 OTε,α (µ, µ) − 1 2 OTε,α (ν, ν). W2 gradient flows of dK (·, ν)2 with K(x, y) := −|x − y| in 1D. Reformulation as maximal monotone inclu- sion Cauchy problem in L2 (0, 1) via quantile functions. Comprehensive description of solutions’ behav- ior, instantaneous measure-to-L∞ regular- ization, implicit Euler is simple. Viktor Stein W2 Gradient Flows of MMD-Moreau Envelopes of f-Divergences in RKHSs August 19th, 2024 3 / 3 −1 −0.5 0.5 1 1.5 2 1 2 3 µ0 8 6 4 2 0 2 4 6 8 0.00 0.05 0.10 0.15 0.20 0.25 0.30 0.35 0.40 Iteration 0 initial target explicit implicit