Upgrade to Pro — share decks privately, control downloads, hide ads and more …

何でも微分する

 何でも微分する

IBIS 2023 企画セッション『最適輸送』 https://ibisml.org/ibis2023/os/#os3 で発表した内容です。

講演概要: 最適輸送が機械学習コミュニティーで人気を博している要因として、最適輸送には微分可能な変種が存在することが挙げられる。微分可能な最適輸送は様々な機械学習モデルに構成要素として簡単に組み入れることができる点が便利である。本講演では、最適輸送の微分可能な変種とその求め方であるシンクホーンアルゴリズムを紹介する。また、この考え方を応用し、ソーティングなどの操作や他の最適化問題を微分可能にする方法を紹介するとともに、これらの微分可能な操作が機械学習においてどのように役立つかを議論する。

シンクホーンアルゴリズムのソースコード:https://colab.research.google.com/drive/1RrQhsS52B-Q8ZvBeo57vKVjAARI2SMwM?usp=sharing
点群の配置の最適化のソースコード:https://colab.research.google.com/drive/1u8lu0I7GwzR48BQqoGqOp2A_7mTHxzrk?usp=sharing
最短経路の最適化のソースコード:https://colab.research.google.com/drive/1yB_tcEA2OppiyaInzM1GKmGAlDKw6VNL?usp=sharing

『最適輸送の理論とアルゴリズム』:https://www.amazon.co.jp/dp/4065305144
『最適輸送の解き方』:https://speakerdeck.com/joisino/zui-shi-shu-song-nojie-kifang

連絡先: @joisino_ (Twitter) / https://joisino.net/

佐藤竜馬 (Ryoma Sato)

October 31, 2023
Tweet

More Decks by 佐藤竜馬 (Ryoma Sato)

Other Decks in Research

Transcript

  1. 3 KYOTO UNIVERSITY 様々な操作を微分し連続的に最適化する方法を学ぶ ◼ 様々な離散的な操作や最適化を微分する方法を学びます。 ⚫ 最適輸送 ⚫ ソート、ランキング

    ⚫ 最短経路問題 など ◼ これらが微分できるようになると ⚫ 輸送コストが最小となる配置を求める ⚫ 真のラベルが top-K に入る確率を最大化する を勾配法ベースの連続最適化により解くことができます。
  2. 4 KYOTO UNIVERSITY 最適輸送は重み付き点群を比較するツール ◼ 最適輸送:重み付き点群を輸送コストを基に比較するツール ◼ 入力: 𝑎 ∈

    ℝ+ 𝑛: 点群 A の各点の重み 𝑋 ∈ ℝ𝑛×𝑑: 点群 A の各点の位置 𝑏 ∈ ℝ+ 𝑚: 点群 B の各点の重み 𝑌 ∈ ℝ𝑚×𝑑: 点群 B の各点の位置 ◼ 出力: 𝑑 𝑎, 𝑋, 𝑏, 𝑌 ∈ ℝ: 点群 A と点群 B の距離(スカラー) 𝑃 𝑎, 𝑋, 𝑏, 𝑌 ∈ ℝ𝑛×m: 点群 A と点群 B の割り当て 例:ソースデータの集合 (一つの点が 1 データ) 例:ターゲットデータの集合 (一つの点が 1 データ) 例:ソースとターゲットの乖離度
  3. 5 KYOTO UNIVERSITY 最適輸送は最適な輸送における移動コストを測る ◼ 最適輸送のイメージ図 点群 A と点群 B

    の距離(違いの大きさ)を測りたい。 最適な輸送 このときの移動コストで距離を測る 最適でない輸送
  4. 6 KYOTO UNIVERSITY 最適輸送の入力例 ◼ 入力例: 𝑎 = 0.2, 0.3,

    0.4, 0.1 𝑏 = (0.1, 0.6, 0.3) 𝑎1 = 0.2 𝑏1 = 0.1 点の大きさ(重み) 𝑥1 = (1.5, 2.4) 𝑦1 = (1.8, 1.4) 位置
  5. 7 KYOTO UNIVERSITY 最適輸送の出力例 ◼ 出力例: 総移動コスト: 𝑃1,1 = 0.1

    𝑃1,2 = 0.1 𝑃2,2 = 0.3 𝑑 𝑎, 𝑋, 𝑏, 𝑌 = 0.83 輸送量
  6. 9 KYOTO UNIVERSITY 線形計画としての定式化 ◼ 最適輸送を最適化問題として定式化する minimize ෍ 𝑖=1 𝑛

    ෍ 𝑗=1 𝑚 𝑃𝑖𝑗 𝐶𝑖𝑗 𝑃 ∈ ℝ𝑛×𝑚 s.t. 𝑃𝑖𝑗 ≥ 0 ∀𝑖, 𝑗 ෍ 𝑗=1 𝑚 𝑃𝑖𝑗 = 𝑎𝑖 ∀𝑖 ෍ 𝑖=1 𝑛 𝑃𝑖𝑗 = 𝑏𝑗 ∀𝑗 輸送量は非負 点群 A の点 i から 出ていく量の合計は 𝑎𝑖 点群 B の点 j から 出ていく量の合計は 𝑏𝑖 は移動コストを並べた行列 例: 𝐶 ∈ ℝ𝑛×m 𝐶𝑖𝑗 = || 𝑥𝑖 − 𝑦𝑗 || 2 2 移動コストを 最小化する 移動方法 P を求める
  7. 10 KYOTO UNIVERSITY a, b, X, Y を入れると d, P

    が出てくる 𝑎 𝑏 𝑋 𝑌 𝐶 最適輸送問題 𝑃 𝑑
  8. 11 KYOTO UNIVERSITY 入力についての勾配を求めて最適化をしたい 𝑎 𝑏 𝑋 𝑌 𝐶 最適輸送問題

    𝑃 𝑑 損失 誤差逆伝播 近似誤差が最小となる サンプル重みづけを求めたい 輸送コストが最小となるような 点の配置を求めたい
  9. 12 KYOTO UNIVERSITY 入力についての勾配を求めて最適化をしたい 𝑎 𝑏 𝑋 𝑌 𝐶 最適輸送問題

    𝑃 𝑑 損失 誤差逆伝播 近似誤差が最小となる サンプル重みづけを求めたい 輸送コストが最小となるような 点の配置を求めたい
  10. 13 KYOTO UNIVERSITY 入力についての勾配を求めて最適化をしたい 𝑎 𝑏 𝑋 𝑌 𝐶 最適輸送問題

    𝑃 𝑑 損失 誤差逆伝播 近似誤差が最小となる サンプル重みづけを求めたい 輸送コストが最小となるような 点の配置を求めたい
  11. 14 KYOTO UNIVERSITY 正則化を追加して滑らかにする ◼ 悲報:最適輸送は微分できない 朗報:ちょっと変えればできる minimize ෍ 𝑖=1

    𝑛 ෍ 𝑗=1 𝑚 𝑃𝑖𝑗 𝐶𝑖𝑗 + 𝜀 ෍ 𝑖=1 𝑛 ෍ 𝑗=1 𝑚 𝑃𝑖𝑗 (log 𝑃𝑖𝑗 − 1) 𝑃 ∈ ℝ𝑛×𝑚 s.t. 𝑃𝑖𝑗 ≥ 0 ෍ 𝑗=1 𝑚 𝑃𝑖𝑗 = 𝑎𝑖 ෍ 𝑖=1 𝑛 𝑃𝑖𝑗 = 𝑏𝑗 問題を滑らかにするための エントロピー正則化項 一様に近い輸送を優遇する 𝜀 ∈ ℝ はハイパーパラメータ
  12. 15 KYOTO UNIVERSITY 正則化を追加して滑らかにする ◼ シンクホーンアルゴリズム:正則化付き最適輸送を解く 導出は 『最適輸送の理論とアルゴリズム』 第三章や 『最適輸送の解き方』

    p.198- を参照してください。 https://speakerdeck.com/joisino/zui-shi-shu-song-nojie-kifang?slide=198 超シンプル! K = np.exp(- C / eps) u = np.ones(n) for i in range(100): v = b / (K.T @ u) u = a / (K @ v) P = u.reshape(n, 1) * K * v.reshape(1, m) d = (C * P).sum()
  13. 16 KYOTO UNIVERSITY 線形計画解とシンクホーン解はほぼ同じ n = m = 4 n

    = m = 100 線形計画解 シンクホーン解 ほぼ同じ → 以降同一視する 行列 𝑃 ∈ ℝ𝑛×𝑚 の図示 https://colab.research.google.com/drive/1RrQhsS52B-Q8ZvBeo57vKVjAARI2SMwM?usp=sharing ソースコード
  14. 17 KYOTO UNIVERSITY 再掲:シンクホーンアルゴリズム ◼ シンクホーンアルゴリズム:正則化付き最適輸送を解く 超シンプル! K = np.exp(-

    C / eps) u = np.ones(n) for i in range(100): v = b / (K.T @ u) u = a / (K @ v) P = u.reshape(n, 1) * K * v.reshape(1, m) d = (C * P).sum()
  15. 18 KYOTO UNIVERSITY シンクホーンアルゴリズムは自動微分できる ◼ 四則計算と exp だけからなるので自動微分が可能 a.requires_grad =

    True K = torch.exp(- C / eps) u = torch.ones(n) for i in range(100): v = b / (K.T @ u) u = a / (K @ v) P = u.reshape(n, 1) * K * v.reshape(1, m) d = (C * P).sum() d.backward() print(a.grad)
  16. 19 KYOTO UNIVERSITY シンクホーンアルゴリズムは自動微分できる ◼ 他のニューラルネットワークと組み合わせてもオーケー C = net1(z) K

    = torch.exp(- C / eps) u = torch.ones(n) for i in range(100): v = b / (K.T @ u) u = a / (K @ v) P = u.reshape(n, 1) * K * v.reshape(1, m) d = (C * P).sum() loss = net2(P, d) loss.backward() 何かしらのニューラルネットワーク
  17. 20 KYOTO UNIVERSITY 自動微分を使って配置を最適化する例 ◼ 数値例:点群 A を点群 B に近づける

    パラメータは位置 X Adam で最適化 https://drive.google.com/file/d/19XNtttaSr-Kc8yfv1VKRz0O8dUpcxSZM/view?usp=sharing https://colab.research.google.com/drive/1u8lu0I7GwzR48BQqoGqOp2A_7mTHxzrk?usp=sharing 動画 ソースコード
  18. 21 KYOTO UNIVERSITY 応用例:転移学習 ◼ 予測誤差 + 赤と青の最適輸送コストを最小化 ◼ ニューラル

    ネットワーク 入力 予測ヘッド 予測 1600 サンプルの埋め込み 赤:シミュレーションデータについての埋め込み 青:本番環境データについての埋め込み
  19. 22 KYOTO UNIVERSITY ランキング問題を考える ◼ ランキング問題 ◼ 入力: 𝑥 ∈

    ℝ𝑛: 配列 出力: 𝑟 ∈ ℕ𝑛: ランク(𝑟𝑖 = 𝑘 ⇔ 𝑥𝑖 は k 番目に大きい) ◼ 入力例: 𝑥 = 6.2, 1.4, 1.5, 3.9, 2.2 出力例: 𝑟 = (1, 5, 4, 2, 3)
  20. 23 KYOTO UNIVERSITY 分類問題では正解率を最大化したい ◼ 分類問題において本当にやりたいことは正解率の最大化。 二値分類問題においてクラス 1 のデータの予測確率が (0.6,

    0.4) だろうが (0.99, 0.01) だろうが正解なら十分。 ◼ 正解率を最適化するのが難しいので、クロスエントロピーを使う ことが多い。 ◼ しかし、クロスエントロピーは (0.99, 0.01) を優遇する。 もう正解できているデータの損失を無駄に下げるために、 際どいデータが不正解に転じることがある。
  21. 24 KYOTO UNIVERSITY 正解率や top-K 正解率を直接最大化したい ニューラルネットワーク logit = 6.2,

    1.4, 1.5, 3.9, 2.2 ランキング 𝑟 = (1, 5, 4, 2, 3) 𝑦 = 4 𝑟𝑦 = 2 教師ラベル 「猫」の予測順位は 2 位 1 位にして正解率を上げるには? 誤差逆伝播 (をやりたい) 正解率や top-K 正解率 を直接最適化したい 例:豹、鳥、犬、猫、猿の五クラス分類
  22. 25 KYOTO UNIVERSITY ランキング問題は最適輸送の特殊例 ◼ ランキング問題は 𝑑 = 1 次元の最適輸送問題の特殊例

    → シンクホーンアルゴリズムで計算すればランクも微分できる 𝑥1 = 6.2 𝑥2 = 1.4 𝑥3 = 1.5 𝑥5 = 2.2 𝑥4 = 3.9 𝑦1 = 1 𝑦2 = 2 𝑦3 = 3 𝑦4 = 4 𝑦5 =5 𝑃 = 0 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 1 0 0 最も小さいものは 1 に、二番に小さいものは 2 に … と 輸送するのが最適 𝑟 = 𝑃 5 4 3 2 1 = 1 5 4 2 3 順列行列をランクに変換 y = 1, 2, … , n ⊤ 𝑥 は入力
  23. 26 KYOTO UNIVERSITY 正解率や top-K 正解率を直接最大化できる ニューラルネットワーク logit = 6.2,

    1.4, 1.5, 3.9, 2.2 シンクホーンによるランク計算 𝑟 = (1.01, 4.84, 4.13, 2.02, 3.04) 𝑦 = 4 𝑟𝑦 = 2.02 教師ラベル 「猫」の予測順位は 2.02 位 誤差逆伝播 正解率や top-K 正解率 を直接最適化できる
  24. 27 KYOTO UNIVERSITY ビームサーチなど、様々な過程全体を微分可能にできる ◼ 同様の考えから、ランキング・ソートなどを end-to-end 学習 パイプラインの中に組み込むことができる。 ◼

    言語モデルの訓練においてビームサーチを微分する [1] 訓練時は teacher forcing して、テスト時はビームサーチを することが多いが、これだと乖離が生じる。 ビームサーチの top-K をシンクホーンで計算し、ビームサーチの 過程全体を微分可能にする。これを使って訓練する。 [1] Xie et al. Differentiable Top-k with Optimal Transport. NeurIPS 2020. [2] Goyal et al. A continuous relaxation of beam search for end-to-end training of neural sequence models. AAAI 2018.
  25. 28 KYOTO UNIVERSITY ビームサーチなど、様々な過程全体を微分可能にできる ◼ シンクホーンアルゴリズムはブレグマン法の特殊例である [1] ◼ ブレグマン法は制約あり凸計画問題のアルゴリズム。 制約なしの解からはじめて、制約に射影していく。

    ◼ シンクホーンアルゴリズムは、P1 = 𝑎 と 𝑃⊤1 = 𝑏 に交互に 射影していくことに対応する。 ◼ 一般の線形計画もブレグマン法により、 シンクホーンアルゴリズムと同様の簡単な反復アルゴリズムで 解くことができ、これにより同様に微分ができる。 [1] Benamou et al. Iterative Bregman Projections for Regularized Transportation Problems. 2015. for i in range(100): v = b / (K.T @ u) u = a / (K @ v)
  26. 30 KYOTO UNIVERSITY 最短経路問題の数値例 ◼ 例1: 最短経路を長くして邪魔をする(※最短経路問題は線形計画) パラメータ:マスのコスト(総和は一定) Adam で最適化

    この例はマスコストを生パラメータとしているが、 生成モデルでマップ生成してモデルまで逆伝播なども可能 コスト (パラメータ) 最短経路 可視化 https://drive.google.com/file/d/1_eijS6R83nTcBOMzUM1QoR74Uk4S0qvw/view?usp=sharing https://colab.research.google.com/drive/1yB_tcEA2OppiyaInzM1GKmGAlDKw6VNL?usp=sharing 動画 コード
  27. 31 KYOTO UNIVERSITY 最短経路を最適化する問題を勾配法で解くことができる ◼ 例1: 最短経路を長くして邪魔をする(※最短経路問題は線形計画) パラメータ:マスのコスト(総和は一定) Adam で最適化

    ◼ 観察1:最短経路などの組合せ的な問題も行列積を用いた 反復法により解が求まる。 ◼ 観察2:最短経路を最適化するという 2 レベルの最適化問題も Adam などの勾配法ベースの連続最適化で解ける。
  28. 32 KYOTO UNIVERSITY 最短経路問題のその他の問題例 ◼ 例2: 教師あり最短経路問題 [1] ゲーム画面 ニューラル

    ネットワーク 推定コスト 真コスト(非観測) 最短経路 微分可能最短経路 推定最短経路 真経路(観測) 損失 誤差逆伝播 [1] Vlastelica et al. Differentiation of Blackbox Combinatorial Solvers. ICLR 2020.
  29. 33 KYOTO UNIVERSITY 離散的な操作を微分可能にすることができる ◼ 最適輸送、ランキング、最短経路問題などの微分可能版を 考えることができる。 ◼ シンクホーンアルゴリズムやブレグマン法で計算できる。 ◼

    「モデルの予測の順位」「モデルの出力を基にしたビームサーチの 結果」「モデルの出力を最短経路問題に入力した結果」 などの量を直接最適化することができる。 離散的な操作を微分可能にしてニューラルネットワークの end-to-end 最適化パイプラインに組み込むことができる
  30. 34 KYOTO UNIVERSITY 参考文献 ◼ Xie et al. Differentiable Top-k

    with Optimal Transport. NeurIPS 2020. ◼ Goyal et al. A continuous relaxation of beam search for end-to-end training of neural sequence models. AAAI 2018. ◼ Benamou et al. Iterative Bregman Projections for Regularized Transportation Problems. 2015. ◼ Vlastelica et al. Differentiation of Blackbox Combinatorial Solvers. ICLR 2020. ◼ Cuturi et al. Differentiable Ranks and Sorting using Optimal Transport. NeurIPS 2019. ◼ Blondel et al. Fast Differentiable Sorting and Ranking. ICML 2020. ◼ Berthet et al. Learning with Differentiable Perturbed Optimizers. NeurIPS 2020. ◼ Weed. An explicit analysis of the entropic penalty in linear programming. COLT 2018. ◼ 『最適輸送の解き方』 https://speakerdeck.com/joisino/zui-shi-shu-song-nojie-kifang ◼ 佐藤竜馬 『最適輸送の理論とアルゴリズム』