Upgrade to Pro — share decks privately, control downloads, hide ads and more …

[Journal club] Sigmoid Loss for Language Image Pre-Training

[Journal club] Sigmoid Loss for Language Image Pre-Training

More Decks by Semantic Machine Intelligence Lab., Keio Univ.

Other Decks in Technology

Transcript

  1. Sigmoid Loss for Language Image Pre-Training Xiaohua Zhai⋆, Basil Mustafa,

    Alexander Kolesnikov, Lucas Beyer⋆ Google DeepMind 慶應義塾大学 杉浦孔明研究室 小槻誠太郎 X. Zhai, B. Mustafa, A. Kolesnikov, and L. Beyer, “Sigmoid Loss for Language Image Pre-Training,” in ICCV, 2023, pp. 11975–11986. ICCV’23 Oral
  2. 4 関連研究 – CLIPに代表されるVision-Language Pretraining 対比損失を用いるCLIPの改良・派生手法が多数 公開データセット・モデルも複数出現 手法 概要 LAION-5B

    [Schuhmann+, ‘22] 大規模な画像-テキスト対データセットを公開 WIT [Srinivasan+, ‘21] Wikipediaベースの大規模な画像-テキスト対 データセット (マルチモーダル, 多言語, 公開) OpenCLIP [Ilharco+, ‘21] 公開データセット上で学習されたCLIP実装 BLIP [Li+, ICML’22] Captioning, filteringによってbootstrapに detasetをクリーニングしつつ学習
  3. 各 i,j-pair について独立に計算可能 9 SigLIP, SigLiT – Sigmoid関数に基づいた損失関数で学習 データの事前分布を加味 (#Negative

    >> #positive) バイアス項を導入, b=-10で初期化 別にバッチ全体で対比損失を考えなくても positive pairはcosine類似度 = 1, negative pairはcosine類似度 = -1 に近づけば良い
  4. 10 SigLIP, SigLiT – 効率的なマルチデバイス実装が可能 batch size: 12,デバイス x3 の例

    CLIPの単純な実装では一つの画像 特徴量に対して全てのテキスト特徴量 を同時にメモリに載せる必要がある. 要求メモリサイズがbatch size依存 à スケールしにくい Device 1 Device 2 Device 3 I₁ I₂ I₃ I₄ I₅ I₆ I₇ I₈ I₉ I₁₀ I₁₁ I₁₂ Device 1 T₁ T₂ T₃ T₄ Device 2 T₅ T₆ T₇ T₈ Device 3 T₉ T₁₀ T₁₁ T₁₂
  5. 11 SigLIP, SigLiT – 効率的なマルチデバイス実装が可能 batch size: 12,デバイス x3 の例

    黄色でハイライトした箇所のみ メモリに載せる まず対角線上のペアについて 3デバイスに分散させて並列計算 Device 1 Device 2 Device 3 I₁ I₂ I₃ I₄ I₅ I₆ I₇ I₈ I₉ I₁₀ I₁₁ I₁₂ Device 1 T₁ + – – – T₂ – + – – T₃ – – + – T₄ – – – + Device 2 T₅ + – – – T₆ – + – – T₇ – – + – T₈ – – – + Device 3 T₉ + – – – T₁₀ – + – – T₁₁ – – + – T₁₂ – – – + ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ loss 33% 33% 33% 33% 33% 33% 33% 33% 33% 33% 33% 33% Device 1 Device 2 Device 3
  6. 12 SigLIP, SigLiT – 効率的なマルチデバイス実装が可能 batch size: 12,デバイス x3 の例

    黄色でハイライトした箇所のみ メモリに載せる 各デバイスが持つテキスト特徴量を 隣のデバイスに送り, 同様に計算 Device 1 Device 2 Device 3 I₁ I₂ I₃ I₄ I₅ I₆ I₇ I₈ I₉ I₁₀ I₁₁ I₁₂ Device 3 T₁ ✓ ✓ ✓ ✓ – – – – T₂ ✓ ✓ ✓ ✓ – – – – T₃ ✓ ✓ ✓ ✓ – – – – T₄ ✓ ✓ ✓ ✓ – – – – Device 1 T₅ – – – – ✓ ✓ ✓ ✓ T₆ – – – – ✓ ✓ ✓ ✓ T₇ – – – – ✓ ✓ ✓ ✓ T₈ – – – – ✓ ✓ ✓ ✓ Device 2 T₉ – – – – ✓ ✓ ✓ ✓ T₁₀ – – – – ✓ ✓ ✓ ✓ T₁₁ – – – – ✓ ✓ ✓ ✓ T₁₂ – – – – ✓ ✓ ✓ ✓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ loss 66% 66% 66% 66% 66% 66% 66% 66% 66% 66% 66% 66% Device 1 Device 2 Device 3
  7. 13 SigLIP, SigLiT – 効率的なマルチデバイス実装が可能 batch size: 12,デバイス x3 の例

    黄色でハイライトした箇所のみ メモリに載せる 操作を繰り返して全体を計算し, デバイスごとに計算した値の 総和を取る Device 1 Device 2 Device 3 I₁ I₂ I₃ I₄ I₅ I₆ I₇ I₈ I₉ I₁₀ I₁₁ I₁₂ Device 2 T₁ ✓ ✓ ✓ ✓ – – – – ✓ ✓ ✓ ✓ T₂ ✓ ✓ ✓ ✓ – – – – ✓ ✓ ✓ ✓ T₃ ✓ ✓ ✓ ✓ – – – – ✓ ✓ ✓ ✓ T₄ ✓ ✓ ✓ ✓ – – – – ✓ ✓ ✓ ✓ Device 3 T₅ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ – – – – T₆ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ – – – – T₇ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ – – – – T₈ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ – – – – Device 1 T₉ – – – – ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ T₁₀ – – – – ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ T₁₁ – – – – ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ T₁₂ – – – – ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ loss ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ Device 1 Device 2 Device 3 ↘ ↓ ↙ Cross Device Σ
  8. Device 1 Device 2 Device 3 I₁ I₂ I₃ I₄

    I₅ I₆ I₇ I₈ I₉ I₁₀ I₁₁ I₁₂ Device 1 T₁ + – – – T₂ – + – – T₃ – – + – T₄ – – – + Device 2 T₅ + – – – T₆ – + – – T₇ – – + – T₈ – – – + Device 3 T₉ + – – – T₁₀ – + – – T₁₁ – – + – T₁₂ – – – + ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ loss 33% 33% 33% 33% 33% 33% 33% 33% 33% 33% 33% 33% Device 1 Device 2 Device 3 14 SigLIP, SigLiT – 効率的なマルチデバイス実装が可能 batch size: 12,デバイス x3 の例 黄色でハイライトした箇所のみ メモリに載せる デバイスごとの要求メモリサイズ はbatch sizeに依存しない Batch sizeが増加しても デバイスを増やせばデバイスごとの メモリ使用量を落とせる
  9. SigLiT: Vision: ViT-g (pretrained+frozen) Text: From scratch LiT image-text dataset

    SigLIP: B/16 ViT B-sized transformer WebLI dataset (Eng.) mSigLIP (multilingual) B-sized ViT B-sized text models WebLI dataset (100 lang.) #sample: 900M 16 定量的結果 – 既存手法を上回る / Batch sizeは32k程度でサチる ) ) ( ) ) ) ) - 0-shot acc. on ImageNet Recall@1 on crossmodal 3600 dataset 0-shot acc. on ImageNet
  10. Strength 提案がシンプルかつ強力 実験が豊富 (Resultsのsubsectionが10個, 使用しているデータセットも複数.) Weakness Table 4の内容はバイアス項に関する主張を十分に裏付けられていない気がする. 最終的な性能を比較するだけだと本当に初期の挙動に作用したのかわからない. その他

    なんだかんだ言ってtransformerは大文字始まりなことが多かった気がするが, 小文字始まりになっている (と思いきや大文字始まりも混じっている) > B-sized transformer for text embeddings 19 おきもち
  11. バイアス項を-10で初期化することで一貫して性能向上 実験: SigLIP setup 1. Base architecture 2. 8k batch

    size 3. Trained for 900M examples 21 Ablation study – バイアス項と温度パラメータの初期化
  12. Use 2B SigLIP Vision model to obtain visual feature. Encode

    & decode visual features & text embeddings by 3B UL2. SMALLER, FASTER, STRONGER 22 PaLI-3
  13. Vocab. sizeが大きいと単語埋め込みに必要な行列が巨大化🤮 ( 特に multilingual 設定など ) [Vocab. size] x

    [embedding dimension of the text model] 2つの行列を用意して一度低次元空間に写像してから戻すことで 必要なパラメータ数を削減😄 23 Bottlenecked token embedding F2 : RK à RW F1 : RN à RK Vocab size: N Embedding dim.: W