Upgrade to Pro — share decks privately, control downloads, hide ads and more …

JAXとFlaxを使って、ナウい機械学習をしたい

Sponsored · Your Podcast. Everywhere. Effortlessly. Share. Educate. Inspire. Entertain. You do you. We'll handle the rest.

 JAXとFlaxを使って、ナウい機械学習をしたい

JAXとFlaxの基本と、深層学習フレームワークの流れなど

Avatar for Moriyama Naoto

Moriyama Naoto

February 27, 2021
Tweet

Other Decks in Technology

Transcript

  1. 自己紹介 
 - 森山直人(Twitter: @vimmode) - みらい翻訳株式会社でリサーチエンジニア - 日本語<->中国語の言語間の機械翻訳がメイン -

    pythonと自然言語処理が好き - 深層学習はPyTorchを使うことが多いです - allennlp, flairあたりが好き - 発表内容は組織を代表するものではありません
  2. 深層学習フレームワークの機能 
 - 必要最低限の機能 - 学習データからミニバッチを作成 - ニューラルネットワークの定義 - 予測値を計算し、誤差から自動微分でパラメータを更新

    - 学習済みのモデルのシリアライズ - GPU, TPUなどのハードウェアアクセラレータ対応 - ニューラルネットワークの記述は主に2つのパラダイムがある - 以降のページで説明していきます
  3. Define and run
 - Caffe, TensorFlow1などが該当 - 静的な計算グラフを作ってから、データを流し込む - 内部構造は直感的であり、理解しやすい

    - pythonで計算グラフを定義に使うが、実行時はpythonは必要ない - 定義されたネットワークは実行時に変わることはないので、 デプロイと運用は安心・安全 - モバイルやエッジコンピューティングにも強い! - コーディングは深層学習への深い理解がないと直感的には書けない
  4. Define by run
 - Chainer, PyTorchなどが該当 - 変数に計算元の情報を保持させ、それを辿っていくとネットワークが出来る(計算グ ラフの概念を意識させない) -

    これにより、記述のしやすさが格段に向上 - ネットワークは入力が来て初めて作られる、永続化はパラメータの辞書 - (初期は)製品化でもpythonのruntimeが必要なため、言語由来の制約は多い - ネットワークが動的に変わることがあり、実運用で問題が発生し得る - 動作環境やメモリ、互換性など、デプロイはデータを流してみないとわからない
  5. Define by run VS Define and run
 - 研究ではDefine by

    runが支持され、製品運用ではDefine and runが支持 される構図に - Define and runであるTensorFlowは研究者から避けられることが多いが、 実務運用ではきわめて優秀 - Define by runであるPyTorchは、モデルをdefine and run スタイルである Caffe2に変換する機能を早期に採用したことで製品運用の課題を一定カ バー - Chainerとの強い差別化 - とはいえ、TensorFlowほど簡単ではない
  6. 異なるフレームワークの規格を統一したい
 - 記述が得意なフレームワークと、実装に優れたフレームワークの相互運用の ために、学習済みモデルの規格を統一させる => ONNX - PyTorch -> MXNetなど

    - 一方で、フレームワーク間で数値表現に違いが存在する場合があり、ONNX を交えた変換で計算結果が同じにならない事がある! - 平均や分散などの統計計算は注意が必要 - ONNX専用のruntimeを利用する話もあるが、時間の都合でここでは割愛し ます
  7. 現在の二大勢力の課題(個人感)
 Tensorflow - TensorFlow2ではdefine by run形式でコーディングできるようになったものの、TensorFlow1の 基本設計を考えると、かなり無理な拡張をしたと察する - kerasやeagerなど、抽象化機能が多くて書き方が多様すぎる PyTorch

    - 初期からCaffeに変換する設計だったこともあり、内部は複雑に - 細かいところはC++なので、内部実装把握はそこで力づきる - モデルとパラメータが密接に紐付いており、かつネットワークは計算時に確立されるため、量子化と いったパラメータ操作や、モデルの確実なシリアライズが複雑
  8. JAX
 Googleが開発した行列演算+自動微分+XLAのライブラリ (もともとはautogradというライブラリを拡張して設計されたもの) - 行列演算 - NumPyのAPIと完全互換(ただし非同期処理) - 自動微分 -

    自動微分をサポートすることで、JAXだけで簡単なニューラルネットワークが書ける - XLA - pythonで記載された線形代数関連の命令郡をまとめてハードウェアアクセラレータ向け にJITコンパイルし、一度で実行できるようにする。
  9. JAXの好きなところ(個人感)
 - pure python! - デバックや内部実装の把握がしやすい - とにかく早い - ミニバッチ内の処理など関数をすべて

    JITコンパイルすることで、全体の処理が高速化 - データのCPU -> GPU(TPU)間の移動がシームレスに出来る - 設計は関数型指向 - 行列のデータは基本的に変更不可 - インデックス/スライス経由の値変更やインプレース演算ができない設計 - 乱数生成はグローバルの乱数状態を参考にするのではなく、都度状態を生成
  10. Flax
 GoogleによるJAXをベースに実装された深層学習フレームワーク - JAX開発者と近い距離で開発されており、一枚岩感がある - JAX以上に、強い関数型指向の性格を持つ 🌟 - 各種深層学習フレームワークの負債を研究しており、設計思想がアツい 設計思想(抜粋&意訳)

    - 悪い抽象化や関数のオプションを増やすよりも、コードの複製を - ドキュメンテーションやテストが難しい部分は、設計を見直そう - 関数型スタイルは一部のユーザーを混乱させるが、高い利益をもたらす - 役に立たないエラーメッセージはバグ同然
  11. JAXとFlaxの所感- 悩むところ
 - 関数型指向な設計により、フレームワーク設計としての美しさは十分だが、入 門者にとっての学習コストは高い - とはいえ、慣れれば可読性と生産性はかなり高い - 既存の資産は簡単には転用できない -

    PyTorchとTensorFlow2間はある程度簡単だが、Flaxは少し複雑 - コミュニティがどれだけ大きくなるかは読めない - 世間一般ではPyTorchとTensorFlowはさほど強い不満は持たれていない - 実務観点で、既存のフレームワークからリプレイスするROIは難しいと思う