Upgrade to Pro — share decks privately, control downloads, hide ads and more …

BurnでDeep Learningやってみる

Shirokuma
February 24, 2023

BurnでDeep Learningやってみる

RustハイブリッドLTでの発表

Shirokuma

February 24, 2023
Tweet

More Decks by Shirokuma

Other Decks in Technology

Transcript

  1. なぜRustなのか?深層学習といえばPythonなのでは? • Burnのブログで熱弁 ◦ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning • Pythonが使われてきた理由 ◦ シンプルで学びやすい

    ◦ 研究のサイクルを回しやすい • Pythonが使われていることの課題 ◦ フレームワーク(PyTorchなど)内部はC++ ◦ フレームワーク開発者よりのエンジニアと研究 者の間の技術の隔たり
  2. なぜRustなのか?深層学習といえばPythonなのでは? • Burnのブログで熱弁 ◦ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning • Pythonが使われてきた理由 ◦ シンプルで学びやすい

    ◦ 研究のサイクルを回しやすい • Pythonが使われていることの課題 ◦ フレームワーク(PyTorchなど)内部はC++ ◦ フレームワーク開発者よりのエンジニアと研究 者の間の技術の隔たり • Rustで解決できること ◦ 1つの言語で低レベルから抽象レイヤまで扱 える ◦ エンジニアと研究者の隔たりを無くす
  3. 一般的な深層学習フレームワークに含まれる機能 • テンソル計算 ◦ 一般的なテンソルを用いた計算 ◦ 自動微分 ◦ CPU・GPUなどの使用するリソースの切り替え •

    データ準備 ◦ 保存されている画像データなどを読み取ってテンソルに変換 ◦ 一般的なデータセットはダウンロードで取ってくる
  4. 一般的な深層学習フレームワークに含まれる機能 • テンソル計算 ◦ 一般的なテンソルを用いた計算 ◦ 自動微分 ◦ CPU・GPUなどの使用するリソースの切り替え •

    データ準備 ◦ 保存されている画像データなどを読み取ってテンソルに変換 ◦ 一般的なデータセットはダウンロードで取ってくる • モデル作成 ◦ 深層学習で使用するモデルのネットワーク構造を作る ◦ Lossの設定
  5. 一般的な深層学習フレームワークに含まれる機能 • テンソル計算 ◦ 一般的なテンソルを用いた計算 ◦ 自動微分 ◦ CPU・GPUなどの使用するリソースの切り替え •

    データ準備 ◦ 保存されている画像データなどを読み取ってテンソルに変換 ◦ 一般的なデータセットはダウンロードで取ってくる • モデル作成 ◦ 深層学習で使用するモデルのネットワーク構造を作る ◦ Lossの設定 • 学習 ◦ パラメタ設定 ◦ どのような手法で最適化するか選択 ◦ イテレーション回数やモデルの保存方法などを決める
  6. Burnのモジュール一覧 • テンソル計算(Backend) ◦ テンソルのバックエンドトレイト(Tch, NdArray) • データ準備(Dataset) ◦ PyTorchのDataLoaderに近い

    ◦ mnistやhuggingface hubなどのデータセットを使用できる • モデル作成(Module) ◦ NNのレイヤー(Linear, Convolution, Pooling, Activation, …) ◦ Loss関数(CrossEntropyLossのみ) • 学習(Config) ◦ デフォルト値設定やJSONシリアライズできる機械学習用パラメタ • 学習(Learner) ◦ 最適化ソルバの選択(SGD, Adam) ◦ 学習時のメトリクスやプロット、モデル保存の設定
  7. MNISTのコードをPyTorchと比較する • モデルの定義(PyTorchの場合) ◦ モデルとなる構造体を定義 ◦ コンストラクタでモデルに含まれる各層の初期化 ▪ 主にモデルパラメタを含むものが初期化される class

    Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) Conv2d Conv2d MaxPool2d ReLU ReLU Dropout Flatten Linear Linear ReLU Dropout Log Softmax
  8. MNISTのコードをPyTorchと比較する • モデルの定義(PyTorchの場合) ◦ forward関数に実際の計算を記述 ◦ backward(微分計算)は計算グラフを用いて自動微分よってに行われる def forward(self, x):

    x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output Conv2d Conv2d MaxPool2d ReLU ReLU Dropout Flatten Linear Linear ReLU Dropout Log Softmax
  9. MNISTのコードをPyTorchと比較する • モデルの定義(Burnの場合) ◦ PyTorchの書き方にかなり近い ◦ モデルとなる構造体にMuduleトレイトを継承させる #[derive(Module, Debug)] pub

    struct Model<B: Backend> { conv1: Param<Conv2d<B>>, conv2: Param<Conv2d<B>>, dropout1: Dropout, dropout2: Dropout, linear1: Param<Linear<B>>, linear2: Param<Linear<B>>, max_pool: MaxPool2d, } pub fn new() -> Self { Self { conv1: Param::new(Conv2d::new( & Conv2dConfig ::new([1, 32], [3, 3]), )), conv2: Param::new(Conv2d::new( & Conv2dConfig ::new([32, 64], [3, 3]), )), dropout1: Dropout::new(&DropoutConfig ::new(0.25)), dropout2: Dropout::new(&DropoutConfig ::new(0.5)), linear1: Param::new(Linear::new(&LinearConfig ::new(9216, 128))), linear2: Param::new(Linear::new(&LinearConfig ::new(128, 10))), max_pool: MaxPool2d::new( & MaxPool2dConfig ::new(64, [2, 2]).with_strides ([2, 2] )), } }
  10. MNISTのコードをPyTorchと比較する • モデルの定義(Burnの場合) ◦ PyTorchの書き方にかなり近い ◦ In/Outのテンソルの次元を定義しながら書けるのが嬉しい ▪ rust-analyzerで途中の出力の次元も分かる pub

    fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> { let [batch_size, heigth, width] = input.dims(); let x = input.reshape([batch_size, 1, heigth, width]).detach(); let x = self.conv1.forward(x); let x = relu(&x); let x = self.conv2.forward(x); let x = relu(&x); let x = self.max_pool.forward(x); let x = self.dropout1.forward(x); let x = x.reshape([batch_size, 9216]); let x = self.linear1.forward(x); let x = relu(&x); let x = self.dropout2.forward(x); let out = self.linear2.forward(x); out }
  11. MNISTのコードをPyTorchと比較する • パラメタ設定 ◦ Configトレイトを使うことで、デフォルト値の設定が可能 • 学習 ◦ 最適化はAdamを使用 ◦

    プロットやモデルのチェックポイント保存設定を行ってfitを実行 #[derive(Config)] pub struct MnistConfig { #[config(default = 2)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 8)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, pub optimizer: AdamConfig, } let learner = LearnerBuilder::new(ARTIFACT_DIR) .metric_train_plot(AccuracyMetric::new()) .metric_valid_plot(AccuracyMetric::new()) .metric_train_plot(LossMetric::new()) .metric_valid_plot(LossMetric::new()) .with_file_checkpointer::<f32>(2) .devices(vec![device]) .num_epochs(config.num_epochs) .build(model, optim); let _model_trained = learner.fit(dataloader_train, dataloader_test);