Upgrade to Pro — share decks privately, control downloads, hide ads and more …

20250515_SeedEverythingな人を救いたい

 20250515_SeedEverythingな人を救いたい

内部の勉強会で使った、Pytorchのseedに関する備忘スライドです
RNGに関してミスしやすいポイントをまとめています

Avatar for yuyagi

yuyagi

May 19, 2025
Tweet

Other Decks in Technology

Transcript

  1. 疑似乱数とは 簡単に言うと、 「見かけはランダムだけど、中身は決まった計算で作られる数列」 疑似乱数(Pseudo Random Number)とは、アルゴリズムによって生成された乱数 モデルの初期化 学習データの準備 / 前処理

    モデルの学習 モデルの評価 テストデータへの予測 ◆ 機械学習・深層学習での使われ方 ※重み初期化、データのシャッフル・分割、データ拡張、ドロップアウトなど 第46回 市村学術賞 功績賞 -01:一様疑似乱数発生法の高機能化 決まった計算で 生成するので、 実際には一様乱数にならない。 メルセンヌツイスタによる一様乱数
  2. 疑似乱数の仕組み ① 疑似乱数生成器は、シード(seed)と呼ばれる初期状態からスタートする ② 生成器は内部状態(state)を持ち、乱数の出力を持ち、その状態を更新 ② 同じシードから開始すれば常に同じ乱数列が再現される 種(seed) 内部状態 疑似乱数列

    疑似乱数生成器(RNG) ① 内部状態の初期化 ② 乱数生成 ③ 内部状態の更新 参考:【初学者向け】暗号基本技術まとめ その1 実験の再現性確保のためには、 シードを固定する必要がある
  3. 疑似乱数の仕組み 種(seed) 内部状態 疑似乱数列 疑似乱数生成器(RNG) ① 内部状態の初期化 ② 乱数生成 しかし、コードの実行順序や、

    マルチプロセス等の処理で、 想定外の挙動が発生することがよくある ③ 内部状態の更新 ① 疑似乱数生成器は、シード(seed)と呼ばれる初期状態からスタートする ② 生成器は内部状態(state)を持ち、乱数の出力を持ち、その状態を更新 ② 同じシードから開始すれば常に同じ乱数列が再現される 参考:【初学者向け】暗号基本技術まとめ その1
  4. よくあるミスの例:途中Epochから再開したら結果が違う • 学習途中の重みやオプティマイザの状態を保存、再読み込みしたのに、全エポック通し で学習した場合と結果が異なる モデルの初期化 学習データの準備 / 前処理 モデルの学習 モデルの評価

    テストデータへの予測 このタイミングでseedを固定 コードの実行開始 20エポック中、10エポックで学習を中断 (weight、optimizer等は保存) その後、重みを読み込んで、 残り10エポックで学習 ↓ 20エポックを一気に学習した場合と 結果が異なる!
  5. 原因:乱数生成器の内部状態 • Seed固定の関数は呼び出せているが、乱数生成器の内部状態が再現できていない モデルの初期化 学習データの準備 / 前処理 モデルの学習 モデルの評価 テストデータへの予測

    ◆ 例1:途中Epochから再開したい場合 種(seed) 内部状態 疑似乱数列 疑似乱数生成器(RNG) ② 乱数生成 ③ 内部状態の更新 学習コードの流れ ① 内部状態の初期化 モデルの初期化、データ拡張など、 学習が進むにつれてRNGの状態が変化 → 初期状態のRNGから再開すると、全体の結果が変わってしまう
  6. ポイントと対処方法:内部状態について ✓Python、Numpy、pytorchライブラリはそれぞれのアルゴリズムで疑似乱数を計算 ✓ライブラリごとに別のRNGをつかっているため、それぞれが内部状態を持つ ✓内部状態は取得・読み込みすることができる → 途中Epochから再開する場合は、これらの内部状態も保存して読み込むとよい Numpy Python Pytorch Seed

    固定 メソッド Random.seed(seed) np.random.seed(seed) torch.manual_ seed(seed) RNGの 取得 読み込み random.getstate() random.setstate(state) numpy.random.get_state() numpy.random.set_state(state) torch.get_rng_st ate() torch.set_rng_st ate(state) 採用 アルゴ リズム メルセンヌツイスタ PCG64(numpy 1.17+) メルセンヌ ツイスタ torch.cuda.man ual_seed(seed) torch.cuda.get_rng _state_all() torch.cuda.set_rng _state_all(state) Philox CPU GPU
  7. さらなる注意ポイント:乱数シードの有効範囲 • シードの固定はメインプロセスでのみされていて、別プロセスには自動では適用されない → 学習中にマルチプロセスな処理を行う場合は注意が必要 ※学習コードの中では、データを取り出す際に、マルチプロセスな処理を行うことが多い ◆ よくある現象: Pytorch DataLoaderのWorkerを増減したら、結果が変化

    ・並列でデータを取り出すために、num_worker > 0に設定する場合、 処理がマルチプロセスになるので注意が必要 . num_workers>0の時、 データの取り出し専用の別のプロセスが立ち上がる この時、Pytorch側でseed固定してくれるような実装だが、 仕様が少しわかりづらい・・・(次頁) モデルの初期化 学習データの準備 / 前処理 モデルの学習 モデルの評価 テストデータへの予測
  8. Pytorch DataLoaderのデフォルト挙動(エポック開始時) メインプロセスのtorch(CPU)のRNGで乱数を生成し、その乱数にworker_idを足した値で、 サブプロセスの各seedを設定(サブプロセスごとにseedが違う、かつ毎エポックの開始時に動作) Worker 0 メインプロセス(訓練スクリプト) サブプロセス(DataLoader) Seed (Python)

    RNG (Numpy) RNG (torch CPU) RNG (Python) RNG (Numpy) RNG (torch CPU) 乱数生成 base_seed + worker_id seedを設定 _generate_state (seedの桁数調整) Worker 1 RNG (Python) RNG (Numpy) RNG (torch CPU) base_seed + worker_id seedを設定 _generate_state (seedの桁数調整) Worker 2 RNG (Python) RNG (Numpy) RNG (torch CPU) base_seed + worker_id seedを設定 _generate_state (seedの桁数調整) ※最新(2.6.0)のソースコードでは、 サブプロセスのnumpyのseedも設定してくれる!(下記228行目以降) https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py 新しいプロセスが作られた (≒エポック開始時) のtorchのRNGの内部状態で、 base_seedを生成する
  9. (参考)学習途中から再開する場合(マルチプロセス) 途中Epochから再開したかったら、サブプロセスのseed全部読み込む必要があるのでは・・・? →基本的にはメインプロセスのRNGだけ読み込めばOK (・・・のはず) Worker 0 メインプロセス(訓練スクリプト) サブプロセス(DataLoader) Seed (Python)

    RNG (Numpy) RNG (torch CPU) RNG (Python) RNG (Numpy) RNG (torch CPU) 乱数生成 base_seed + worker_id seedを設定 _generate_state (seedの桁数調整) Worker 0 RNG (Python) RNG (Numpy) RNG (torch CPU) base_seed + worker_id seedを設定 _generate_state (seedの桁数調整) Worker 0 RNG (Python) RNG (Numpy) RNG (torch CPU) base_seed + worker_id seedを設定 _generate_state (seedの桁数調整) Torchの乱数生成器が生成 した乱数が、サブプロセスの base_seedになる → 赤枠内さえ再現できればOK!