Upgrade to PRO for Only $50/Year—Limited-Time Offer! 🔥

Julia Tokyo #11 トーク: 「Juliaで歩く自動微分」

abap34
February 03, 2024

Julia Tokyo #11 トーク: 「Juliaで歩く自動微分」

abap34

February 03, 2024
Tweet

More Decks by abap34

Other Decks in Research

Transcript

  1. Juliaのこんなところが気に入ってます! 1. 綺麗な可視化・ベンチマークライブラリ i. Plotまわり, @code_... マクロ, BenchmarkTools.jl たち 2.

    パッケージ管理ツール i. 言語同梱、仮想環境すぐ作れる、パッケージ化簡単 3. すぐ書ける すぐ動く i. Jupyter サポート, 強力な REPL 4. 速い!! i. 速い正義 ii. 裏が速いライブラリの「芸人」にならなくても、素直に書いてそのまま速い 自己紹介 Introduction 4 / 142
  2. [1] 微分と連続最適化 1.1 微分のおさらい 1.2 勾配降下法 1.3 勾配降下法と機械学習 [2] 自動で微分

    2.1 自動微分の枠組み 2.2 数式微分 ─式の表現と微分 2.3 自動微分 ─式からアルゴリズムへ [3] Juliaに微分させる 3.1 FiniteDiff.jl/FiniteDifferences.jl 3.1 ForwardDiff.jl 3.2 Zygote.jl [4] 付録 おしながき Introduction 10 / 142
  3. 偏 • 微分の定義 from 大学数学 について偏微分 以外の変数を固定して微分 [定義2. 偏微分係数] 変数関数

    の の に関する偏微分係数 微分の定義を振り返る ~大学 1.1 微分のおさらい 15 / 142
  4. 勾配降下法 元ネタ: Ilya Pavlyukevich, "Levy flights, non-local search and simulated

    annealing", Journal of Computational Physics 226 (2007) 1830-1844. 25 / 142
  5. さっきは ⇨ 頑張って手で を求められた 深層学習の複雑なモデル...   ⇩ とてもつらい. 勾配の計算法を考える 画像:

    He, K., Zhang, X., Ren, S., & Sun, J. (2015). Deep Residual Learning for Image Recognition. ArXiv. /abs/1512.03385 31 / 142
  6. アイデア1. 近似によって求める? ⇨ 実際に小さい をとって計算する. function diff(f, x; h=1e-8) return

    (f(x + h) - f(x)) / h end 勾配の計算法を考える ~近似編 1.3 勾配降下法と機械学習 32 / 142
  7. これでもそれなりに近い値を得られる. 例) の における微分係数 を求める. julia> function diff(f, x; h=1e-8)

    return (f(x + h) - f(x)) / h end diff (generic function with 1 method) julia> diff(x -> x^2, 2) # おしい! 3.999999975690116 勾配の計算法を考える ~近似編 1.3 勾配降下法と機械学習 33 / 142
  8. 問題点①. 誤差が出る 1. 本来極限をとるのに、小さい を とって計算しているので誤差が出る 2. 分子が極めて近い値同士の引き算に なっていて、 桁落ちによって精度が大幅に悪化.

    問題点②. 勾配ベクトルの計算が非効率 1. 変数関数の勾配ベクトル を計算するには、 各 について「少し動かす→計算」 を繰り返すので 回 を評価する. 2. 応用では がとても大きくなり、 の評価が重くなりがちで これが 致 命的 --> 数値微分 1.3 勾配降下法と機械学習 35 / 142
  9. [定義. 自動微分] (数学的な関数を表すように定義された) 計算機上のアルゴリズム • • • • • •

    • • • • • を入力とし, その関数の任意の点の微分係数を無限精度の計算モデル上で正確に計算できる 計算機上のアルゴリズム • • • • • • • • • • • を出力するアルゴリズムを 「自動微分(Auto Differentiation, Algorithmic Differentiation)」 と呼ぶ。 「自動微分」 2.1 自動微分の枠組み 44 / 142
  10. 2. アルゴリズム アルゴリズム by 自動微分ライブラリ using AutoDiffLib # ※ 存在しないです!

    function f(x::InfinityPrecisionFloat) return x^2 end df = AutoDiffLib.differentiate(f) df(2.0) # 4.0 df(3.0) # 6.0 例: 二次関数の微分 49 / 142
  11. [定義. 自動微分] (数学的な関数を表すように定義された) 計算機上のアルゴリズム • • • • • •

    • • • • • を入力とし, その関数の任意の点の微分係数を無限精度の計算モデル上で正確に計算できる 計算機上のアルゴリズム • • • • • • • • • • • を出力するアルゴリズムを 「自動微分(Auto Differentiation, Algorithmic Differentiation)」 と呼ぶ。 数式微分 2.2 数式微分 -式の表現と微分 53 / 142
  12. + * * 1 2 ^ x 2 3 x

    単純・解析しやすい表現 ... 式をそのまま木で表す 数式微分のアイデア 55 / 142
  13. Julia なら 簡単に式の木構造による表現を得られる. julia> f = :(4x + 3) #

    or Meta.parse("4x + 3") :(4x + 3) julia> dump(f) Expr head: Symbol call args: Array{Any}((3,)) 1: Symbol + 2: Expr head: Symbol call args: Array{Any}((3,)) 1: Symbol * 2: Int64 4 3: Symbol x 3: Int64 3 Expr 型 2.2 数式微分 -式の表現と微分 58 / 142
  14. 3. 足し算に関する微分 function derivative(ex::Expr)::Expr op = ex.args[1] if op ==

    :+ return Expr( :call, :+, derivative(ex.args[2]), derivative(ex.args[3]) ) end end 数式微分の実装 2.2 数式微分 -式の表現と微分 62 / 142
  15. 4. 掛け算に関する微分 function derivative(ex::Expr)::Expr op = ex.args[1] if op ==

    :+ ... elseif op == :* return Expr( :call, :+, Expr(:call, :*, ex.args[2], derivative(ex.args[3])), Expr(:call, :*, derivative(ex.args[2]), ex.args[3]) ) end end 数式微分の実装 2.2 数式微分 -式の表現と微分 63 / 142
  16. derivative(ex::Symbol) = 1 # dx/dx = 1 derivative(ex::Int64) = 0

    # 定数の微分は 0 function derivative(ex::Expr)::Expr op = ex.args[1] if op == :+ return Expr(:call, :+, derivative(ex.args[2]), derivative(ex.args[3])) elseif op == :* return Expr( :call, :+, Expr(:call, :*, ex.args[2], derivative(ex.args[3])), Expr(:call, :*, derivative(ex.args[2]), ex.args[3]) ) end end 数式微分の実装 2.2 数式微分 -式の表現と微分 ※ Juliaは 2 * x * x のような式を、 (2 * x) * x でなく *(2, x, x) として表現するのでこのような式については上は正しい結果を返しません. (スペースが足りませんでした) このあたりもちゃんとやるやつは付録のソースコードを見てください. 基本的には二項演算の合成とみて順にやっていくだけで良いです。 64 / 142
  17. 例) の導関数 を求めて での微分係数を計算 julia> f = :(x * x

    + 3) :(x * x + 3) julia> df = derivative(f) :((x * 1 + 1x) + 0) julia> x = 2; eval(df) 4 julia> x = 10; eval(df) 20 数式微分の実装 2.2 数式微分 -式の表現と微分 65 / 142
  18. df = ((x * 1 + 1x) + 0) ...

    2x にはなっているが冗長? 数式微分の改良 ~ 複雑な表現 2.2 数式微分 -式の表現と微分 66 / 142
  19. 自明な式の簡約を行ってみる 足し算の引数から 0 を除く. 掛け算の引数から 1 を除く. function add(args) args

    = filter(x -> x != 0, args) if length(args) == 0 return 0 elseif length(args) == 1 return args[1] else return Expr(:call, :+, args...) end end 簡約化 2.2 数式微分 -式の表現と微分 67 / 142
  20. 掛け算の引数から 1 を取り除く. function mul(args) args = filter(x -> x

    != 1, args) if length(args) == 0 return 1 elseif length(args) == 1 return args[1] else return Expr(:call, :*, args...) end end 簡約化 2.2 数式微分 -式の表現と微分 68 / 142
  21. 数式微分 + 自明な簡約 derivative(ex::Symbol) = 1 derivative(ex::Int64) = 0 function

    derivative(ex::Expr) op = ex.args[1] if op == :+ return add([derivative(ex.args[2]), derivative(ex.args[3])]) elseif op == :* return add([ mul([ex.args[2], derivative(ex.args[3])]), mul([derivative(ex.args[2]), ex.args[3]]) ]) end end 簡約化 2.2 数式微分 -式の表現と微分 69 / 142
  22.  簡単な式を得られた julia> derivative(:(x * x + 3)) :(x + x)

    ⇨ ではこれでうまくいく? julia> derivative(:((1 + x) / (2 * x^2))) :((2 * x ^ 2 - (1 + x) * (2 * (2x))) / (2 * x ^ 2) ^ 2) 簡約化 2.2 数式微分 -式の表現と微分 70 / 142
  23. * * * * * x x x x *

    * x x x x julia> t1 = :(x * x) julia> t2 = :($t1 * $t1) julia> f = :($t2 * $t2) :(((x * x) * (x * x)) * ((x * x) * (x * x))) という は、木で表現すると... 式の表現法を考える 71 / 142
  24. julia> t1 = :(x * x) julia> t2 = :($t1

    * $t1) julia> f = :($t2 * $t2) :(((x * x) * (x * x)) * ((x * x) * (x * x))) 作るときは単純な関数が、なぜこんなに複雑になってしまったのか? ⇨ (木構造で表す) 式には、代入・束縛がない ので、共通のものを参照できない. ⇨ アルゴリズムを記述する言語として、数式(木構造)は貧弱 式の表現法を考える 2.2 数式微分 -式の表現と微分 72 / 142
  25. / - ^ * * 2 ^ x 2 +

    * 1 x 2 * 2 x * 2 2 ^ x 2 :((2 * x ^ 2 - (1 + x) * (2 * (2x))) / (2 * x ^ 2) ^ 2) も、 式の表現法を考える 74 / 142
  26. y_{1} y_{2} y_{3} y_{4} y_{5} y_{6} y_{7} y_{8} y_{9} y_{1}

    ^ x 2 y_{2} * 2 y_{3} + 1 x y_{4} * 2 x y_{5} * 2 y_{6} * y_{7} - y_{8} ^ 2 y_{9} / 式の表現法を考える 75 / 142
  27. [需要] 制御構文・関数呼び出し etc... 一般的なプログラミング言語によって 記述されたアルゴリズムに対しても、 微分したい x = [1, 2,

    3] y = [2, 4, 6] function linear_regression_error(coef) pred = x * coef error = 0. for i in eachindex(y) error += (y[i] - pred[i])^2 end return error end 式からアルゴリズムへ、木からDAGへ 2.2 数式微分 -式の表現と微分 76 / 142
  28. 木構造の式 から 木構造の式 ⇩ (ふつうの) プログラム から プログラム へ 式からアルゴリズムへ、木からDAGへ 2.2 数式微分

    -式の表現と微分 ヒューリスティックにやってそれなりに簡単な式を得られれば実用的には大丈夫なので与太話になりますが、簡約化を頑張れば最もシンプルな式を得られるか考えてみます。 簡単さの定義にもよるかもしれませんが、 で な は と簡約化されるべきでしょう。 ところが、 が四則演算と と有理数, で作れる式のとき、 か判定する問題は決定不能であることが知られています。(Richardson's theorem) したがって、一般の式を入力として、最も簡単な式を出力するようなアルゴリズムは存在しないとわかります。 77 / 142
  29. [計算グラフ] 計算過程をDAGで表現 計算グラフによる表現 単に計算過程を表しただけのものを Kantorovich グラフなどと呼び、 これに偏導関数などの情報を加えたものを計算グラフと呼ぶような定義もあります. (伊里, 久保田 (1998)

    に詳しく形式的な定義があります) ただ、単に計算グラフというだけで計算過程を表現するグラフを指すという用法はか なり普及していて一般的と思われます。そのためここでもそれに従って計算過程を表 現するグラフを計算グラフと呼びます. 82 / 142
  30. z Mul x y Sub u Add v 変数 に対する

    による偏微分の 計算グラフ上の表現 から への全ての経路の偏微分の総積の総和 は から への全ての経路の集合. は変数 から変数 への辺を表す. 連鎖律と計算グラフの対応 88 / 142
  31. z x6 x5 x2 x3 v u 一番簡単なやりかた を求める: graph

    = ComputationalGraph(f) ∂z_∂u = 0 for path in all_paths(graph, u, z) ∂z_∂u += prod(grad(s, t) for (s, t) in path) end キャッシュ 89 / 142
  32. z x6 x5 x2 x3 v u 続いて を求める: ∂z_∂v

    = 0 for path in all_paths(graph, v, z) ∂z_∂v += prod(grad(s, t) for (s, t) in path) end キャッシュ 90 / 142
  33. z x6 x5 x2 x3 v u 共通部がある! 独立して計算するのは非効率. ⇨

    うまく複数のノードからの経路を計算する. 自動微分とキャッシュ 91 / 142
  34. 1 x6 6 x5 30 30 x2 x3 v u

    Backward-Mode AD 95 / 142
  35. 1 x6 6 x5 30 30 x2 x3 90 60

    Backward-Mode AD 96 / 142
  36. x x6 x5 x2 x3 y z ... から辿っていくことで、共通部を共有しつつ, 複数の出力に対する偏微分係数を一

    度に計算できる ⇨ 前進型自動微分 (Forward-Mode AD) Forward-Mode AD 2.3 自動微分 ─式からアルゴリズムへ 99 / 142
  37. 逆向き自動微分 (Backward-Mode AD) に対して、 の場合に効率的 勾配を一回グラフを走査するだけで計算可能 前進型自動微分 (Forward-Mode AD) に対して、

    の場合に効率的 ヤコビ行列の一列を一回グラフを走査するだけで計算可能 実装では定数倍が軽くなりがちなため、 が小さい場合には効率的な可能性が高い Backward / Forward-Mode AD 2.3 自動微分 ─式からアルゴリズムへ 時間がないため割愛しましたが、 Forward-Mode AD の話でよく出てくる 二重数 というものがあります. かつ なる を考え、 これと実数からなる集合上の演算を素直に定義すると一見、不思議なことに微分が計算される... というような面白い数です。  興味があるかたは 「2乗してはじめて0になる数」とかあったら面白くないですか?ですよね - アジマティクス や ForwardDiff.jlのドキュメント などおすすめです。 100 / 142
  38. PyTorch / Chainer は Wengert List ではなく計算グラフを使っている. [1] No tape.

    Traditional reverse-mode differentiation records a tape (also known as a Wengert list) describing the order in which operations were originally executed; <中略> An added benefit of structuring graphs this way is that when a portion of the graph becomes dead, it is automatically freed; an important consideration when we want to free large memory chunks as quickly as possible. Zygote.jl, Tensorflow などは Wengert List を使っている. 計算グラフ vs Wengert List 2.3 自動微分 ─式からアルゴリズムへ [1] Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L. & Lerer, A. (2017). Automatic Differentiation in PyTorch. NIPS 2017 Workshop on Autodiff, . [2] 計算グラフとメモリの解放周辺で、Chainer の Aggressive Buffer Release という仕組みがとても面白いです: Aggressive buffer release #2368 102 / 142
  39. 一般的なプログラムを 直接解析 • • • • して (微分が計算できる) 計算グラフを得る プログラムを実装するのはとても難易度

    が高い. x = [1, 2, 3] y = [2, 4, 6] function linear_regression_error(coef) pred = x * coef error = 0. for i in eachindex(y) error += (y[i] - pred[i])^2 end return error end 計算グラフをどう得るか? 2.3 自動微分 ─式からアルゴリズムへ 104 / 142
  40. import Base mutable struct Scalar values creator grad generation name

    end Base.:+(x1::Scalar, x2::Scalar) = calc_and_build_graph(+, x1, x2) Base.:*(x1::Scalar, x2::Scalar) = calc_and_build_graph(*, x1, x2) ... トレースの OO による典型的な実装 2.3 自動微分 ─式からアルゴリズムへ 107 / 142
  41. 「実際そのときあった演算」 のみが 記録され問題になる ⇨ 制御構文がいくらあってもOK function crazy_function(x, y) rand() <

    0.5 ? x + y : x - y end x = Scalar(2.0, name="x") y = Scalar(3.0, name="y") z = crazy_function(x, y) JITrench.plot_graph(z, var_label=:name) トレースの利点 108 / 142
  42. どちらも数値微分のパッケージ 概ね機能は同じだが、スパースなヤコビ行列を求めたいときやメモリのアロケー ションを気にしたいときは FIniteDiff.jl を使うといいかもしれない 詳しい比較は https://github.com/JuliaDiff/FiniteDifferences.jl julia> using FiniteDifferences

    julia> central_fdm(5, 1)(sin, π / 3) 0.4999999999999536 FiniteDiff.jl/FiniteDifferences.jl FiniteDiff.jl/FiniteDifferences.jl 数値微分時代については 「付録: 数値微分」 を参照してください 113 / 142
  43. Forward-Mode の自動微分を実装したパッケージ 小規模な関数の微分を行う場合は高速なことが多い julia> using ForwardDiff julia> f(x) = x^2

    + 4x f (generic function with 1 method) julia> df(x) = ForwardDiff.derivative(f, x) # 2x + 4 df (generic function with 1 method) julia> df(2) 8 ForwardDiff.jl 115 / 142
  44. ソースコード変換ベースのAD JuliaのコードをSSA形式のIRに変換 して導関数を計算するコード (Adjoint Code) を生成 julia> f(x) = 3x^2

    f (generic function with 1 method) julia> Zygote.@code_ir f(0.) 1: (%1, %2) %3 = Main.:^ %4 = Core.apply_type(Base.Val, 2) %5 = (%4)() %6 = Base.literal_pow(%3, %2, %5) %7 = 3 * %6 return %7 Zygote.jl 118 / 142
  45. function numerical_derivative(f, x; h=1e-8) g = (f(x+h) - f(x)) /

    h return g end f(x) = sin(x) f′(x) = cos(x) x = π / 3 numerical_derivative(f, x) # 0.4999999969612645 f′(x) # 0.5000000000000001 数値微分の実装 数値微分 124 / 142
  46. 実験: なら、 をどんどん小さくすればいくらでも精度が良くなるはず? H = [0.1^i for i in 4:0.5:10]

    E = similar(H) for i in eachindex(H) d = numerical_derivative(f, x, h=H[i]) E[i] = abs(d - f′(x)) end plot(H, E) 誤差の最小化 数値微分 128 / 142
  47. 実はこれの方が精度がよい! [定理3. 中心差分の誤差] 同じようにテイラー展開をするとわかる また、簡単な計算で一般の について 点評価で の近似式を得られる 中心差分による2次精度の数値微分 数値微分

    中心差分と同様に から左右に 個ずつ点とってこれらの評価の重みつき和を考えてみます。 すると、テイラー展開の各項を足し合わせて 以外の係数を にすることを考えることで公比が各列 で初項 のヴァンデルモンド行列を として を満たす を 各成分 で割ったのが求めたい重みとわかります. あとはこれの重み付き和をとればいいです. 同様にして 階微分の近似式も得られます. 134 / 142
  48. 2. 桁落ちへの対応 Q. 打ち切り誤差と丸め誤差のトレードオフで を小さくすればいいというものじゃな いことはわかった。じゃあ、最適な は見積もれる? A. 最適な は

    の 階微分の大きさに依存するから簡単ではない. 例) 中心差分 は くらい ? ⇨ がわからないのに を使った式を使うのは現実的でない. しょうがないので に線を引いてみると... 数値微分の改良 ~ 桁落ちへの対応 数値微分 135 / 142
  49. の における勾配ベクトル を求める function numerical_gradient(f, x::Vector; h=1e-8) n = length(x)

    g = zeros(n) y = f(x...) for i in 1:n x[i] += h g[i] = (f(x...) - y) / h x[i] -= h end return g end ⇨ を 回評価する必要がある. 多変数関数への拡張 数値微分 139 / 142
  50. 1. 久保田光一, 伊里正夫 「アルゴリズムの自動微分と応用」 コロナ社 (1998) i. 自動微分そのものついて扱ったおそらく唯一の和書です. 詳しいです. ii. 形式的な定義から、計算グラフの縮小のアルゴリズムや実装例と基礎から実用まで触れられています.

    iii. サンプルコードは、FORTRAN, (昔の) C++ です. 2. 斉藤康毅 「ゼロから作るDeep Learning ③」 O'Reilly Japan (2020) i. トレースベースの Reverse AD を Python で実装します. ii. Step by step で丁寧に進んでいくので、とてもおすすめです. iii. 自動微分自体について扱った本ではないため、その辺りの説明は若干手薄かもしれません. 3. Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2015). Automatic differentiation in machine learning: A survey. ArXiv. /abs/1502.05767 i. 機械学習 x AD のサーベイですが、機械学習に限らず AD の歴史やトピックを広く取り上げてます. ii. 少し内容が古くなっているかもしれません. 4. Differentiation for Hackers i. Flux.jl や Zygote.jl の開発をしている Mike J Innes さんが書いた自動微分の解説です。 Juliaで動かしながら勉強できます. おすすめです. 5. Innes, M. (2018). Don't Unroll Adjoint: Differentiating SSA-Form Programs. ArXiv. /abs/1810.07951 i. Zygote.jl の論文です. かなりわかりやすいです. 6. Gebremedhin, A. H., & Walther, A. (2019). An introduction to algorithmic differentiation. Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery, 10(1), e1334. https://doi.org/10.1002/widm.1334 i. 実装のパラダイムやCheckpoint, 並列化などかなり広く触れられています 7. Zygote.jl のドキュメントの用語集 i. 自動微分は必要になった応用の人がやったり、コンパイラの人がやったり、数学の人がやったりで用語が乱立しまくっているのでこちらを参照して整理すると良いです 8. JuliaDiff i. Julia での微分についてまとまっています. 9. Chainer のソースコード i. Chainer は Python製の深層学習フレームワークですが、既存の巨大フレームワークと比較すると、裏も Pythonでとても読みやすいです. ii. 気になる実装があったら当たるのがおすすめです. 議論もたくさん残っているのでそれを巡回するだけでとても勉強になります. 自動微分の勉強で参考になる文献 142 / 142