class: center, middle, inverse, title-slide # Interpretable Machine Learning ## with R ### Brad Boehmke ### 2018-09-14 --- class: center, middle, inverse # Introduction --- # About me .pull-left[ <img src="images/name-tag.png" width="1360" style="display: block; margin: auto;" /> * <svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 496 512"><path d="M336.5 160C322 70.7 287.8 8 248 8s-74 62.7-88.5 152h177zM152 256c0 22.2 1.2 43.5 3.3 64h185.3c2.1-20.5 3.3-41.8 3.3-64s-1.2-43.5-3.3-64H155.3c-2.1 20.5-3.3 41.8-3.3 64zm324.7-96c-28.6-67.9-86.5-120.4-158-141.6 24.4 33.8 41.2 84.7 50 141.6h108zM177.2 18.4C105.8 39.6 47.8 92.1 19.3 160h108c8.7-56.9 25.5-107.8 49.9-141.6zM487.4 192H372.7c2.1 21 3.3 42.5 3.3 64s-1.2 43-3.3 64h114.6c5.5-20.5 8.6-41.8 8.6-64s-3.1-43.5-8.5-64zM120 256c0-21.5 1.2-43 3.3-64H8.6C3.2 212.5 0 233.8 0 256s3.2 43.5 8.6 64h114.6c-2-21-3.2-42.5-3.2-64zm39.5 96c14.5 89.3 48.7 152 88.5 152s74-62.7 88.5-152h-177zm159.3 141.6c71.4-21.2 129.4-73.7 158-141.6h-108c-8.8 56.9-25.6 107.8-50 141.6zM19.3 352c28.6 67.9 86.5 120.4 158 141.6-24.4-33.8-41.2-84.7-50-141.6h-108z"/></svg> bradleyboehmke.github.io * <svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 496 512"><path d="M165.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6zm-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3zm44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9zM244.8 8C106.1 8 0 113.3 0 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C428.2 457.8 496 362.9 496 252 496 113.3 383.5 8 244.8 8zM97.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1zm-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7zm32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1zm-11.4-14.7c-1.6 1-1.6 3.6 0 5.9 1.6 2.3 4.3 3.3 5.6 2.3 1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2z"/></svg> @bradleyboehmke * <svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M459.37 151.716c.325 4.548.325 9.097.325 13.645 0 138.72-105.583 298.558-298.558 298.558-59.452 0-114.68-17.219-161.137-47.106 8.447.974 16.568 1.299 25.34 1.299 49.055 0 94.213-16.568 130.274-44.832-46.132-.975-84.792-31.188-98.112-72.772 6.498.974 12.995 1.624 19.818 1.624 9.421 0 18.843-1.3 27.614-3.573-48.081-9.747-84.143-51.98-84.143-102.985v-1.299c13.969 7.797 30.214 12.67 47.431 13.319-28.264-18.843-46.781-51.005-46.781-87.391 0-19.492 5.197-37.36 14.294-52.954 51.655 63.675 129.3 105.258 216.365 109.807-1.624-7.797-2.599-15.918-2.599-24.04 0-57.828 46.782-104.934 104.934-104.934 30.213 0 57.502 12.67 76.67 33.137 23.715-4.548 46.456-13.32 66.599-25.34-7.798 24.366-24.366 44.833-46.132 57.827 21.117-2.273 41.584-8.122 60.426-16.243-14.292 20.791-32.161 39.308-52.628 54.253z"/></svg> @bradleyboehmke * <svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 448 512"><path d="M416 32H31.9C14.3 32 0 46.5 0 64.3v383.4C0 465.5 14.3 480 31.9 480H416c17.6 0 32-14.5 32-32.3V64.3c0-17.8-14.4-32.3-32-32.3zM135.4 416H69V202.2h66.5V416zm-33.2-243c-21.3 0-38.5-17.3-38.5-38.5S80.9 96 102.2 96c21.2 0 38.5 17.3 38.5 38.5 0 21.3-17.2 38.5-38.5 38.5zm282.1 243h-66.4V312c0-24.8-.5-56.7-34.5-56.7-34.6 0-39.9 27-39.9 54.9V416h-66.4V202.2h63.7v29.2h.9c8.9-16.8 30.6-34.5 62.9-34.5 67.2 0 79.7 44.3 79.7 101.9V416z"/></svg> @bradleyboehmke * <svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M502.3 190.8c3.9-3.1 9.7-.2 9.7 4.7V400c0 26.5-21.5 48-48 48H48c-26.5 0-48-21.5-48-48V195.6c0-5 5.7-7.8 9.7-4.7 22.4 17.4 52.1 39.5 154.1 113.6 21.1 15.4 56.7 47.8 92.2 47.6 35.7.3 72-32.8 92.3-47.6 102-74.1 131.6-96.3 154-113.7zM256 320c23.2.4 56.6-29.2 73.4-41.4 132.7-96.3 142.8-104.7 173.4-128.7 5.8-4.5 9.2-11.5 9.2-18.9v-19c0-26.5-21.5-48-48-48H48C21.5 64 0 85.5 0 112v19c0 7.4 3.4 14.3 9.2 18.9 30.6 23.9 40.7 32.4 173.4 128.7 16.8 12.2 50.2 41.8 73.4 41.4z"/></svg> bradleyboehmke@gmail.com ] .pull-right[ #### Family <img src="images/family.png" align="right" alt="family" width="130" /> * Dayton, OH * Kate, Alivia (9), Jules (6) #### Professional * 84.51° <img src="images/logo8451.jpg" align="right" alt="family" width="150" /> #### Academic * University of Cincinnati <img src="images/uc.png" align="right" alt="family" width="100" /> * Air Force Institute of Technology #### R Community <img src="images/r-contributions.png" alt="family" width="400" /> ] --- # Your turn! <br><br><br><br><br><br> .center[.font150[What does machine learning interpretability mean to you?]] --- # A mental model .pull-left[ .bolder[.font120[Philosophical: Political & Social]] * Data ethics * Fairness, Accountability, Transparency (FAT) * Regulatory examples: - Civil Rights Acts - Americans with Disabilities Act - Genetic Information Nondiscrimination Act - Health Insurance Portability and Accountability Act - Equal Credit Opportunity Act - Fair Credit Reporting Act - Fair Housing Act - European Union Greater Data Privacy Regulation * https://www.fatml.org/resources/relevant-scholarship .center[.font120[.blue[___Right to explanation___]]] ] .pull-right[ <img src="images/santa.jpg" style="display: block; margin: auto;" /> ] ] --- # A mental model .pull-left[ <img src="images/black-box.gif" style="display: block; margin: auto;" /> ] .pull-right[ .bolder[.font120[Pragmatic: Model Logic]] * Performance analysis - Residual plots - Lift charts - ROC curves * Sensitivity analysis - Simulated data - Perturbation - Accuracy vs explanation * Feature analysis - Feature importance - Feature effects <hr style="height:40px; visibility:hidden;" /></hr> .center[.font120[.blue[___Ability to explain___]]] ] ] --- # Today's focus .pull-left[ .opacity10[ .bolder[.font120[Philosophical: Political & Social]] * Data ethics * Fairness, Accountability, Transparency (FAT) * Regulatory examples: - Civil Rights Acts - Americans with Disabilities Act - Genetic Information Nondiscrimination Act - Health Insurance Portability and Accountability Act - Equal Credit Opportunity Act - Fair Credit Reporting Act - Fair Housing Act - European Union Greater Data Privacy Regulation * https://www.fatml.org/resources/relevant-scholarship .center[.font120[___Right to explanation___]] ] ] .pull-right[ .bolder[.font120[Pragmatic: Model Logic]] .opacity10[ * Performance analysis - Residual plots - Lift charts - ROC curves * Sensitivity analysis - Simulated data - Perturbation - Accuracy vs explanation ] * Feature analysis - Feature importance - Feature effects <hr style="height:20px; visibility:hidden;" /></hr> .center[.font120[.blue[___Ability to explain___]]] ] ] --- class: center, middle, inverse # Terminology to consider --- # Interpretable models vs model interpretation - The complexity of a machine learning model is directly related to its interpretability. - Generally, the more complex the model, the more difficult it is to interpret and explain. <img src="images/interpretable-models.png" width="80%" height="80%" style="display: block; margin: auto;" /> --- # Model complexity Even naturally interpretable models (i.e. GLMs) can become quite complex. Consider the following from Harrison Jr and Rubinfeld [1]: <br> $$ \widehat{\text{log}(y)} = 9.76 + 0.0063RM^2 + 8.98 \times 10^{-5}AGE - 0.19\text{log}(DIS) + 0.096\text{log}(RAD) - \dots $$ $$ 4.20 \times 10^{-4}TAX - 0.031PTRATIO + 0.36(B - 0.63)^2 - 0.37\text{log}(LSTAT) - \dots $$ $$ 0.012CRIM + 8.03 \times 10^{-5}ZN + 2.41 \times 10^{-4}INDUS + 0.088CHAS - 0.0064NOX^2 $$ <br><br><br><br> .center[.content-box-gray[.font110[Is this really any more interpretable than a random forest model?]]] --- # Model complexity .opacity10[ Even naturally interpretable models (i.e. GLMs) can become quite complex. Consider the following from Harrison Jr and Rubinfeld [1]: <br> $$ \widehat{\text{log}(y)} = 9.76 + 0.0063RM^2 + 8.98 \times 10^{-5}AGE - 0.19\text{log}(DIS) + 0.096\text{log}(RAD) - \dots $$ $$ 4.20 \times 10^{-4}TAX - 0.031PTRATIO + 0.36(B - 0.63)^2 - 0.37\text{log}(LSTAT) - \dots $$ $$ 0.012CRIM + 8.03 \times 10^{-5}ZN + 2.41 \times 10^{-4}INDUS + 0.088CHAS - 0.0064NOX^2 $$ <br><br><br> ] .center[.content-box-gray[.font130[.blue[We need additional approaches for robust model interpretability.]]]] --- # Model specific vs Model agnostic .pull-left[ .bolder[.font120[Model specific]] - Limited to specific ML classes - Incorporates model-specific logic - Examples: - coefficients in linear models - impurity in tree-based models - ___.red[limited application]___ ] .pull-right[ .bolder[.font120[Model agnostic]] - Can be applied to any type of ML algorithm - Assesses inputs and outputs - Examples: - Permutation-based variable importance - PDPs, ICE curves - LIME, Shapley, Breakdown - .blue[___most of what you'll see today are model agnostic approaches___] ] <br><br><br> .center[.content-box-gray[.font110[When possible its good practice to compare model specific vs model agnostic approaches.]]] --- # Scope of interpretability .pull-left[ .bolder[Global interpretability] - How do features influence overall model performance? - What is the overall relationship between features and the target? <hr style="height:1px; visibility:hidden;" /></hr> <img src="images/global.png" width="80%" height="80%" style="display: block; margin: auto;" /> .center[.content-box-gray[.bolder[Averages effects over data dimensions]]] ] .pull-right[ .bolder[Local interpretability] - How do our features influence individual predictions? - What are the observation level relationships between features and the target? <img src="images/local.png" width="80%" height="80%" style="display: block; margin: auto;" /> .center[.content-box-gray[.bolder[Assesses individual effects]]] ] --- class: center, middle, inverse # Prerequisites --- # Packages & Data .pull-left[ .bolder[Packages] ```r # helper packages library(ggplot2) library(dplyr) # setting up machine learning models library(rsample) library(h2o) # packages for explaining our ML models *library(pdp) *library(vip) *library(iml) *library(DALEX) *library(lime) # initialize h2o session h2o.no_progress() h2o.init() ## Connection successful! ## ## R is connected to the H2O cluster: ## H2O cluster uptime: 6 minutes 43 seconds ## H2O cluster timezone: America/New_York ## H2O data parsing timezone: UTC ## H2O cluster version: 3.20.0.2 ## H2O cluster version age: 2 months and 27 days ## H2O cluster name: H2O_started_from_R_b294776_ony639 ## H2O cluster total nodes: 1 ## H2O cluster total memory: 3.99 GB ## H2O cluster total cores: 8 ## H2O cluster allowed cores: 8 ## H2O cluster healthy: TRUE ## H2O Connection ip: localhost ## H2O Connection port: 54321 ## H2O Connection proxy: NA ## H2O Internal Security: FALSE ## H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4 ## R Version: R version 3.5.1 (2018-07-02) ``` ] .pull-right[ .bolder[Data] ```r # classification data df <- rsample::attrition %>% mutate_if(is.ordered, factor, ordered = FALSE) %>% mutate(Attrition = ifelse(Attrition == "Yes", 1, 0) %>% as.factor()) # convert to h2o object df.h2o <- as.h2o(df) # variable names for resonse & features y <- "Attrition" x <- setdiff(names(df), y) ``` ] --- # Models .scrollable90[ .pull-left[ .bolder[4 machine learning models] * Elastic net (AUC = 0.836) * Random forest (AUC = 0.788) * Gradient boosting machine (AUC = 0.8105) * Ensemble(AUC = 0.835) ] .pull-right[ .bolder[Models] ```r # elastic net model glm <- h2o.glm( x = x, y = y, training_frame = df.h2o, nfolds = 5, fold_assignment = "Modulo", keep_cross_validation_predictions = TRUE, family = "binomial", seed = 123 ) # random forest model rf <- h2o.randomForest( x = x, y = y, training_frame = df.h2o, nfolds = 5, fold_assignment = "Modulo", keep_cross_validation_predictions = TRUE, ntrees = 1000, stopping_metric = "AUC", stopping_rounds = 10, stopping_tolerance = 0.005, seed = 123 ) # gradient boosting machine model gbm <- h2o.gbm( x = x, y = y, training_frame = df.h2o, nfolds = 5, fold_assignment = "Modulo", keep_cross_validation_predictions = TRUE, ntrees = 1000, stopping_metric = "AUC", stopping_rounds = 10, stopping_tolerance = 0.005, seed = 123 ) # ensemble ensemble <- h2o.stackedEnsemble( x = x, y = y, training_frame = df.h2o, metalearner_nfolds = 5, model_id = "ensemble", base_models = list(glm, rf, gbm), metalearner_algorithm = "glm" ) # model performance h2o.auc(glm, xval = TRUE) ## [1] 0.8363927 h2o.auc(rf, xval = TRUE) ## [1] 0.7882236 h2o.auc(gbm, xval = TRUE) ## [1] 0.810503 h2o.auc(ensemble, xval = TRUE) ## [1] 0.8285921 ``` ] ] --- # Model agnostic procedures .pull-left[ In order to work with the __DALEX__ and __iml__ packages, we need to: 1. Get 3 key ingredients - data frame of just features - numeric vector of response - custom prediction function ] .pull-right[ ```r # 1. create a data frame with just the features features <- as.data.frame(df) %>% select(-Attrition) # 2. Create a numeric vector with the actual responses response <- as.numeric(as.character(df$Attrition)) # 3. Create custom predict function that returns the predicted values as a # vector (probability of purchasing in our example) pred <- function(model, newdata) { results <- as.data.frame(h2o.predict(model, as.h2o(newdata))) return(results[[3L]]) } # example of prediction output *pred(gbm, features) %>% head() ## [1] 0.56528449 0.01921764 0.74270046 0.11967059 0.19895407 0.03881329 ``` ] --- # Model agnostic procedures .pull-left[ In order to work with the __DALEX__ and __iml__ packages, we need to: 1. Get 3 key ingredients - data frame of just features - numeric vector of response - custom prediction function 2. Create a model agnostic object - __iml__: Class `Predictor` - __DALEX__: Class `Explainer` ] .pull-right[ ```r # GBM predictor object iml_predictor_gbm <- Predictor$new( model = gbm, * data = features, * y = response, * predict.fun = pred, class = "classification" ) # GBM explainer dalex_explainer_gbm <- DALEX::explain( model = gbm, * data = features, * y = response, * predict_function = pred, label = "gbm" ) ``` ] <hr style="height:2px; visibility:hidden;" /></hr> .center[.content-box-gray[These objects simply pass key information from the ML model to downstream functions.]] --- class: center, middle, inverse # Global Interpretation --- # Global Interpretation .pull-left[ How do features influence overall model performance? * Feature importance - model specific - model agnostic ] .pull-right[ What is the overall relationship between features and the target? * Feature effects - Partial dependence - Interactions ] --- # Global feature importance .scrollable[ .pull-left[ How do features influence overall model performance? * Feature importance - model specific - GLM: absolute standardized coefficients or t-statistic - RF & GBM: improvement in gini - ensemble: NA ] .pull-right[ ```r vip::vip(glm) ``` <img src="slides-source_files/figure-html/vip-model-specific-1.png" style="display: block; margin: auto;" /> ```r vip::vip(rf) ``` <img src="slides-source_files/figure-html/vip-model-specific-2.png" style="display: block; margin: auto;" /> ```r vip::vip(ensemble) ## Error: Column indexes must be at most 0 if positive, not 1, 2 ``` ] ] <br> .center[.content-box-gray[[__vip__](https://github.com/koalaverse/vip) provides consistant vip plotting regardless of ML model.]] --- # Global feature importance .pull-left[ How do features influence overall model performance? * Feature importance - .opacity10[model specific] - model agnostic: .bolder[Permutation-based] ] .pull-right[ <img src="images/vip-permute.png" width="90%" height="90%" style="display: block; margin: auto;" /> ] .center[.content-box-gray[.font90[Permutation breaks the relationship between the feature and response by randomizing the feature values.]]] --- # Global feature importance .scrollable90[ .pull-left[ .bolder[iml] ```r # compute feature importance with specified loss metric iml_vip <- FeatureImp$new(iml_predictor_ensemble, loss = "logLoss") # output as a data frame head(iml_vip$results) ## feature original.error permutation.error importance ## 1 OverTime 0.1558804 0.2757361 1.768895 ## 2 JobRole 0.1558804 0.2047531 1.313527 ## 3 EnvironmentSatisfaction 0.1558804 0.1926294 1.235751 ## 4 YearsWithCurrManager 0.1558804 0.1853997 1.189372 ## 5 StockOptionLevel 0.1558804 0.1837507 1.178793 ## 6 JobSatisfaction 0.1558804 0.1832634 1.175667 # plot output plot(iml_vip) + ggtitle("Ensemble variable importance") ``` <img src="slides-source_files/figure-html/vip-iml-1.png" style="display: block; margin: auto;" /> ] .pull-right[ .bolder[DALEX] ```r dalex_vip_glm <- variable_importance(dalex_explainer_glm, n_sample = -1) dalex_vip_rf <- variable_importance(dalex_explainer_rf, n_sample = -1) dalex_vip_gbm <- variable_importance(dalex_explainer_gbm, n_sample = -1) dalex_vip_ensemble <- variable_importance(dalex_explainer_ensemble, n_sample = -1) plot(dalex_vip_glm, dalex_vip_rf, dalex_vip_gbm, dalex_vip_ensemble, max_vars = 10) ``` <img src="slides-source_files/figure-html/vip-dalex-1.png" style="display: block; margin: auto;" /> ] ] --- # Global feature effects What is the overall relationship between features and the target? .pull-left-narrow[ * Feature effects - Partial dependence - .opacity10[Interactions] ] .pull-right-wide[ <img src="images/pdp.png" width="90%" height="90%" style="display: block; margin: auto;" /> ] --- # Partial dependence .scrollable90[ .pull-left[ .bolder[pdp] ```r pdp_fun <- function(object, newdata) { # compute partial dependence pd <- mean(predict(object, as.h2o(newdata))[[3L]]) # return data frame with average predicted value return(as.data.frame(pd)) } # partial dependence values pd_df <- partial( ensemble, pred.var = "OverTime", train = features, pred.fun = pdp_fun ) ``` ] .pull-right[ <br> ```r # partial dependence pd_df ## OverTime yhat ## 1 No 0.1137813 ## 2 Yes 0.3071652 # partial dependence plot autoplot(pd_df) ``` <img src="slides-source_files/figure-html/pdp-pdp1-plot-1.png" style="display: block; margin: auto;" /> ] ] --- # Partial dependence .scrollable90[ .pull-left[ .bolder[pdp] ```r # partial dependence values partial( ensemble, pred.var = "Age", train = features, pred.fun = pdp_fun, grid.resolution = 20 ) %>% autoplot(rug = TRUE, train = features) + ggtitle("Age") ``` <br> .bolder[iml] ```r Partial$new(iml_predictor_ensemble, "Age", ice = FALSE, grid.size = 20) %>% plot() + ggtitle("Age") ``` <br> .bolder[DALEX] ```r p1 <- variable_response(dalex_explainer_glm, variable = "Age", type = "pdp", grid.resolution = 20) p2 <- variable_response(dalex_explainer_rf, variable = "Age", type = "pdp", grid.resolution = 20) p3 <- variable_response(dalex_explainer_gbm, variable = "Age", type = "pdp", grid.resolution = 20) p4 <- variable_response(dalex_explainer_ensemble, variable = "Age", type = "pdp", grid.resolution = 20) plot(p1, p2, p3, p4) ``` ] .pull-right[ <br> <img src="slides-source_files/figure-html/all-pdp-plots-1.png" style="display: block; margin: auto;" /> ] ] --- # Global feature effects What is the overall relationship between features and the target? .pull-left[ * Feature effects - .opacity10[Partial dependence] - Interactions - one-way interactions ```r 1: for variable i in {1,...,p} do | f(x) = estimate predicted values with original model | pd(x) = partial dependence of variable i | pd(!x) = partial dependence of all features excluding i * | upper = sum(f(x) - pd(x) - pd(!x)) | lower = variance(f(x)) | rho = upper / lower end 2. Sort variables by descending rho (interaction strength) ``` ] .pull-right[ <img src="images/h-statistic.png" width="100%" height="100%" style="display: block; margin: auto;" /> ] --- # H-statistic: 1-way interaction .pull-left[ .bolder[iml] ```r interact <- Interaction$new(iml_predictor_ensemble) plot(interact) ``` - One of only a few implementations - Computationally intense ( `\(2n^2\)` runs) - took 53 minutes for data set with 100 features - Can parallelize (`vignette(“parallel”, package = “iml”)`) ] .pull-right[ <img src="slides-source_files/figure-html/h-statistic-plot-1.png" style="display: block; margin: auto;" /> ] --- # Global feature effects What is the overall relationship between features and the target? .pull-left[ * Feature effects - .opacity10[Partial dependence] - Interactions .opacity10[ - one-way interactions] - two-way interactions ```r 1: i = a selected variable of interest 2: for remaining variables j in {1,...,p} do * | pd(ij) = interaction partial dependence of variables i and j | pd(i) = partial dependence of variable i | pd(j) = partial dependence of variable j | upper = sum(pd(ij) - pd(i) - pd(j)) | lower = variance(pd(ij)) | rho = upper / lower end 3. Sort interaction relationship by descending rho (interaction strength) ``` ] .pull-right[ <img src="images/h-statistic-2way.png" width="75%" height="75%" style="display: block; margin: auto;" /> ] --- # H-statistic: 2-way interaction .scrollable[ .pull-left[ ```r interact_2way <- Interaction$new(iml_predictor_ensemble, feature = "OverTime") plot(interact_2way) ``` <img src="slides-source_files/figure-html/h-statistic-2way-plot-1.png" style="display: block; margin: auto;" /> ] .pull-right[ <img src="slides-source_files/figure-html/pdp-2way-plot-1.png" style="display: block; margin: auto;" /> ] ] <br> .center[.content-box-gray[H-statistic can point you to which interaction PDPs to look at.]] --- class: center, middle, inverse # Local Interpretation --- # Local Interpretation .pull-left[ Are feature effects uniform over all observations? * Feature effects - Individual conditional expectation curves (ICE) ] .pull-right[ How do features influence individual predictions? * Feature importance - Local individual model-agnostic explanation (LIME) - Shapley Values - Breakdown ] --- # ICE curves .pull-left[ Are feature effects uniform over all observations? * Feature effects - Individual conditional expectation curves (ICE) <br> .center[.blue[.font90[Same as PDPs but rather than average the effect across all observations, we keep and plot individual observation predictions]]] ] .pull-right[ <img src="slides-source_files/figure-html/example-ice-curve-1.png" style="display: block; margin: auto;" /> ] --- # ICE curves .scrollable90[ .pull-left[ .bolder[pdp] ```r # create custom predict function --> must return a data frame ice_fun <- function(object, newdata) { as.data.frame(predict(object, newdata = as.h2o(newdata)))[[3L]] } # individual conditional expectations values pd_df <- partial( gbm, pred.var = "Age", train = features, pred.fun = ice_fun, grid.resolution = 20 ) # ICE plots p1 <- autoplot(pd_df, alpha = 0.1) + ggtitle("Non-centered") *p2 <- autoplot(pd_df, alpha = 0.1, center = TRUE) + ggtitle("Centered") gridExtra::grid.arrange(p1, p2, ncol = 1) ``` ] .pull-right[ <img src="slides-source_files/figure-html/example-ice-curve-hidden-1.png" style="display: block; margin: auto;" /> ] ] --- # ICE curves .scrollable90[ .pull-left[ .bolder[iml] ```r iml_ice <- Partial$new(iml_predictor_gbm, "Age", ice = TRUE, grid.size = 20) *iml_ice$center(min(features$Age)) plot(iml_ice) ``` <br><br> ```r iml_ice <- Partial$new(iml_predictor_gbm, "OverTime", ice = TRUE, grid.size = 20) plot(iml_ice) ``` ] .pull-right[ <img src="slides-source_files/figure-html/iml-ice-plot-output-1.png" style="display: block; margin: auto;" /> ] ] --- # Local interpretation Given the following observations, how do features influence individual predictions? ```r # predictions predictions <- predict(gbm, df.h2o) %>% .[[3L]] %>% as.vector() # highest and lowest probabilities paste("Observation", which.max(predictions), "has", round(max(predictions), 2), "probability of attrition") ## [1] "Observation 128 has 0.97 probability of attrition" paste("Observation", which.min(predictions), "has", round(min(predictions), 2), "probability of attrition") ## [1] "Observation 277 has 0.01 probability of attrition" # get these observations high_prob_ob <- df[which.max(predictions), ] low_prob_ob <- df[which.min(predictions), ] ``` --- # LIME .pull-left[ How do features influence individual predictions? * Feature importance - Local individual model-agnostic explanation (LIME) LIME algorithm: .font80[ 1. __Permute__ your training data to create replicated feature data with same distribution. 2. Compute __similarity distance measure__ between the single observation of interest and the permuted observations. 3. Apply selected machine learning model to __predict outcomes__ of permuted data. 4. __Select m number of features__ to best describe predicted outcomes. 5. __Fit a simple model__ to the permuted data, explaining the complex model outcome with m features from the permuted data weighted by its similarity to the original observation . 6. Use the resulting __feature weights to explain local behavior__. ] ] .pull-right[ <img src="images/lime-fitting-1.png" width="100%" height="100%" style="display: block; margin: auto;" /> ] --- # LIME .pull-left[ .bolder[lime] ```r # create explainer object lime_explainer <- lime( x = df[, names(features)], model = ensemble, n_bins = 5 ) # perform lime algorithm lime_explanation <- lime::explain( x = high_prob_ob[, names(features)], explainer = lime_explainer, n_permutations = 5000, dist_fun = "gower", kernel_width = .75, n_features = 10, feature_select = "highest_weights", label = "p1" ) ``` ] .pull-right[ <br> ```r plot_features(lime_explanation) ``` <img src="slides-source_files/figure-html/plot-lime-explanation-1.png" style="display: block; margin: auto;" /> ] --- # LIME .pull-left[ .bolder[lime] ```r # create explainer object lime_explainer <- lime( x = df[, names(features)], model = ensemble, * n_bins = 6 ) # perform lime algorithm lime_explanation <- lime::explain( x = high_prob_ob[, names(features)], explainer = lime_explainer, * n_permutations = 8000, * dist_fun = "manhattan", * kernel_width = 3, * n_features = 10, * feature_select = "lasso_path", label = "p1" ) ``` ] .pull-right[ <br> ```r plot_features(lime_explanation) ``` <img src="slides-source_files/figure-html/plot-lime-explanation2-1.png" style="display: block; margin: auto;" /> ] --- # LIME .bolder[iml] .scrollable[ .pull-left[ ```r # fit local model to high probability ob lime <- LocalModel$new( predictor = iml_predictor_gbm, x.interest = high_prob_ob[, names(features)], dist.fun = "gower", kernel.width = NULL, k = 10 ) plot(lime) ``` ```r # reapply model to low probability observation *lime$explain(x.interest = low_prob_ob) plot(lime) ``` ] .pull-right[ <img src="slides-source_files/figure-html/iml-lime-plots-1.png" style="display: block; margin: auto;" /> ] ] <br> .center[.content-box-gray[__iml__'s implementation of LIME is easier to apply but not as robust as __lime__]] --- # Shapley values .pull-left[ How do features influence individual predictions? * Feature importance .opacity10[ - Local individual model-agnostic explanation (LIME)] - Shapley values - a method from coalitional game theory - tells us how to fairly distribute the ‘payout’ among contributors - computationally infeasible for any normally-sized data set ] .pull-right[ <img src="images/shapley-idea.png" width="30%" style="display: block; margin: auto;" /> ] --- # Shapley values - an approximation <img src="images/approx-shapley-idea.png" width="95%" style="display: block; margin: auto;" /> --- # Shapley values - an approximation .scrollable90[ .pull-left[ .bolder[iml] - one of only a few implementations - adjust `sample.size` for greater accuracy ( `\(\uparrow\)`) or improved computational speed ( `\(\downarrow\)`) ```r # compute shapley values shapley <- Shapley$new( iml_predictor_gbm, x.interest = high_prob_ob[, names(features)], * sample.size = 500 ) shapley ## Interpretation method: Shapley ## Predicted value: 0.968981, Average prediction: 0.161931 (diff = 0.807050) ## ## Analysed predictor: ## Prediction task: unknown ## ## ## Analysed data: ## Sampling from data.frame with 1470 rows and 30 columns. ## ## Head of results: ## feature phi phi.var feature.value ## 1 Age 0.045683646 4.968576e-03 Age=19 ## 2 BusinessTravel -0.006991512 1.596109e-03 BusinessTravel=Travel_Rarely ## 3 DailyRate 0.014440487 2.616942e-03 DailyRate=528 ## 4 Department 0.003553409 7.527069e-05 Department=Sales ## 5 DistanceFromHome 0.049930408 4.935489e-03 DistanceFromHome=22 ## 6 Education 0.008533806 1.132339e-03 Education=Below_College ``` ] .pull-right[ ```r # plot plot(shapley) ``` <img src="slides-source_files/figure-html/iml-shapley-plot-1.png" style="display: block; margin: auto;" /> ] ] --- # Breakdown .pull-left[ How do features influence individual predictions? * Feature importance .opacity10[ - Local individual model-agnostic explanation (LIME) - Shapley values] - Breakdown - less popular than LIME & Shapley - two approaches: _step up_ & _step down_ - computationally expensive ] .pull-right[ <br> <img src="images/breakdown-idea.png" width="110%" style="display: block; margin: auto;" /> ] --- # Breakdown .scrollable90[ .pull-left[ .bolder[DALEX] - uses __breakdown__ package to compute - only implementation of such algorithm - took __13 minutes__ to compute ```r # compute breakdown values values high_prob_breakdown <- prediction_breakdown( dalex_explainer_gbm, observation = high_prob_ob[, names(features)], * direction = "up" ) # check out the top 10 influential variables for this observation high_prob_breakdown[1:10, c(1, 2, 5)] ## variable contribution ## 1 (Intercept) 0.00000000 ## OverTime + OverTime = Yes 0.11089285 ## MonthlyIncome + MonthlyIncome = 1675 0.14161426 ## Age + Age = 19 0.10594394 ## YearsWithCurrManager + YearsWithCurrManager = 0 0.10226719 ## JobLevel + JobLevel = 1 0.09586214 ## JobRole + JobRole = Sales_Representative 0.08280089 ## DailyRate + DailyRate = 528 0.05050095 ## DistanceFromHome + DistanceFromHome = 22 0.02784709 ## WorkLifeBalance + WorkLifeBalance = Good 0.01935530 ## cummulative ## 1 0.0000000 ## OverTime 0.1108928 ## MonthlyIncome 0.2525071 ## Age 0.3584510 ## YearsWithCurrManager 0.4607182 ## JobLevel 0.5565804 ## JobRole 0.6393813 ## DailyRate 0.6898822 ## DistanceFromHome 0.7177293 ## WorkLifeBalance 0.7370846 ``` ] .pull-right[ ```r # plot plot(high_prob_breakdown) ``` <img src="slides-source_files/figure-html/dalex-breakdown-plot-1.png" style="display: block; margin: auto;" /> ] ] --- class: center, middle, inverse # Summary of Solutions --- # Packages & capabilities <br> <br> <img src="images/summary.png" width="100%" style="display: block; margin: auto;" /> --- class: center, middle, inverse # Concluding Remarks --- # Final thoughts .pull-left[ <br> <img src="slides-source_files/figure-html/chuck-norris-1.gif" style="display: block; margin: auto;" /> ] .pull-right[ .bolder[Learn more] * Interpretable machine learning: A guide for making black box models explainable [<svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M326.612 185.391c59.747 59.809 58.927 155.698.36 214.59-.11.12-.24.25-.36.37l-67.2 67.2c-59.27 59.27-155.699 59.262-214.96 0-59.27-59.26-59.27-155.7 0-214.96l37.106-37.106c9.84-9.84 26.786-3.3 27.294 10.606.648 17.722 3.826 35.527 9.69 52.721 1.986 5.822.567 12.262-3.783 16.612l-13.087 13.087c-28.026 28.026-28.905 73.66-1.155 101.96 28.024 28.579 74.086 28.749 102.325.51l67.2-67.19c28.191-28.191 28.073-73.757 0-101.83-3.701-3.694-7.429-6.564-10.341-8.569a16.037 16.037 0 0 1-6.947-12.606c-.396-10.567 3.348-21.456 11.698-29.806l21.054-21.055c5.521-5.521 14.182-6.199 20.584-1.731a152.482 152.482 0 0 1 20.522 17.197zM467.547 44.449c-59.261-59.262-155.69-59.27-214.96 0l-67.2 67.2c-.12.12-.25.25-.36.37-58.566 58.892-59.387 154.781.36 214.59a152.454 152.454 0 0 0 20.521 17.196c6.402 4.468 15.064 3.789 20.584-1.731l21.054-21.055c8.35-8.35 12.094-19.239 11.698-29.806a16.037 16.037 0 0 0-6.947-12.606c-2.912-2.005-6.64-4.875-10.341-8.569-28.073-28.073-28.191-73.639 0-101.83l67.2-67.19c28.239-28.239 74.3-28.069 102.325.51 27.75 28.3 26.872 73.934-1.155 101.96l-13.087 13.087c-4.35 4.35-5.769 10.79-3.783 16.612 5.864 17.194 9.042 34.999 9.69 52.721.509 13.906 17.454 20.446 27.294 10.606l37.106-37.106c59.271-59.259 59.271-155.699.001-214.959z"/></svg>](https://christophm.github.io/interpretable-ml-book/) * An introduction to machine learning interpretability [<svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M326.612 185.391c59.747 59.809 58.927 155.698.36 214.59-.11.12-.24.25-.36.37l-67.2 67.2c-59.27 59.27-155.699 59.262-214.96 0-59.27-59.26-59.27-155.7 0-214.96l37.106-37.106c9.84-9.84 26.786-3.3 27.294 10.606.648 17.722 3.826 35.527 9.69 52.721 1.986 5.822.567 12.262-3.783 16.612l-13.087 13.087c-28.026 28.026-28.905 73.66-1.155 101.96 28.024 28.579 74.086 28.749 102.325.51l67.2-67.19c28.191-28.191 28.073-73.757 0-101.83-3.701-3.694-7.429-6.564-10.341-8.569a16.037 16.037 0 0 1-6.947-12.606c-.396-10.567 3.348-21.456 11.698-29.806l21.054-21.055c5.521-5.521 14.182-6.199 20.584-1.731a152.482 152.482 0 0 1 20.522 17.197zM467.547 44.449c-59.261-59.262-155.69-59.27-214.96 0l-67.2 67.2c-.12.12-.25.25-.36.37-58.566 58.892-59.387 154.781.36 214.59a152.454 152.454 0 0 0 20.521 17.196c6.402 4.468 15.064 3.789 20.584-1.731l21.054-21.055c8.35-8.35 12.094-19.239 11.698-29.806a16.037 16.037 0 0 0-6.947-12.606c-2.912-2.005-6.64-4.875-10.341-8.569-28.073-28.073-28.191-73.639 0-101.83l67.2-67.19c28.239-28.239 74.3-28.069 102.325.51 27.75 28.3 26.872 73.934-1.155 101.96l-13.087 13.087c-4.35 4.35-5.769 10.79-3.783 16.612 5.864 17.194 9.042 34.999 9.69 52.721.509 13.906 17.454 20.446 27.294 10.606l37.106-37.106c59.271-59.259 59.271-155.699.001-214.959z"/></svg>](https://www.safaribooksonline.com/library/view/an-introduction-to/9781492033158/) * H2O's machine learning interpretability resources [<svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M326.612 185.391c59.747 59.809 58.927 155.698.36 214.59-.11.12-.24.25-.36.37l-67.2 67.2c-59.27 59.27-155.699 59.262-214.96 0-59.27-59.26-59.27-155.7 0-214.96l37.106-37.106c9.84-9.84 26.786-3.3 27.294 10.606.648 17.722 3.826 35.527 9.69 52.721 1.986 5.822.567 12.262-3.783 16.612l-13.087 13.087c-28.026 28.026-28.905 73.66-1.155 101.96 28.024 28.579 74.086 28.749 102.325.51l67.2-67.19c28.191-28.191 28.073-73.757 0-101.83-3.701-3.694-7.429-6.564-10.341-8.569a16.037 16.037 0 0 1-6.947-12.606c-.396-10.567 3.348-21.456 11.698-29.806l21.054-21.055c5.521-5.521 14.182-6.199 20.584-1.731a152.482 152.482 0 0 1 20.522 17.197zM467.547 44.449c-59.261-59.262-155.69-59.27-214.96 0l-67.2 67.2c-.12.12-.25.25-.36.37-58.566 58.892-59.387 154.781.36 214.59a152.454 152.454 0 0 0 20.521 17.196c6.402 4.468 15.064 3.789 20.584-1.731l21.054-21.055c8.35-8.35 12.094-19.239 11.698-29.806a16.037 16.037 0 0 0-6.947-12.606c-2.912-2.005-6.64-4.875-10.341-8.569-28.073-28.073-28.191-73.639 0-101.83l67.2-67.19c28.239-28.239 74.3-28.069 102.325.51 27.75 28.3 26.872 73.934-1.155 101.96l-13.087 13.087c-4.35 4.35-5.769 10.79-3.783 16.612 5.864 17.194 9.042 34.999 9.69 52.721.509 13.906 17.454 20.446 27.294 10.606l37.106-37.106c59.271-59.259 59.271-155.699.001-214.959z"/></svg>](https://github.com/h2oai/mli-resources) * Patrick Hall's machine learning intrepretability resources [<svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M326.612 185.391c59.747 59.809 58.927 155.698.36 214.59-.11.12-.24.25-.36.37l-67.2 67.2c-59.27 59.27-155.699 59.262-214.96 0-59.27-59.26-59.27-155.7 0-214.96l37.106-37.106c9.84-9.84 26.786-3.3 27.294 10.606.648 17.722 3.826 35.527 9.69 52.721 1.986 5.822.567 12.262-3.783 16.612l-13.087 13.087c-28.026 28.026-28.905 73.66-1.155 101.96 28.024 28.579 74.086 28.749 102.325.51l67.2-67.19c28.191-28.191 28.073-73.757 0-101.83-3.701-3.694-7.429-6.564-10.341-8.569a16.037 16.037 0 0 1-6.947-12.606c-.396-10.567 3.348-21.456 11.698-29.806l21.054-21.055c5.521-5.521 14.182-6.199 20.584-1.731a152.482 152.482 0 0 1 20.522 17.197zM467.547 44.449c-59.261-59.262-155.69-59.27-214.96 0l-67.2 67.2c-.12.12-.25.25-.36.37-58.566 58.892-59.387 154.781.36 214.59a152.454 152.454 0 0 0 20.521 17.196c6.402 4.468 15.064 3.789 20.584-1.731l21.054-21.055c8.35-8.35 12.094-19.239 11.698-29.806a16.037 16.037 0 0 0-6.947-12.606c-2.912-2.005-6.64-4.875-10.341-8.569-28.073-28.073-28.191-73.639 0-101.83l67.2-67.19c28.239-28.239 74.3-28.069 102.325.51 27.75 28.3 26.872 73.934-1.155 101.96l-13.087 13.087c-4.35 4.35-5.769 10.79-3.783 16.612 5.864 17.194 9.042 34.999 9.69 52.721.509 13.906 17.454 20.446 27.294 10.606l37.106-37.106c59.271-59.259 59.271-155.699.001-214.959z"/></svg>](https://github.com/jphall663/awesome-machine-learning-interpretability) * UC Business Analytics R Programming Guide [<svg style="height:0.8em;top:.04em;position:relative;fill:steelblue;" viewBox="0 0 512 512"><path d="M326.612 185.391c59.747 59.809 58.927 155.698.36 214.59-.11.12-.24.25-.36.37l-67.2 67.2c-59.27 59.27-155.699 59.262-214.96 0-59.27-59.26-59.27-155.7 0-214.96l37.106-37.106c9.84-9.84 26.786-3.3 27.294 10.606.648 17.722 3.826 35.527 9.69 52.721 1.986 5.822.567 12.262-3.783 16.612l-13.087 13.087c-28.026 28.026-28.905 73.66-1.155 101.96 28.024 28.579 74.086 28.749 102.325.51l67.2-67.19c28.191-28.191 28.073-73.757 0-101.83-3.701-3.694-7.429-6.564-10.341-8.569a16.037 16.037 0 0 1-6.947-12.606c-.396-10.567 3.348-21.456 11.698-29.806l21.054-21.055c5.521-5.521 14.182-6.199 20.584-1.731a152.482 152.482 0 0 1 20.522 17.197zM467.547 44.449c-59.261-59.262-155.69-59.27-214.96 0l-67.2 67.2c-.12.12-.25.25-.36.37-58.566 58.892-59.387 154.781.36 214.59a152.454 152.454 0 0 0 20.521 17.196c6.402 4.468 15.064 3.789 20.584-1.731l21.054-21.055c8.35-8.35 12.094-19.239 11.698-29.806a16.037 16.037 0 0 0-6.947-12.606c-2.912-2.005-6.64-4.875-10.341-8.569-28.073-28.073-28.191-73.639 0-101.83l67.2-67.19c28.239-28.239 74.3-28.069 102.325.51 27.75 28.3 26.872 73.934-1.155 101.96l-13.087 13.087c-4.35 4.35-5.769 10.79-3.783 16.612 5.864 17.194 9.042 34.999 9.69 52.721.509 13.906 17.454 20.446 27.294 10.606l37.106-37.106c59.271-59.259 59.271-155.699.001-214.959z"/></svg>](http://uc-r.github.io/) ] --- # References .scrollable[ Harrison Jr, D. and D. L. Rubinfeld (1978). "Hedonic housing prices and the demand for clean air". In: _Journal of environmental economics and management_ 5.1, pp. 81-102. ]