Upgrade to Pro — share decks privately, control downloads, hide ads and more …

LLVM/ASMを使った有限体の高速実装

herumi
September 11, 2024

 LLVM/ASMを使った有限体の高速実装

2024/09/11 Ethereum Core Program talk

herumi

September 11, 2024
Tweet

More Decks by herumi

Other Decks in Technology

Transcript

  1. 自己紹介 @herumi サイボウズ・ラボで暗号と高速化のR&D 『暗号と認証のしくみと理論がしっかりわかる』 『Binary Hacks Rebooted』にも寄稿 Microsoft MVP C++,

    Developer Seucirty受賞 (2024/7) x86/x64用JITアセンブラ Xbyak (AArch64, RISC-V版もある) Google Open Source Peer Bonus受賞 (2024/6) UXL FoundationのoneAPI/oneDNNのx64 CPUエンジン ペアリング暗号・BLS署名ライブラリ mcl/bls 2014年zkSNARK/libsnark対応, 2016年DFINITYのBLS署名, 2018年Ethereum PoS対応, etc. mcl-wasmのGitHub network dependentsは20万repo, NPM 1900万DL Ethereum Foundation Retroactive Grants など獲得 2 / 45
  2. 背景 ペアリングと楕円曲線演算 ペアリング BLS署名やゼロ知識証明zk-SNARKなどの暗号技術で利用される要素技術 2種類の楕円曲線 , から有限体の拡大体 への写像 楕円曲線演算 楕円曲線の点の同士の加算・減算(

    )や2倍算( ) 楕円曲線の点の整数倍のスカラー倍算( ) 有限体演算 ペアリングや楕円曲線演算を実装するために必要な演算 整数同士の四則演算、特に256~384bit整数加減乗算が重要 3 / 45
  3. 有限体 定義 を素数として とする 加減乗算を普通の整数として計算した後 で割った余りを取ると定義する , for . の逆数

    は を満たす とする 加算の実装例 def add(x, y): return (x+y) % p 割り算は重たいので def add(x, y): t = x + y s = t - p return t if s < 0 else s 4 / 45
  4. select(c, x, y) 条件分岐用の関数 が true なら を返し, false なら

    を返す アセンブリ言語(以下ASM)レベルで分岐を避けるために使う 実装詳細は後述 def add(x, y): t = x + y s = t - p return select(s < 0, t, s) def sub(x, y): t = x - y s = t + p return select(t < 0, s, t) 5 / 45
  5. 整数の実装 整数の表現 長いビット長の整数はCPUのレジスタサイズ(unit=64 or 32)に合わせる 例えば64 bit CPUでは256bit整数を4個のunitで表現する x =

    [x3:x2:x1:x0] : xi はunit bit整数 キャリー(CF : Carry Flag, 繰り上がり) unit bit整数同士を足すとunit+1 bit整数になる unit bitレジスタには入らない あふれた部分は C では切り捨てられる uint64_t x, y, z; で z = x + y; は z = (x + y) mod 2^64 , CF = z>>64 加算結果が元の値よりも小さければ切り捨てられたと判定できる U z = x + y; bool CF = z < x; // z < yでも可 6 / 45
  6. C++による実装例 N 桁の配列同士の加算 // z[] = x[] + y[] and

    return CF // U : uint64_t or uint32_t // N : 配列の大きさ template<size_t N, typename U> U addT(U *z, const U *x, const U *y) { U CF = 0; // 最初はCF = 0 for (size_t i = 0; i < N; i++) { U xc = x[i] + CF; CF = xc < CF; // CF1 U yi = y[i]; xc += yi; CF += xc < yi; // CF2 z[i] = xc; } return CF; } 7 / 45
  7. x64におけるCFの扱い 加算命令 Input : x, y : 64 bit 整数レジスタ

    Output : x ← (x + y) % M, CF ← (x + y) ≥𝑀 ? 1 : 0; // 繰り上がりつき加算命令 Input : x, y and CF Output : x ← (x + y + CF) % 𝑀, CF ← (x + y + CF) ≥𝑀 ? 1 : 0 AArch64 同様のaddsとadcs命令がある RISC-V, wasm CFを扱う命令は無い 8 / 45
  8. 整数加算の例 x64 Windowsにおける例 呼び出し規約 U addN(U *z, const U *x,

    const U *y); rax rcx rdx r8 N=3の例 add3: mov rax, [rdx] ; rax ← x[0] add rax, [r8] ; rax ← rax + y[0] mov [rcx], rax ; z[0] ← rax mov rax, [rdx+8] ; rax ← x[1] adc rax, [r8+8] ; rax ← rax + y[1] + CF mov [rcx+8], rax ; z[2] ← rax mov rax, [rdx+16] ; rax ← x[2] adc rax, [r8+16] ; rax ← rax + y[2] + CF mov [rcx+16], rax ; z[2] ← rax setc al ; al ← CF movzx eax, al ; rax ← zero extension of al ret 9 / 45
  9. LLVM (Low Level Virtual Machine) LLVM-IR プラットフォーム非依存の低レベルな言語 : リファレンス LLVM-IRのソースコード

    → clang -S -O2 → x64/Aarch64/RISC-Vなど各種CPU用ASM 特徴 SSA (Static Single Assignment form) : 単一代入形式 任意の固定長整数レジスタ : %i{N} キャリー無しの任意の固定長add/subサポート %i{N} ← add(%i{N}, %i{N}) %i{N} ← sub(%i{N}, %i{N}) 最小限の乗算命令サポート 128bit ← 64bit × 64bit (64bit CPU) 64bit ← 32bit × 32bit (32bit CPU) 10 / 45
  10. 256ビット加算の例 仕様 define i1 @add4(i256* %pz, i256* %px, i256* %py)

    { %x = load i256, i256* %px ; x ← *(const uint256_t*)px; %y = load i256, i256* %py ; y ← *(const uint256_t*)py; %x2 = zext i256 %x to i257 ; x2 ← uint257_t(x); %y2 = zext i256 %y to i257 ; y2 ← uint257_t(y); %z = add i257 %x2, %y2 ; z ← x2 + y2 %z2 = trunc i257 %z to i256 ; z2 ← uint256_t(x);  store i256 %z2, i256* %pz ; *(uint256_t*)pz = z2; %z3 = lshr i257 %z, 256 ; z3 ← z >> 256 %z4 = trunc i257 %z3 to i1 ; z4 ← uint1_t(z3); ret i1 %z4 ; return z4; } 256bit整数 , のポインタ px , py を入力として pz に書き込む 繰り上がりがbool値で返る clang -S -O2 -target x86_64 や aarch64 でほぼ最適なASMが生成される 11 / 45
  11. s_xbyak_llvm.py LLVM-IRを書きやすくするためのツール(簡易DSL) 型は一度定義すればOK, 変数の再代入OK, 複数のNに対して1個のコードでOK # unit(=64)*N bit整数の有限体として*pz=*px+*pyを求める関数を生成 # p.bit_length()

    < N*unitを仮定 def gen_mcl_fp_add(N): pz = IntPtr(unit) # uint64_t *pz px = IntPtr(unit) py = IntPtr(unit) pp = IntPtr(unit) with Function('mcl_fp_add', Void, pz, px, py, pp): x = loadN(px, N) # N個のunit(=64)bit整数を読み込む y = loadN(py, N) x = add(x, y) # x = x + y p = loadN(pp, N) y = sub(x, p) # y = x - p c = trunc(lshr(y, unit*N - 1), 1) # c = y >> (unit*N-1) x = select(c, x, y) # x = c ? x : y storeN(x, pz) ret(Void) 13 / 45
  12. s_xbyak_llvmの使い方例 コード生成 % python3 gen.py > add.ll # 約80行 define

    void @mcl_fp_addNF4L(i64* noalias %r1, i64* noalias %r2, i64* noalias %r3, i64* noalias %r4) { %r5 = load i64, i64* %r2 %r6 = zext i64 %r5 to i128 %r7 = getelementptr i64, i64* %r2, i32 1 %r8 = load i64, i64* %r7 %r9 = zext i64 %r8 to i128 %r10 = shl i128 %r9, 64 %r11 = or i128 %r6, %r10 ... ... store i64 %r74, i64* %r73 %r75 = lshr i256 %r72, 64 %r76 = getelementptr i64, i64* %r1, i32 3 %r77 = trunc i256 %r75 to i64 store i64 %r77, i64* %r76 ret void } 14 / 45
  13. ASM化 clangに渡せばOK(x64では最適なコードが生成される) % clang-15 -O2 -S -masm=intel a.ll mcl_fp_add4: mcl_fp_add3:

    push rbx mov r8, qword ptr [rdx] mov r8, qword ptr [rdx] add r8, qword ptr [rsi] add r8, qword ptr [rsi] mov r9, qword ptr [rdx + 8] mov r9, qword ptr [rdx + 8] adc r9, qword ptr [rsi + 8] adc r9, qword ptr [rsi + 8] mov r10, qword ptr [rdx + 16] mov r10, qword ptr [rdx + 16] adc r10, qword ptr [rsi + 16] adc r10, qword ptr [rsi + 16] mov rsi, r8 mov r11, qword ptr [rdx + 24] sub rsi, qword ptr [rcx] adc r11, qword ptr [rsi + 24] mov rax, r9 mov rsi, r8 sbb rax, qword ptr [rcx + 8] sub rsi, qword ptr [rcx] mov rdx, r10 mov rax, r9 sbb rdx, qword ptr [rcx + 16] sbb rax, qword ptr [rcx + 8] mov rcx, rdx mov rdx, r10 sar rcx, 63 sbb rdx, qword ptr [rcx + 16] cmovs rdx, r10 mov rbx, r11 cmovs rax, r9 sbb rbx, qword ptr [rcx + 24] cmovs rsi, r8 cmovs rbx, r11 mov qword ptr [rdi], rsi cmovs rdx, r10 mov qword ptr [rdi + 8], rax cmovs rax, r9 mov qword ptr [rdi + 16], rdx cmovs rsi, r8 ret mov qword ptr [rdi], rsi mov qword ptr [rdi + 8], rax mov qword ptr [rdi + 16], rdx mov qword ptr [rdi + 24], rbx pop rbx ret 15 / 45
  14. その他のCPUにも対応 clang --targetを指定するだけ % clang-15 -O2 -S --target=aarch64 a.ll mcl_fp_add3:

    ldp x8, x12, [x2] ldp x9, x10, [x1] ldr x11, [x1, #16] ldr x13, [x2, #16] adds x8, x8, x9 ldp x14, x9, [x3] adcs x10, x12, x10 ldr x12, [x3, #16] adc x11, x13, x11 subs x13, x8, x14 sbcs x9, x10, x9 sbc x12, x11, x12 asr x14, x12, #63 cmp x14, #0 csel x8, x8, x13, lt csel x9, x10, x9, lt csel x10, x11, x12, lt stp x8, x9, [x0] str x10, [x0, #16] ret 16 / 45
  15. 乗算 (N*unit) bit整数とunit bit整数の乗算 教科書的な乗算 各 xi はunit(=64) bit整数 xi

    * y は2*unit bit整数でそれを [Hi:Li] と表記する LLVMでは xi と y をゼロ拡張して2*unit bit整数にしてから乗算する [x3:x2:x1:x0] X y ---------------- [H0:L0] [H1:L1] [H2:L2] [H3:L3] ----------------- [z4:z3:z2:z1:z0] LLVMにはCFが無いので[Hi:Li]の足し方を工夫する [H3:H2:H1:H0]+[0:L3:L2:L1]と256 bit整数として加算する 17 / 45
  16. mulx (x64) CFを変更しない乗算命令 Haswell (2013)以降のx64 CPUで利用可能 従来のmulはCFを壊すため、途中の計算結果を保存する必要があった レジスタ数の少ないx64では辛い clang/gccでは -march=bmi2

    で利用可能 古いCPUでは動かないのでCPU判別による命令の切り替えが必要 切り捨てたいところだが先日も非AVX CPUで動かないというissueがあって対応した Nehalem (2008) CPU判別 Xbyak::Cpuというクラスを持っている Intel/AMDの様々な命令セットを判別可能 18 / 45
  17. Montgomery乗算 目的 を計算するとき で割るのは重たいのでそれを避けたい 記号の準備 : CPUのビットサイズ(32 or 64), :

    を bit整数で表現したときの配列の大きさ と は互いに素なので となる整数 が存在する , という意味 , とする これらは と が決まると一意に定まる定数たち(以下 は省略) mont をMontgomery乗算と呼ぶ この効率的な計算方法がある(後述) 19 / 45
  18. 通常の乗算とMontgomery乗算とのやりとり 変換 , の代わりに , を使うと mont . 通常世界の ,

    , とMontgomery世界の , , が対応する toMont( ) = mont( , ) = . fromMont( ) = mont( , ) = . Montgomery乗算 出⼒ ⼊⼒ toMont toMont fromMont mont xZ yZ xyZ xy x y mul+mod 乗算を多用する場面ではMontgomeryの世界に入って計算後に戻ってくると効率的 20 / 45
  19. Montgomery乗算のPython実装 記号は上記の通り ip は p' を表す x, p は bit整数

    y は bit整数 個の配列で表現 def mont(x, y): MASK = 2**L - 1 t = 0 for i in range(N): t += x * y[i] # mulUnitAdd q = ((t & MASK) * ip) & MASK # L bit整数の計算 t += p * q # mulUnitAdd t >>= L if t >= p: t -= p return t 定理 : mont(x, y) = である(証明は有限体の実装参照) 21 / 45
  20. s_xbyakによる実装例 コード # z[n..0] = z[n-1..0] + px[n-1..0] * rdx

    def mulUnitAdd(z, px): n = len(z)-1 xor_(z[n], z[n]) for i in range(n): mulx(H, L, ptr(px+i*8)) # [H:L] = px[i] * y adox(z[i], L) # z[i] += L with CF if i == n-1: break adcx(z[i + 1], H) # z[i+1] += H with CF' adox(z[n], H) # z[n] += H with CF adc(z[n], 0) # z[n] += CF' 実際にはこれをサブルーチンとしてmont(x, y)を実装する AArch64ではレジスタが32個あって特殊なことをしなくてもmulUnitAddを実装できる LLVM-IRに任せても概ね大丈夫 23 / 45
  21. AVX-512を用いた有限体・楕円曲線演算の高速化 AVX-512 IFMA 1命令で を計算 ただし , は52bit整数, は64bit整数 と

    から104bitの積 を計算し、その下位52bitをL, 上位52bitをHとする L = uint52_t(uint104_t(a) * uint104_t(b)); H = uint52_t((uint104_t(a) * uint104_t(b)) >> 52); vpmadd52luq(c, a, b) : c += L; vpmadd52huq(c, a, b) : c += H; これを使って52bit進数多倍長整数の演算を実装する 24 / 45
  22. ベンチマーク 楕円曲線の点のスカラー倍算mul 演算 1個あたりのclk 倍率 Ec::mul 82.7K 1 mulEach 28.8K

    2.8 Ec::mulが従来の楕円曲線クラス mulEachが8個まとめて処理するAVX-512版楕円曲線クラス 1個あたり2.8倍の高速化 インタフェースを変えないためにmulEachは内部で従来のクラスとAVX-512版楕円曲線ク ラスの変換をやっているためオーバーヘッドがある 25 / 45
  23. 52bit進数加算 データフォーマット 381bit整数を52bitずつに分割してuint64_t (以下U) に入れていく 64bitずつなら(381+63)/64=6 U でよいが52bitずつなら(381+51)/52=8 U 必要

    加算 52bitずつの加算結果は53bit. 最上位の1bitはcarryとして繰り上げる // N = 8 // mask52 = (1<<52)-1 void rawAdd(U z[N], const U x[N], const U y[N]) { U c = 0; for (size_t i = 0; i < N; i++) { z[i] = x[i] + y[i] + c; if (i == N-1) break; c = z[i] >> 52; // 桁あふれをcに入れて z[i] &= mask52; // マスクする(ここの処理を正規化と呼ぶことにする) } } 26 / 45
  24. AVX-512による加算 先程のrawAddに対応するコード // N = 8 // Vec = __m512i

    void vrawAdd(Vec z[N], const Vec x[N], const Vec y[N]) { Vec t = vadd(x[0], y[0]); Vec c = vpsrlq(t, 52); // t>>52 z[0] = vand(t, vmask); // vmask = mask52 for (size_t i = 1; i < N; i++) { t = vadd(vadd(x[i], y[i]), t); if (i == N-1) { z[i] = t; return; } c = vpsrlq(t, 52); z[i] = vand(t, vmask); } } 28 / 45
  25. AVX-512による381bit×52bit乗算 素直な方法 IFMAを使って52bit整数同士の積の上位と下位をそれぞれ計算 まとめて順次足す void vrawMulUnit(Vec *z, const Vec *x,

    const Vec& y) { Vec L[N], H[N]; for (size_t i = 0; i < N; i++) { L[i] = vmulL(x[i], y); // vpmadd52luqのalias. 下位52bit H[i] = vmulH(x[i], y); // vpmadd52huqのalias. 上位52bit } z[0] = L[0]; for (size_t i = 1; i < N; i++) { z[i] = vadd(L[i], H[i-1]); } z[N] = H[N-1]; } 注意 : 計算結果は52bit進数にはなっていない(繰り上がりがあるので) 正規化が必要 29 / 45
  26. vralMulUnitの改善 vpmadd52luq(vmulL)はFMAなので加算と乗算を同時にできる void vrawMulUnit(Vec *z, const Vec *x, const Vec&

    y) { Vec H; z[0] = vmulL(x[0], y); H = vmulH(x[0], y); for (size_t i = 1; i < N; i++) { z[i] = vmulL(x[i], y, H); // x[i]*y + H H = vmulH(x[i], y); } z[N] = H; } vrawMulUnitAddも同様に実装 詳細は略 30 / 45
  27. 楕円曲線演算 2種類の演算 add(P, Q) : と dbl(P) : 座標の選択 アフィン座標

    : 2次元空間 と無限遠点を使い分ける 逆数(割り算)が必要で重たいので使わない 射影座標 : 3次元空間 ( ) を使う(無限遠点は ) add, dblの演算時に除算を無くせる ヤコビ座標 : 3次元空間 を使う( ) add, dblの演算時に除算を無くせる 演算コストは射影座標に比べてaddは大きい, dblは小さい スカラー倍算 多少後述するが, dblの回数がaddより多い 32 / 45
  28. 通常クラスとの型変換 通常クラスからの分岐 APIとしては, 通常の楕円クラスで扱いたい AVX-512 IFMAが使えるときのみ専用コードに移行 型変換が大変 有限体クラス Fp (通常)

    : 64bit進数(U 6個)が1個 FpM : (AVX-512) : 52bit進数(U 8個)が8個 楕円曲線クラス Ec/Fp (通常) : ヤコビ座標用にFpが3個(x, y, z) EcM/FpM (AVX-512) : 射影座標用にFpMが3個(x, y, z) これらの間の変換が大変 34 / 45
  29. メモリレイアウト Fp[8]とFpM Fpを8個単位で処理 X 00 X 01 X 02 X

    03 X 04 X 05 Ec/Fp Y 00 Y 01 Y 02 Y 03 Y 04 Y 05 Z 00 Z 01 Z 02 Z 03 Z 04 Z 05 X 10 X 11 X 12 X 13 X 14 X 15 ... Y 10 EcM/FpM X 00 X 01 X 02 ... Y 00 Y 01 Y 02 ... Z 00 Z 01 X 10 X 11 X 12 Y 10 Y 11 Y 12 Z 10 Z 11 ... ... zmm0 zmm1 zmm2 ... 35 / 45
  30. 変換関数詳細 Ec/FpからEcM/FpMへの変換 384bit整数(6U)が3個(ヤコビ座標)」が8個を「52bit整数(8U)が8個」を射影座標3個に変換 必要な手続き 1. 64bit Montgomery表現を普通384bit整数に変換(24個) 2. gather命令で6Ux8を8Ux6に変換 3.

    split52bitで6Ux8を8Ux8に変換 4. 普通の384bit整数x8を52bit Montgomery表現x8に変換(3回) 5. ヤコビ座標を射影座標に変換 AVX-512での計算後はこの逆順の操作が必要 36 / 45
  31. split52bit 64bit整数を52bit整数に分割 /* |64 |64 |64 |64 |64 |64 |

    x|52:12|40:24|28:36|16:48|4:52:8|44:20| y|52|52 |52 |52 |52 |52|52 |20| */ void split52bit(Vec y[8], const Vec x[6]) { y[0] = vand(x[0], vmask); y[1] = vand(vor(vpsrlq(x[0], 52), vpsllq(x[1], 12)), vmask); y[2] = vand(vor(vpsrlq(x[1], 40), vpsllq(x[2], 24)), vmask); y[3] = vand(vor(vpsrlq(x[2], 28), vpsllq(x[3], 36)), vmask); y[4] = vand(vor(vpsrlq(x[3], 16), vpsllq(x[4], 48)), vmask); y[5] = vand(vpsrlq(x[4], 4), vmask); y[6] = vand(vor(vpsrlq(x[4], 56), vpsllq(x[5], 8)), vmask); y[7] = vpsrlq(x[5], 44); } vand(vor(a, b), mask)の形 vpternlogq(a, b, mask, 0b11101000)と同値 37 / 45
  32. 64bit Montgomery表現から52bit Montgomery表現への変換 2回の変換回数を1回に減らす toMont (fromMont ( )) , ,

    , , , fromMont ( ) = toMont ( ) = を掛ければ一度で変換できる 8回のfromMont と1回のtoMont を1回のMont で実現 39 / 45
  33. スカラー倍算の実装 バイナリ法 点 を 倍する を2進数で表現して上位ビットから順に計算 n 1 1 0

    1 0 0 1 += Q Q Q Q dbl dbl dbl P += P += P def mul(P : Ec, n : int): bs = bin(n)[2:] Q = Ec() # zero for b in bs: # 上位ビットから順次計算 Q = dbl(Q) if b == '1': Q = add(Q, P) return Q 40 / 45
  34. マルチスカラー倍算 を計算する ECDSAなどの中で現れる式 , を別々に計算して足すのではなく一緒に計算する def mul(P1 : Ec, n1

    : int, P2 : Ec, n2 : int): bs1 = bin(n1)[2:] bs2 = bin(n2)[2:] # bs1とbs2の短い方の先頭に0を追加して同じ長さにする(コード略) Q = Ec() # zero for i in range(len(bs1)): Q = dbl(Q) if i < len(bs1) and bs1[i] == '1': Q = add(Q, P1) if i < len(bs2) and bs2[i] == '1': Q = add(Q, P2) return Q こうすると dbl(Q) の呼び出し回数が半分になる( と のビット数が同じなら) 41 / 45
  35. 自己準同型写像 数学の準備 扱っている楕円曲線 の方程式は という形 を1の原始3乗根とする( , ) とする なので

    を3回適用すると恒等写像になる また は準同型であることが(比較的容易に)示される このとき は点 のある定数倍写像になることが知られている となる定数 が存在する 42 / 45
  36. GLV法 1個のスカラー倍算を2個のマルチスカラー倍算に変換する 先程の を使う 点 の 倍を計算するとき として を計算する ここで

    , は の約半分のbit数になるように選ぶ は簡単に計算できる写像 dbl(P) の回数が約半分にできる が256bitなら は約128bit SIMD向けにちょっと思いついた工夫 約半分では扱いづらい(ごくまれに128bitを越えるとSIMDできない) が保証される選び方を探す必要がある パラメータとアルゴリズムを次のようにするとうまくいくのが示せた これで安心してGLV法を使える 43 / 45
  37. 安全なsplit関数 パラメータと分割関数 M = 1<<256 H = 1<<128 z =

    -0xd201000000010000 L = z*z - 1 r = L*L + L + 1 s = r.bit_length() # 255 (256だと駄目) S = 1<<s v = S // L r0 = S % L def split(n): # a + bL == x for (a, b) = split(x) b = (n * v) >> s a = n - b * L return (a, b) Prop. : なら である 44 / 45
  38. GLV法でヤコビ座標やアフィン座標を使う SIMDでは射影座標だった add(P, Q)でP = Qのときにdbl(P)を呼ぶ例外処理ができないため 射影座標はヤコビ座標よりも若干遅い GLV法の中でadd(P, Q)でP ≠

    Qとなることが保証できれば射影座標が使える 注意深く演算順序を変えてヤコビ座標を利用できるようにした アフィン座標を組み合わせる 複数の楕円曲線の場合、まとめてアフィン座標に変換する効率のよい方法がある 複数の逆数をまとめて計算する方法 これを用いてGLV法で利用するテーブルをアフィン座標にしておく 「ヤコビ座標+ヤコビ座標」よりも「ヤコビ座標+アフィン座標」の方が速い 実装はmcl/src/msm_avx.cpp参照 zenn.dev/herumiでも解説予定 45 / 45