Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Counterfactual Explanationsで機械学習モデルを解釈する / TokyoR99

Counterfactual Explanationsで機械学習モデルを解釈する / TokyoR99

2022年6月4日に行われた、第99回R勉強会@東京(#TokyoR)での発表資料です。
https://tokyor.connpass.com/event/249096/

資料で使っているRコードはこちらになります。
Pythonですが、DiCEを使ったnotebookもあります。
https://github.com/dropout009/TokyoR99

森下光之助

May 23, 2022
Tweet

More Decks by 森下光之助

Other Decks in Science

Transcript

  1. • 800 SHapley Additive exPlanations SHAP • SHAP 𝔼[ #

    𝑓 𝑿 ] 𝑖 # 𝑓 𝒙! • 500 𝑖 800 300 ! 𝑓 𝒙! = 𝜙" + 𝜙 !, + 𝜙 !, + 𝜙 !, + 𝜙 !, 800 800 +200 +200 +500 +200 -300
  2. • DS DE Biz 3 100 • # 𝑓 𝑋"#,

    𝑋"$, 𝑋%!& = 8𝑋"# + 4𝑋"$ + 10𝑋%!& • DS 60 DE 30 Biz 20 • # 𝑓 60,30,20 = 8×60 + 4×30 + 10×20 = 800 • 1,000 CE (2014.12.10) (http://www.datascientist.or.jp/news/2014/pdf/1210.pdf)
  3. ! 𝑓 𝑋&', 𝑋&(, 𝑋)!* = 8𝑋&' + 4𝑋&( +

    10𝑋)!* ! 𝑓 60,30,20 = 8×60 + 4×30 + 10×20 = 800 1,000 • 1 • DS 60 85 • DE 30 80 • Biz 20 40 • 2 • DS 70 DE 60 CE
  4. • ! 𝑓 𝑋&', 𝑋&(, 𝑋)!* 1,000 • Random Forest

    Neural Net ! 𝑓 𝑋&', 𝑋&(, 𝑋)!* • ! 𝑓 𝑋&', 𝑋&(, 𝑋)!* 1,000 • 100 0 100 101 3 101 3 100 3 CE
  5. • 𝑦 𝒙∗ = (𝑥"#, 𝑥"$, 𝑥%!&) 𝒙∗ = argmin

    𝒙 𝒙 − 𝒙! ) s. t. # 𝑓 𝒙 ≥ 𝑦, 𝒙 ≥ 𝒙! • 𝒙! = (𝑥!,"#, 𝑥!,"$, 𝑥!,%!&) 𝑖 𝒙 − 𝒙! ) = ∑+ 𝑥+ − 𝑥!,+ ) L2 L1 • 100 Biz DS 築 • 𝒙 ≥ 𝒙! CE
  6. R

  7. 𝐼𝑛𝑐𝑜𝑚𝑒 = 0.05𝑋&'𝑋&( ".,𝑋)!* ".-𝜖 𝑋&' ∼ Beta 13, 7

    ×100 𝑋&( ∼ Beta 10, 10 ×100 𝑋)!* ∼ Beta 7, 13 ×100 𝜖 ∼ 𝒩(1, 0.05.)
  8. • Elastic Net Counterfactual Neural Net tree • 4 •

    Elastic Nest L1 L2 D 𝜷/0 = argmin 𝒃 K !23 4 𝑦! − 𝒙! 5𝒃 . + 𝜆 𝜃 𝒃 3 + 1 − 𝜃 𝒃 .
  9. • DE DS Biz 1,000 • 𝑥&' ∗ , 𝑥&(

    ∗ , 𝑥)!* ∗ = argmin 6!",6!#,6$%& (𝑥&' , 𝑥&( , 𝑥)!* ) − (𝑥!,&' , 𝑥!,&( , 𝑥!,)!* ) . s. t. ! 𝑓 𝑥&', 𝑥&(, 𝑥)!* ≥ 1000 𝑥&' ≥ 𝑥!,&', 𝑥&( = 𝑥!,&( 𝑥)!* ≥ 𝑥!,)!* Counterfactual
  10. • Tidymodels N <- 2000 f <- function(ds, de, biz)

    { ds * de^0.6 * biz^0.8 / 20 } df <- tibble( ds = rbeta(N, 13, 7) * 100, de = rbeta(N, 10, 10) * 100, biz = rbeta(N, 7, 13) * 100, income = f(ds, de, biz) * rnorm(N, 1, 0.05) ) train_test_split <- rsample::initial_split(df) df_train <- rsample::training(train_test_split) df_test <- rsample::testing(train_test_split) cv_split <- rsample::vfold_cv(df_train, v = 5) •
  11. • wf <- workflows::workflow() %>% workflows::add_model(model) %>% workflows::add_recipe(rec) model <-

    parsnip::linear_reg( penalty = tune::tune(), mixture = tune::tune() ) %>% parsnip::set_mode("regression") %>% parsnip::set_engine("glmnet") rec <- recipes::recipe(income ~ ds + de + biz, data = df) %>% recipes::step_interact(terms = ~ ds:de + de:biz + ds:biz + ds:de:biz) %>% recipes::step_poly(all_predictors(), degree = 4) %>% recipes::step_normalize(all_numeric_predictors()) • •
  12. • bayes_result <- wf %>% tune::tune_bayes( resamples = cv_split, param_info

    = tune::parameters( list( penalty = dials::penalty(), mixture = dials::mixture() ) ), metrics = yardstick::metric_set(rmse), initial = 5, iter = 30, control = tune::control_bayes(verbose = TRUE, no_improve = 5) )
  13. • RMSE: 31.3 R2: 0.974 # best_model <- bayes_result %>%

    tune::select_best() # wf_final <- wf %>% tune::finalize_workflow(best_model) # final_result <- wf_final %>% tune::last_fit(train_test_split) # final_result %>% tune::collect_metrics() # final_result %>% tune::collect_predictions()
  14. • DE DS Biz 1,000 Counterfactual Explanations counterfactual_explanations <- function(model,

    current_x, desired_y) { as_input <- function(x) tibble(ds = x[1], de = x[2], biz = x[3]) # predict_num <- function(model, x) pull(predict(model, as_input(x))) # df numeric constraint <- function(x) predict_num(model, x) - desired_y # distance <- function(x) norm(current_x - x, type = "2") # solution <- Rsolnp::solnp( # pars = current_x + 1e-3, fun = distance, ineqfun = constraint, ineqLB = 0, ineqUB = 0.1, LB = current_x, UB = c(100, current_x[2] + 1e-2, 100), # DE control = list(tol = 1e-5) ) result <- list( current_x = as_input(current_x), # current_y = predict_num(model, current_x), # desired_y = desired_y, # required_x = as_input(solution$pars), # predicted_y = predict_num(model, solution$pars) # ) return(result) }
  15. 𝒙! = 𝑥!,&' , 𝑥!,&( , 𝑥!,)!* = (66.0, 66.9,

    28.9) 607 𝒙∗ = (77.4, 66.9, 45.8) 1,000 CE
  16. • Wachter, Sandra, Brent Mittelstadt, and Chris Russell. "Counterfactual explanations

    without opening the black box: Automated decisions and the GDPR." Harv. JL & Tech. 31 (2017): 841. • Mothilal, Ramaravind K., Amit Sharma, and Chenhao Tan. "Explaining machine learning classifiers through diverse counterfactual explanations." Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency. 2020. • Molnar, Christoph. "Interpretable machine learning. A Guide for Making Black Box Models Explainable." (2019). https://christophm.github.io/interpretable-ml-book/.
  17. R