Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Polars を Kaggle コンペで使ってみた(LMSYS Chatbot Arena)

Kohecchi
August 22, 2024

Polars を Kaggle コンペで使ってみた(LMSYS Chatbot Arena)

Kohecchi

August 22, 2024
Tweet

More Decks by Kohecchi

Other Decks in Programming

Transcript

  1. Polars Data Crunch #2 までに学んだ基本的な操作 • read_csv • read_parquet •

    head • tail • with_columns • select • drop • filter • group_by • over • alias • cast • sort • concat • slice • join • dt • replace • fill_null • pipe 上記のほぼ全てが chumajin さんの神 Notebookで紹介されています!必見! chumajin's room - polars basic - (日本語ver)
  2. LMSYSコンペ https://www.kaggle.com/competitions/lmsys-chatbot-arena ❖ 開催期間: 2024/5/3 ~ 8/3 ? ❖ 主催は

    LMSYS Chatbot Arena ❖ 2つのLLMに与えた Prompt と モデルの回答 から勝敗を予測するコ ンペ
  3. 学習データ train.csv ❖ id - ユニークなID ❖ model_a - LLMのモデル名

    ❖ model_b - LLMのモデル名 ❖ prompt - LLM に渡すプロンプト ❖ response_a - LLM モデル a からの回答 ❖ response_b - LLM モデル b からの回答 ❖ winner_model_[a/b/tie] - 結果(a, b, 引き分け) ← Target
  4. prompt, response_a, response_b のデータ train = pl.read_csv("/kaggle/input/lmsys-chatbot-arena/train.csv") train.filter(pl.col("id")==2929096534) id prompt

    response_a response_b i64 str str str 2929096534 ["What is 2+2?", "You are wrong. 2+2=5"] ["2 + 2 = 4", "No, 2+2 equals 4. That's a basic math fact."] ["2+2=4", "No, I apologize but 2+2=4. That is a mathematical fact."] What is 2+2? You are wrong. 2+2=5 2 + 2 = 4 No, 2+2 equals 4. That’s basic math fact. 2+2=4 No, I apologize but 2+2=4. That is a mathematical fact. 複数回の LLM とのやり取りが テキスト配列の文字列 として格納されている。 1ターン目 2ターン目
  5. str → List 型への変換① def str2list(input_str): stripped_str = input_str.strip('[]') sentences

    = [s.strip('"') for s in stripped_str.split('","')] return sentences train.with_columns( prompt_list = pl.col("prompt").map_elements(str2list, return_dtype=pl.List(pl.String)), ) .map_elements()  ※Pandas の apply とほぼ同じ。遅い。 id prompt prompt_list i64 str list[str] 2929096534 ["What is 2+2?", "You are wrong. 2+2=5"] ["What is 2+2?", "You are wrong. 2+2=5"]
  6. str → List 型への変換② train.select( pl.col("id"), pl.col("prompt"), pl.col("prompt").str.json_decode().alias("prompt_list"), ) .str.json_decode()

    • list 形式も JSON フォーマットの1つなので json_decode() が使える id prompt prompt_list i64 str list[str] 2929096534 ["What is 2+2?", "You are wrong. 2+2=5"] ["What is 2+2?", "You are wrong. 2+2=5"]
  7. 厳密には結果が異なる • json_decode ではエスケープされた unicode が変換される。 train.select( pl.col("prompt").str.json_decode().alias("prompt_list_jd"), pl.col("prompt").map_elements(str2list, return_dtype=pl.List(pl.String)).alias("prompt_list_me"),

    ).filter(pl.col("prompt_list_jd") != pl.col("prompt_list_me")) id prompt_list_jd prompt_list_me i64 list[str] list[str] 4294656694 ["A simple mnemonic for π: "How I wish I could enumerate pi easily" The number of letters in each word is a digit of π. Show this is true."] ["A simple mnemonic for \u03c0:\n\"How I wish I could enumerate pi easily\"\n\nThe number of letters in each word is a digit of \u03c0.\nShow this is true."]
  8. tokenize 後の結果も変わる • tokenize の結果も変わる。json_decode() を使った方が token 数を抑えられ そう。 from

    transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct-bnb-4bit") prompt_list_jd_token = tokenizer(prompt_list_jd, max_length=500, truncation=True)["input_ids"] prompt_list_me_token = tokenizer(prompt_list_me, max_length=500, truncation=True)["input_ids"] print("json_decode:", prompt_list_jd_token) print("map_element:", prompt_list_me_token) json_decode: [128000, 32, 4382, 87901, 369, 52845, 512, 72059, 358, 6562, 358, 1436, 13555, 9115, 6847, 1875, 791, 1396, 315, 12197, 304, 1855, 3492, 374, 264, 16099, 315, 52845, 627, 7968, 420, 374, 837, 13] map_element: [128000, 32, 4382, 87901, 369, 1144, 84, 2839, 66, 15, 7338, 77, 2153, 4438, 358, 6562, 358, 1436, 13555, 9115, 6847, 23041, 77, 1734, 791, 1396, 315, 12197, 304, 1855, 3492, 374, 264, 16099, 315, 1144, 84, 2839, 66, 15, 7255, 77, 7968, 420, 374, 837, 13] \u03c0 π
  9. List の操作 train.with_columns( num_turns = pl.col("prompt_list").list.len(), prompt_first = pl.col("prompt_list").list.get(0, null_on_oob=True),

    prompt_first2 = pl.col("prompt_list").list.get(1, null_on_oob=True), prompt_last = pl.col("prompt_list").list.get(-1, null_on_oob=True), prompt_last2 = pl.col("prompt_list").list.get(-2, null_on_oob=True), ) .list.len() : list の要素数を取得 .list.get() : list の要素を取得 id prompt_list num_turn prompt_first prompt_last i64 list[str] u32 str str 2929096534 ["What is 2+2?", "You are wrong. 2+2=5"] 2 “What is 2+2?” "You are wrong. 2+2=5"
  10. 学習データ train.csv ❖ id - ユニークなID ❖ model_a - LLMのモデル名

    ❖ model_b - LLMのモデル名 ❖ prompt - LLM に渡すプロンプト ❖ response_a - LLM モデル a からの回答 ❖ response_b - LLM モデル b からの回答 ❖ winner_model_[a/b/tie] - 結果(a, b, 引き分け) ← Target 同じ処理を3回繰り替えす