Explaining Predictions: Random Forest Post-hoc Analysis (permutation & impurity variable importance)

Intro

Recap

There are 2 approaches to explaining models

  1. Use simple interpretable models. This approach was covered in the previous posts where we looked at logistic regression and decision trees as examples of white box models.
  2. Conduct post-hoc interpretation on models. There are two are two types of post-hoc analysis which can be done, model specific and model agonistic.

Direction of post

In the next few posts, we will look at model specific post-hoc analysis which involves ranking the variables according to importance to the model. Though these interpretation can be applied on both white box and black box models, we will be explaining black box models in these posts as white box models can be explained by the models themselves. Specifically, we will explain random forest in this post and gradient boosting in future posts. Similar to the previous posts, the Cleveland heart dataset will be used as well as principles of tidymodels.

#library
library(tidyverse)
library(tidymodels)
#import
heart<-read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data", col_names = F)
# Renaming var 
colnames(heart)<- c("age", "sex", "rest_cp", "rest_bp",
"chol", "fast_bloodsugar","rest_ecg","ex_maxHR","ex_cp",
"ex_STdepression_dur", "ex_STpeak","coloured_vessels", "thalassemia","heart_disease")
#elaborating cat var
##simple ifelse conversion 
heart<-heart %>% mutate(sex= ifelse(sex=="1", "male", "female"),fast_bloodsugar= ifelse(fast_bloodsugar=="1", ">120", "<120"), ex_cp=ifelse(ex_cp=="1", "yes", "no"),
heart_disease=ifelse(heart_disease=="0", "no", "yes")) 

heart<-heart %>% mutate(
rest_cp=case_when(rest_cp== "1" ~ "typical",rest_cp=="2" ~ "atypical", rest_cp== "3" ~ "non-CP pain",rest_cp== "4" ~ "asymptomatic"), rest_ecg=case_when(rest_ecg=="0" ~ "normal",rest_ecg=="1" ~ "ST-T abnorm",rest_ecg=="2" ~ "LV hyperthrophy"), ex_STpeak=case_when(ex_STpeak=="1" ~ "up/norm", ex_STpeak== "2" ~ "flat",ex_STpeak== "3" ~ "down"), thalassemia=case_when(thalassemia=="3.0" ~ "norm", 
  thalassemia== "6.0" ~ "fixed", thalassemia== "7.0" ~ "reversable")) 
# convert missing value "?" into NA
heart<-heart%>% mutate_if(is.character, funs(replace(., .=="?", NA)))
# convert char into factors
heart<-heart %>% mutate_if(is.character, as.factor)
#train/test set 
set.seed(4595)
data_split <- initial_split(heart, prop=0.75, strata = "heart_disease")
heart_train <- training(data_split)
heart_test <- testing(data_split)
# create recipe object
heart_recipe<-recipe(heart_disease ~., data= heart_train) %>%
  step_knnimpute(all_predictors()) %>% 
step_dummy(all_nominal(), -heart_disease)# need dummy variables to be created for some `randomForestexplainer` functions though random forest model itself doesnt need explicit one hot encoding 
# process the traing set/ prepare recipe(non-cv)
heart_prep <-heart_recipe %>% prep(training = heart_train, retain = TRUE)

Random forest

Ensemble Models (tree-based)

Random forests and gradient boosting are also examples of tree based ensemble models. In tree based ensemble models, multiple variants of a simple tree are combined to develop a more complex model with higher prediction power. The trade off of the ensemble is that the model becomes a black box and cannot be interpreted as simply as a decision tree. Thus, post-hoc analysis is required to explain the ensemble model.

Random forest mechanism

Random forest is a step up of bootstrap aggregating/ bagging and bootstrap aggregating is a step up of decision tree. In bagging, multiple bootstrap samples are of the training data are used to train multiple single regression tree. As bootstrap samples are used, the same observation from the training set can be used more than once to train a regression tree. For regression problems, the ensemble occurs when the predictions of each tree are averaged to provide the predicted value for the bagging model. For classification problems, the ensemble occurs when the class probabilities of each tree are averaged to provide the predicted class. Alternatively, the majority predicted class of the ensemble of trees will be the predicted class for the model. Random forest is similar to bootstrap aggregating except that a random subset of variables is used for each split. The number of variables in this random subset of variables is known as mtry. The default mtry in regression problems is a third of the total number of available variables in the dataset. The default mtry in classification problems is the square root of the total number of available variables.

Training random forest models

We will train two random forest where each model adopts a different ranking approach for feature importance. The two ranking measurements are:

  1. Permutation based. Permuting values in a variable decouples any relationship between the predictor and the outcome which renders the variable pseudo present in the model. When the out of bag sample is passed down the model with the permutated variable, the increased in error/ decreased in accuracy of the model can be calculated. This accuracy/error is compared against the baseline accuracy/error (for regression problems, the measurement will be R2) to calculate the importance score. This computation is repeated for all variables. The variable with the largest decrease in accuracy/largest increase in error is considered the most important variable because not considering the variable has the largest penalty on prediction. In the case of the randomForest package, the score measures the decrease in accuracy and this score is refer as the MeanDecreaseAccuracy.

  2. Impurity based. Impurity refers to gini impurity/ gini index. The concept of impurity for random forest is the same as regression tree. Features which are more important have a lower impurity score/ higher purity score/ higher decrease in impurity score. The randomForest package, adopts the latter score which known as MeanDecreaseGini.

set.seed(69)
rf_model<-rand_forest(trees = 2000, mtry = 4, mode = "classification") %>% set_engine("randomForest",
# importance = T to have permutation score calculated
importance=T,
# localImp=T for randomForestExplainer(next post)
localImp = T, ) %>% fit(heart_disease ~ ., data = juice(heart_prep))

Feature Importance

We can extract permutation based importance as well as gini based importance directly from the model trained by the randomForest package.

Permutation-based importance

Using the tidyverse approach to the extract results, remember to convert MeanDecreaseAccuracy from character to numeric form for arrange to sort the variables correctly. Otherwise, R will recognise the value based on the first digit while ignoring log/exp values. For instance, if MeanDecreaseAccuracy was in character format, rest_ecg_ST.T.abnorm will have the highest value as R recognises the first digit in 4.1293934444743e-05 and fails to recognise that e-05 renders that value close to 0.

cbind(rownames(rf_model$fit$importance), rf_model$fit$importance) %>% as_tibble() %>% select(predictor= V1, MeanDecreaseAccuracy)   %>% 
 mutate(MeanDecreaseAccuracy=as.numeric(MeanDecreaseAccuracy)) %>% arrange(desc((MeanDecreaseAccuracy))) %>% head()
## # A tibble: 6 x 2
##   predictor              MeanDecreaseAccuracy
##   <chr>                                 <dbl>
## 1 ex_STdepression_dur                  0.0365
## 2 thalassemia_norm                     0.0332
## 3 thalassemia_reversable               0.0323
## 4 ex_maxHR                             0.0226
## 5 sex_male                             0.0111
## 6 coloured_vessels_X2.0                0.0109

As the model is trained using the randomForest package, we can use randomForest::importance to extract the results too. Use argument type=1 to extract permutation-based importance and set scale=F to prevent normalization. There are benefits for not scaling the permutation-based score.

cbind(rownames(randomForest::importance(rf_model$fit, type=1, scale=F)) %>% as_tibble(), randomForest::importance(rf_model$fit, type=1, scale=F) %>% as_tibble())%>% arrange(desc(MeanDecreaseAccuracy)) %>% head()
##                    value MeanDecreaseAccuracy
## 1    ex_STdepression_dur           0.03651391
## 2       thalassemia_norm           0.03321383
## 3 thalassemia_reversable           0.03231399
## 4               ex_maxHR           0.02256749
## 5               sex_male           0.01106189
## 6  coloured_vessels_X2.0           0.01087665

Impurity-based importance

Likewise, we can use both approaches to calculate impurity-based importance. Using the tidyverse way:

cbind(rownames(rf_model$fit$importance), rf_model$fit$importance) %>% as_tibble() %>% select(predictor= V1, MeanDecreaseGini)   %>% 
 mutate(MeanDecreaseGini=as.numeric(MeanDecreaseGini)) %>% arrange(desc((MeanDecreaseGini))) %>% head(10)
## # A tibble: 10 x 2
##    predictor              MeanDecreaseGini
##    <chr>                             <dbl>
##  1 ex_maxHR                          14.9 
##  2 ex_STdepression_dur               14.5 
##  3 thalassemia_norm                  11.1 
##  4 thalassemia_reversable             9.90
##  5 rest_bp                            8.84
##  6 age                                8.69
##  7 chol                               8.00
##  8 ex_cp_yes                          5.56
##  9 coloured_vessels_X2.0              3.77
## 10 rest_cp_typical                    3.34

For the randomForest::importance function, we will change the argument type to 2. The scale argument does not work for impurity-based scores.

cbind(rownames(randomForest::importance(rf_model$fit, type=2)) %>% as_tibble(), randomForest::importance(rf_model$fit, type=2) %>% as_tibble())%>% arrange(desc(MeanDecreaseGini)) %>% head(10)
##                     value MeanDecreaseGini
## 1                ex_maxHR        14.850652
## 2     ex_STdepression_dur        14.537241
## 3        thalassemia_norm        11.128131
## 4  thalassemia_reversable         9.902711
## 5                 rest_bp         8.838706
## 6                     age         8.686010
## 7                    chol         7.997923
## 8               ex_cp_yes         5.557652
## 9   coloured_vessels_X2.0         3.769974
## 10        rest_cp_typical         3.342365

Comments

The length of the script for the randomForest::importance is shorter but users need to be familiar with the details of the package and the functions. On the other hand, the tidyverse approach is more verbose but it is user friendly to everyday R users.

Both measurements showed differing importance for the some variables. This is expected as each measurement uses a different method to calculate the score. There are some articles suggesting using permutation-based importance as the preferred measurement for feature importance.

To sum up

In this post we did a post-hoc analysis on random forest to understand the model by using permutation and impurity variable importance ranking. In the next post, we will continue post-hoc analysis on random forest model using other techniques.