Upgrade to Pro — share decks privately, control downloads, hide ads and more …

BurnでDeep Learningやってみる

Avatar for Shirokuma Shirokuma
February 24, 2023

BurnでDeep Learningやってみる

RustハイブリッドLTでの発表

Avatar for Shirokuma

Shirokuma

February 24, 2023
Tweet

More Decks by Shirokuma

Other Decks in Technology

Transcript

  1. なぜRustなのか?深層学習といえばPythonなのでは? • Burnのブログで熱弁 ◦ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning • Pythonが使われてきた理由 ◦ シンプルで学びやすい

    ◦ 研究のサイクルを回しやすい • Pythonが使われていることの課題 ◦ フレームワーク(PyTorchなど)内部はC++ ◦ フレームワーク開発者よりのエンジニアと研究 者の間の技術の隔たり
  2. なぜRustなのか?深層学習といえばPythonなのでは? • Burnのブログで熱弁 ◦ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning • Pythonが使われてきた理由 ◦ シンプルで学びやすい

    ◦ 研究のサイクルを回しやすい • Pythonが使われていることの課題 ◦ フレームワーク(PyTorchなど)内部はC++ ◦ フレームワーク開発者よりのエンジニアと研究 者の間の技術の隔たり • Rustで解決できること ◦ 1つの言語で低レベルから抽象レイヤまで扱 える ◦ エンジニアと研究者の隔たりを無くす
  3. 一般的な深層学習フレームワークに含まれる機能 • テンソル計算 ◦ 一般的なテンソルを用いた計算 ◦ 自動微分 ◦ CPU・GPUなどの使用するリソースの切り替え •

    データ準備 ◦ 保存されている画像データなどを読み取ってテンソルに変換 ◦ 一般的なデータセットはダウンロードで取ってくる
  4. 一般的な深層学習フレームワークに含まれる機能 • テンソル計算 ◦ 一般的なテンソルを用いた計算 ◦ 自動微分 ◦ CPU・GPUなどの使用するリソースの切り替え •

    データ準備 ◦ 保存されている画像データなどを読み取ってテンソルに変換 ◦ 一般的なデータセットはダウンロードで取ってくる • モデル作成 ◦ 深層学習で使用するモデルのネットワーク構造を作る ◦ Lossの設定
  5. 一般的な深層学習フレームワークに含まれる機能 • テンソル計算 ◦ 一般的なテンソルを用いた計算 ◦ 自動微分 ◦ CPU・GPUなどの使用するリソースの切り替え •

    データ準備 ◦ 保存されている画像データなどを読み取ってテンソルに変換 ◦ 一般的なデータセットはダウンロードで取ってくる • モデル作成 ◦ 深層学習で使用するモデルのネットワーク構造を作る ◦ Lossの設定 • 学習 ◦ パラメタ設定 ◦ どのような手法で最適化するか選択 ◦ イテレーション回数やモデルの保存方法などを決める
  6. Burnのモジュール一覧 • テンソル計算(Backend) ◦ テンソルのバックエンドトレイト(Tch, NdArray) • データ準備(Dataset) ◦ PyTorchのDataLoaderに近い

    ◦ mnistやhuggingface hubなどのデータセットを使用できる • モデル作成(Module) ◦ NNのレイヤー(Linear, Convolution, Pooling, Activation, …) ◦ Loss関数(CrossEntropyLossのみ) • 学習(Config) ◦ デフォルト値設定やJSONシリアライズできる機械学習用パラメタ • 学習(Learner) ◦ 最適化ソルバの選択(SGD, Adam) ◦ 学習時のメトリクスやプロット、モデル保存の設定
  7. MNISTのコードをPyTorchと比較する • モデルの定義(PyTorchの場合) ◦ モデルとなる構造体を定義 ◦ コンストラクタでモデルに含まれる各層の初期化 ▪ 主にモデルパラメタを含むものが初期化される class

    Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) Conv2d Conv2d MaxPool2d ReLU ReLU Dropout Flatten Linear Linear ReLU Dropout Log Softmax
  8. MNISTのコードをPyTorchと比較する • モデルの定義(PyTorchの場合) ◦ forward関数に実際の計算を記述 ◦ backward(微分計算)は計算グラフを用いて自動微分よってに行われる def forward(self, x):

    x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output Conv2d Conv2d MaxPool2d ReLU ReLU Dropout Flatten Linear Linear ReLU Dropout Log Softmax
  9. MNISTのコードをPyTorchと比較する • モデルの定義(Burnの場合) ◦ PyTorchの書き方にかなり近い ◦ モデルとなる構造体にMuduleトレイトを継承させる #[derive(Module, Debug)] pub

    struct Model<B: Backend> { conv1: Param<Conv2d<B>>, conv2: Param<Conv2d<B>>, dropout1: Dropout, dropout2: Dropout, linear1: Param<Linear<B>>, linear2: Param<Linear<B>>, max_pool: MaxPool2d, } pub fn new() -> Self { Self { conv1: Param::new(Conv2d::new( & Conv2dConfig ::new([1, 32], [3, 3]), )), conv2: Param::new(Conv2d::new( & Conv2dConfig ::new([32, 64], [3, 3]), )), dropout1: Dropout::new(&DropoutConfig ::new(0.25)), dropout2: Dropout::new(&DropoutConfig ::new(0.5)), linear1: Param::new(Linear::new(&LinearConfig ::new(9216, 128))), linear2: Param::new(Linear::new(&LinearConfig ::new(128, 10))), max_pool: MaxPool2d::new( & MaxPool2dConfig ::new(64, [2, 2]).with_strides ([2, 2] )), } }
  10. MNISTのコードをPyTorchと比較する • モデルの定義(Burnの場合) ◦ PyTorchの書き方にかなり近い ◦ In/Outのテンソルの次元を定義しながら書けるのが嬉しい ▪ rust-analyzerで途中の出力の次元も分かる pub

    fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> { let [batch_size, heigth, width] = input.dims(); let x = input.reshape([batch_size, 1, heigth, width]).detach(); let x = self.conv1.forward(x); let x = relu(&x); let x = self.conv2.forward(x); let x = relu(&x); let x = self.max_pool.forward(x); let x = self.dropout1.forward(x); let x = x.reshape([batch_size, 9216]); let x = self.linear1.forward(x); let x = relu(&x); let x = self.dropout2.forward(x); let out = self.linear2.forward(x); out }
  11. MNISTのコードをPyTorchと比較する • パラメタ設定 ◦ Configトレイトを使うことで、デフォルト値の設定が可能 • 学習 ◦ 最適化はAdamを使用 ◦

    プロットやモデルのチェックポイント保存設定を行ってfitを実行 #[derive(Config)] pub struct MnistConfig { #[config(default = 2)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 8)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, pub optimizer: AdamConfig, } let learner = LearnerBuilder::new(ARTIFACT_DIR) .metric_train_plot(AccuracyMetric::new()) .metric_valid_plot(AccuracyMetric::new()) .metric_train_plot(LossMetric::new()) .metric_valid_plot(LossMetric::new()) .with_file_checkpointer::<f32>(2) .devices(vec![device]) .num_epochs(config.num_epochs) .build(model, optim); let _model_trained = learner.fit(dataloader_train, dataloader_test);