Chapter 9 Decision Trees

Tree-based models are a class of nonparametric algorithms that work by partitioning the feature space into a number of smaller (non-overlapping) regions with similar response values using a set of splitting rules. Predictions are obtained by fitting a simpler model (e.g., a constant like the average response value) in each region. Such divide-and-conquer methods can produce simple rules that are easy to interpret and visualize with tree diagrams. As we’ll see, decision trees offer many benefits; however, they typically lack in predictive performance compared to more complex algorithms like neural networks and MARS. However, future chapters will discuss powerful ensemble algorithms—like random forests and gradient boosting machines—which are constructed by combining together many decision trees in a clever way. This chapter will provide you with a strong foundation in decision trees.

9.1 Prerequisites

In this chapter we’ll use the following packages:

# Helper packages
library(dplyr)       # for data wrangling
library(ggplot2)     # for awesome plotting

# Modeling packages
library(rpart)       # direct engine for decision tree application
library(caret)       # meta engine for decision tree application

# Model interpretability packages
library(rpart.plot)  # for plotting decision trees
library(vip)         # for feature importance
library(pdp)         # for feature effects

We’ll continue to illustrate the main concepts using the Ames housing example from Section 2.7.

9.2 Structure

There are many methodologies for constructing decision trees but the most well-known is the classification and regression tree (CART) algorithm proposed in Breiman (1984).26 A basic decision tree partitions the training data into homogeneous subgroups (i.e., groups with similar response values) and then fits a simple constant in each subgroup (e.g., the mean of the within group response values for regression). The subgroups (also called nodes) are formed recursively using binary partitions formed by asking simple yes-or-no questions about each feature (e.g., is age < 18?). This is done a number of times until a suitable stopping criteria is satisfied (e.g., a maximum depth of the tree is reached). After all the partitioning has been done, the model predicts the output based on (1) the average response values for all observations that fall in that subgroup (regression problem), or (2) the class that has majority representation (classification problem). For classification, predicted probabilities can be obtained using the proportion of each class within the subgroup.

What results is an inverted tree-like structure such as that in Figure 9.1. In essence, our tree is a set of rules that allows us to make predictions by asking simple yes-or-no questions about each feature. For example, if the customer is loyal, has household income greater than $150,000, and is shopping in a store, the exemplar tree diagram in Figure 9.1 would predict that the customer will redeem a coupon.

Exemplar decision tree predicting whether or not a customer will redeem a coupon (yes or no) based on the customer's loyalty, household income, last month's spend, coupon placement, and shopping mode.

Figure 9.1: Exemplar decision tree predicting whether or not a customer will redeem a coupon (yes or no) based on the customer’s loyalty, household income, last month’s spend, coupon placement, and shopping mode.

We refer to the first subgroup at the top of the tree as the root node (this node contains all of the training data). The final subgroups at the bottom of the tree are called the terminal nodes or leaves. Every subgroup in between is referred to as an internal node. The connections between nodes are called branches.

Terminology of a decision tree.

Figure 9.2: Terminology of a decision tree.

9.3 Partitioning

As illustrated above, CART uses binary recursive partitioning (it’s recursive because each split or rule depends on the the splits above it). The objective at each node is to find the “best” feature (\(x_i\)) to partition the remaining data into one of two regions (\(R_1\) and \(R_2\)) such that the overall error between the actual response (\(y_i\)) and the predicted constant (\(c_i\)) is minimized. For regression problems, the objective function to minimize is the total SSE as defined in Equation (9.1) below:

\[\begin{equation} \tag{9.1} SSE = \sum_{i \in R_1}\left(y_i - c_1\right)^2 + \sum_{i \in R_2}\left(y_i - c_2\right)^2 \end{equation}\]

For classification problems, the partitioning is usually made to maximize the reduction in cross-entropy or the Gini index (see Section 2.6).27

In both regression and classification trees, the objective of partitioning is to minimize dissimilarity in the terminal nodes. However, we suggest Therneau, Atkinson, and others (1997) for a more thorough discussion regarding binary recursive partitioning.

Having found the best feature/split combination, the data are partitioned into two regions and the splitting process is repeated on each of the two regions (hence the name binary recursive partitioning). This process is continued until a suitable stopping criterion is reached (e.g., a maximum depth is reached or the tree becomes “too complex”).

It’s important to note that a single feature can be used multiple times in a tree. For example, say we have data generated from a simple \(\sin\) function with Gaussian noise: \(Y_i \stackrel{iid}{\sim} N\left(\sin\left(X_i\right), \sigma^2\right)\), for \(i = 1, 2, \dots, 500\). A regression tree built with a single root node (often referred to as a decision stump) leads to a split occurring at \(x = 3.1\).

Decision tree illustrating the single split on feature x (left). The resulting decision boundary illustrates the predicted value when x < 3.1 (0.64), and when x > 3.1 (-0.67) (right).Decision tree illustrating the single split on feature x (left). The resulting decision boundary illustrates the predicted value when x < 3.1 (0.64), and when x > 3.1 (-0.67) (right).

Figure 9.3: Decision tree illustrating the single split on feature x (left). The resulting decision boundary illustrates the predicted value when x < 3.1 (0.64), and when x > 3.1 (-0.67) (right).

If we build a deeper tree, we’ll continue to split on the same feature (\(x\)) as illustrated in Figure 9.4. This is because \(x\) is the only feature available to split on so it will continue finding the optimal splits along this feature’s values until a pre-determined stopping criteria is reached.

Decision tree illustrating with depth = 3, resulting in 7 decision splits along values of feature x and 8 prediction regions (left). The resulting decision boundary (right).Decision tree illustrating with depth = 3, resulting in 7 decision splits along values of feature x and 8 prediction regions (left). The resulting decision boundary (right).

Figure 9.4: Decision tree illustrating with depth = 3, resulting in 7 decision splits along values of feature x and 8 prediction regions (left). The resulting decision boundary (right).

However, even when many features are available, a single feature may still dominate if it continues to provide the best split after each successive partition. For example, a decision tree applied to the iris data set (Fisher 1936) where the species of the flower (setosa, versicolor, and virginica) is predicted based on two features (sepal width and sepal length) results in an optimal decision tree with two splits on each feature. Also, note how the decision boundary in a classification problem results in rectangular regions enclosing the observations. The predicted value is the response class with the greatest proportion within the enclosed region.

Decision tree for the iris classification problem (left). The decision boundary results in rectangular regions that enclose the observations.  The class with the highest proportion in each region is the predicted value (right).Decision tree for the iris classification problem (left). The decision boundary results in rectangular regions that enclose the observations.  The class with the highest proportion in each region is the predicted value (right).

Figure 9.5: Decision tree for the iris classification problem (left). The decision boundary results in rectangular regions that enclose the observations. The class with the highest proportion in each region is the predicted value (right).

9.4 How deep?

This leads to an important question: how deep (i.e., complex) should we make the tree? If we grow an overly complex tree as in Figure 9.6, we tend to overfit to our training data resulting in poor generalization performance.

Overfit decision tree with 56 splits.Overfit decision tree with 56 splits.

Figure 9.6: Overfit decision tree with 56 splits.

Consequently, there is a balance to be achieved in the depth and complexity of the tree to optimize predictive performance on future unseen data. To find this balance, we have two primary approaches: (1) early stopping and (2) pruning.

9.4.1 Early stopping

Early stopping explicitly restricts the growth of the tree. There are several ways we can restrict tree growth but two of the most common approaches are to restrict the tree depth to a certain level or to restrict the minimum number of observations allowed in any terminal node. When limiting tree depth we stop splitting after a certain depth (e.g., only grow a tree that has a depth of 5 levels). The shallower the tree the less variance we have in our predictions; however, at some point we can start to inject too much bias as shallow trees (e.g., stumps) are not able to capture interactions and complex patterns in our data.

When restricting minimum terminal node size (e.g., leaf nodes must contain at least 10 observations for predictions) we are deciding to not split intermediate nodes which contain too few data points. At the far end of the spectrum, a terminal node’s size of one allows for a single observation to be captured in the leaf node and used as a prediction (in this case, we’re interpolating the training data). This results in high variance and poor generalizability. On the other hand, large values restrict further splits therefore reducing variance.

These two approaches can be implemented independently of one another; however, they do have interaction effects as illustrated by Figure 9.7.

Illustration of how early stopping affects the decision boundary of a regression decision tree. The columns illustrate how tree depth impacts the decision boundary and the rows illustrate how the minimum number of observations in the terminal node influences the decision boundary.

Figure 9.7: Illustration of how early stopping affects the decision boundary of a regression decision tree. The columns illustrate how tree depth impacts the decision boundary and the rows illustrate how the minimum number of observations in the terminal node influences the decision boundary.

9.4.2 Pruning

An alternative to explicitly specifying the depth of a decision tree is to grow a very large, complex tree and then prune it back to find an optimal subtree. We find the optimal subtree by using a cost complexity parameter (\(\alpha\)) that penalizes our objective function in Equation (9.1) for the number of terminal nodes of the tree (\(T\)) as in Equation (9.2).

\[\begin{equation} \tag{9.2} \texttt{minimize} \left\{ SSE + \alpha \vert T \vert \right\} \end{equation}\]

For a given value of \(\alpha\) we find the smallest pruned tree that has the lowest penalized error. You may recognize the close association to the lasso penalty discussed in Chapter 6. As with the regularization methods, smaller penalties tend to produce more complex models, which result in larger trees. Whereas larger penalties result in much smaller trees. Consequently, as a tree grows larger, the reduction in the SSE must be greater than the cost complexity penalty. Typically, we evaluate multiple models across a spectrum of \(\alpha\) and use CV to identify the optimal value and, therefore, the optimal subtree that generalizes best to unseen data.

To prune a tree, we grow an overly complex tree (left) and then use a cost complexity parameter to identify the optimal subtree (right).

Figure 9.8: To prune a tree, we grow an overly complex tree (left) and then use a cost complexity parameter to identify the optimal subtree (right).

9.5 Ames housing example

We can fit a regression tree using rpart and then visualize it using rpart.plot. The fitting process and the visual output of regression trees and classification trees are very similar. Both use the formula method for expressing the model (similar to lm()). However, when fitting a regression tree, we need to set method = "anova". By default, rpart() will make an intelligent guess as to what method to use based on the data type of your response column, but it’s good practice to set this explicitly.

ames_dt1 <- rpart(
  formula = Sale_Price ~ .,
  data    = ames_train,
  method  = "anova"
)

Once we’ve fit our model we can take a peak at the decision tree output. This prints various information about the different splits. For example, we start with 2054 observations at the root node and the first variable we split on (i.e., the first variable gave the largest reduction in SSE) is Overall_Qual. We see that at the first node all observations with Overall_Qual \(\in\) \(\{\)Very_Poor, Poor, Fair, Below_Average, Average, Above_Average, Good\(\}\) go to the 2nd (2)) branch. The total number of observations that follow this branch (1708), their average sales price (156195) and SSE (3.964e+12) are listed. If you look for the 3rd branch (3)) you will see that 346 observations with Overall_Qual \(\in\) \(\{\)Very_Good, Excellent, Very_Excellent\(\}\) follow this branch and their average sales prices is 304593 and the SEE in this region is 1.036e+12. Basically, this is telling us that Overall_Qual is an important predictor on sales price with those homes on the upper end of the quality spectrum having almost double the average sales price.

ames_dt1
## n= 2054 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 2054 13216450000000 181192.80  
##    2) Overall_Qual=Very_Poor,Poor,Fair,Below_Average,Average,Above_Average,Good 1708  3963616000000 156194.90  
##      4) Neighborhood=North_Ames,Old_Town,Edwards,Sawyer,Mitchell,Brookside,Iowa_DOT_and_Rail_Road,South_and_West_of_Iowa_State_University,Meadow_Village,Briardale,Northpark_Villa,Blueste,Landmark 1022  1251428000000 131978.70  
##        8) Overall_Qual=Very_Poor,Poor,Fair,Below_Average 195   167094500000  98535.99 *
##        9) Overall_Qual=Average,Above_Average,Good 827   814819400000 139864.20  
##         18) First_Flr_SF< 1214.5 631   383938300000 132177.10 *
##         19) First_Flr_SF>=1214.5 196   273557300000 164611.70 *
##      5) Neighborhood=College_Creek,Somerset,Northridge_Heights,Gilbert,Northwest_Ames,Sawyer_West,Crawford,Timberland,Northridge,Stone_Brook,Clear_Creek,Bloomington_Heights,Veenker,Green_Hills 686  1219988000000 192272.10  
##       10) Gr_Liv_Area< 1725 492   517806100000 177796.00  
##         20) Total_Bsmt_SF< 1334.5 353   233343200000 166929.30 *
##         21) Total_Bsmt_SF>=1334.5 139   136919100000 205392.70 *
##       11) Gr_Liv_Area>=1725 194   337602800000 228984.70 *
##    3) Overall_Qual=Very_Good,Excellent,Very_Excellent 346  2916752000000 304593.10  
##      6) Overall_Qual=Very_Good 249   955363000000 272321.20  
##       12) Gr_Liv_Area< 1969 152   313458900000 244124.20 *
##       13) Gr_Liv_Area>=1969 97   331677500000 316506.30 *
##      7) Overall_Qual=Excellent,Very_Excellent 97  1036369000000 387435.20  
##       14) Total_Bsmt_SF< 1903 65   231940700000 349010.80 *
##       15) Total_Bsmt_SF>=1903 32   513524700000 465484.70  
##         30) Year_Built>=2003.5 25   270259300000 429760.40 *
##         31) Year_Built< 2003.5 7    97411210000 593071.40 *

We can visualize our tree model with rpart.plot(). The rpart.plot() function has many plotting options, which we’ll leave to the reader to explore. However, in the default print it will show the percentage of data that fall in each node and the predicted outcome for that node. One thing you may notice is that this tree contains 10 internal nodes resulting in 11 terminal nodes. In other words, this tree is partitioning on only 10 features even though there are 80 variables in the training data. Why is that?

rpart.plot(ames_dt1)
Diagram displaying the pruned decision tree for the Ames Housing data.

Figure 9.9: Diagram displaying the pruned decision tree for the Ames Housing data.

Behind the scenes rpart() is automatically applying a range of cost complexity (\(\alpha\) values to prune the tree). To compare the error for each \(\alpha\) value, rpart() performs a 10-fold CV (by default). In this example we find diminishing returns after 12 terminal nodes as illustrated in Figure 9.10 (\(y\)-axis is the CV error, lower \(x\)-axis is the cost complexity (\(\alpha\)) value, upper \(x\)-axis is the number of terminal nodes (i.e., tree size = \(\vert T \vert\)). You may also notice the dashed line which goes through the point \(\vert T \vert = 8\). Breiman (1984) suggested that in actual practice, it’s common to instead use the smallest tree within 1 standard error (SE) of the minimum CV error (this is called the 1-SE rule). Thus, we could use a tree with 8 terminal nodes and reasonably expect to experience similar results within a small margin of error.

plotcp(ames_dt1)
Pruning complexity parameter (cp) plot illustrating the relative cross validation error (y-axis) for various cp values (lower x-axis). Smaller cp values lead to larger trees (upper x-axis). Using the 1-SE rule, a tree size of 10-12 provides optimal cross validation results.

Figure 9.10: Pruning complexity parameter (cp) plot illustrating the relative cross validation error (y-axis) for various cp values (lower x-axis). Smaller cp values lead to larger trees (upper x-axis). Using the 1-SE rule, a tree size of 10-12 provides optimal cross validation results.

To illustrate the point of selecting a tree with 11 terminal nodes (or 8 if you go by the 1-SE rule), we can force rpart() to generate a full tree by setting cp = 0 (no penalty results in a fully grown tree). Figure 9.11 shows that after 11 terminal nodes, we see diminishing returns in error reduction as the tree grows deeper. Thus, we can significantly prune our tree and still achieve minimal expected error.

ames_dt2 <- rpart(
    formula = Sale_Price ~ .,
    data    = ames_train,
    method  = "anova", 
    control = list(cp = 0, xval = 10)
)

plotcp(ames_dt2)
abline(v = 11, lty = "dashed")
Pruning complexity parameter plot for a fully grown tree. Significant reduction in the cross validation error is achieved with tree sizes 6-20 and then the cross validation error levels off with minimal or no additional improvements.

Figure 9.11: Pruning complexity parameter plot for a fully grown tree. Significant reduction in the cross validation error is achieved with tree sizes 6-20 and then the cross validation error levels off with minimal or no additional improvements.

So, by default, rpart() is performing some automated tuning, with an optimal subtree of 10 total splits, 11 terminal nodes, and a cross-validated SSE of 0.292. Although rpart() does not provide the RMSE or other metrics, you can use caret. In both cases, smaller penalties (deeper trees) are providing better CV results.

# rpart cross validation results
ames_dt1$cptable
##            CP nsplit rel error    xerror       xstd
## 1  0.47940879      0 1.0000000 1.0014737 0.06120398
## 2  0.11290476      1 0.5205912 0.5226036 0.03199501
## 3  0.06999005      2 0.4076864 0.4098819 0.03111581
## 4  0.02758522      3 0.3376964 0.3572726 0.02222507
## 5  0.02347276      4 0.3101112 0.3339952 0.02184348
## 6  0.02201070      5 0.2866384 0.3301630 0.02446178
## 7  0.02039233      6 0.2646277 0.3244948 0.02421833
## 8  0.01190364      7 0.2442354 0.3062031 0.02641595
## 9  0.01116365      8 0.2323317 0.3025968 0.02708786
## 10 0.01103581      9 0.2211681 0.2971663 0.02704837
## 11 0.01000000     10 0.2101323 0.2920442 0.02704791

# caret cross validation results
ames_dt3 <- train(
  Sale_Price ~ .,
  data = ames_train,
  method = "rpart",
  trControl = trainControl(method = "cv", number = 10),
  tuneLength = 20
)

ggplot(ames_dt3)
Cross-validated accuracy rate for the 20 different $\alpha$ parameter values in our grid search. Lower $\alpha$ values (deeper trees) help to minimize errors.

Figure 9.12: Cross-validated accuracy rate for the 20 different \(\alpha\) parameter values in our grid search. Lower \(\alpha\) values (deeper trees) help to minimize errors.

9.6 Feature interpretation

To measure feature importance, the reduction in the loss function (e.g., SSE) attributed to each variable at each split is tabulated. In some instances, a single variable could be used multiple times in a tree; consequently, the total reduction in the loss function across all splits by a variable are summed up and used as the total feature importance. When using caret, these values are standardized so that the most important feature has a value of 100 and the remaining features are scored based on their relative reduction in the loss function. Also, since there may be candidate variables that are important but are not used in a split, the top competing variables are also tabulated at each split.

Figure 9.13 illustrates the top 40 features in the Ames housing decision tree. Similar to MARS (Chapter 7), decision trees perform automated feature selection where uninformative features are not used in the model. We can see this in Figure 9.13 where the bottom four features in the plot have zero importance.

vip(ames_dt3, num_features = 40, bar = FALSE)
Variable importance based on the total reduction in MSE for the Ames Housing decision tree.

Figure 9.13: Variable importance based on the total reduction in MSE for the Ames Housing decision tree.

If we look at the same partial dependence plots that we created for the MARS models (Section 7.5), we can see the similarity in how decision trees are modeling the relationship between the features and target. In Figure 9.14, we see that Gr_Liv_Area has a non-linear relationship such that it has increasingly stronger effects on the predicted sales price for Gr_liv_Area values between 1000–2500 but then has little, if any, influence when it exceeds 2500. However, the 3-D plot of the interaction effect between Gr_Liv_Area and Year_Built illustrates a key difference in how decision trees have rigid non-smooth prediction surfaces compared to MARS; in fact, MARS was developed as an improvement to CART for regression problems.

# Construct partial dependence plots
p1 <- partial(ames_dt3, pred.var = "Gr_Liv_Area") %>% autoplot()
p2 <- partial(ames_dt3, pred.var = "Year_Built") %>% autoplot()
p3 <- partial(ames_dt3, pred.var = c("Gr_Liv_Area", "Year_Built")) %>% 
  plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE, 
              colorkey = TRUE, screen = list(z = -20, x = -60))

# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)
Partial dependence plots to understand the relationship between sale price and the living space, and year built features.

Figure 9.14: Partial dependence plots to understand the relationship between sale price and the living space, and year built features.

9.7 Final thoughts

Decision trees have a number of advantages. Trees require very little pre-processing. This is not to say feature engineering may not improve upon a decision tree, but rather, that there are no pre-processing requirements. Monotonic transformations (e.g., \(\log\), \(\exp\), and \(\sqrt{}\)) are not required to meet algorithm assumptions as in many parametric models; instead, they only shift the location of the optimal split points. Outliers typically do not bias the results as much since the binary partitioning simply looks for a single location to make a split within the distribution of each feature.

Decision trees can easily handle categorical features without preprocessing. For unordered categorical features with more than two levels, the classes are ordered based on the outcome (for regression problems, the mean of the response is used and for classification problems, the proportion of the positive outcome class is used). For more details see J. Friedman, Hastie, and Tibshirani (2001), Breiman and Ihaka (1984), Ripley (2007), Fisher (1958), and Loh and Vanichsetakul (1988).

Missing values often cause problems with statistical models and analyses. Most procedures deal with them by refusing to deal with them—incomplete observations are tossed out. However, most decision tree implementations can easily handle missing values in the features and do not require imputation. This is handled in various ways but most commonly by creating a new “missing” class for categorical variables or using surrogate splits (see Therneau, Atkinson, and others (1997) for details).

However, individual decision trees generally do not often achieve state-of-the-art predictive accuracy. In this chapter, we saw that the best pruned decision tree, although it performed better than linear regression (Chapter 4), had a very poor RMSE ($41,019) compared to some of the other models we’ve built. This is driven by the fact that decision trees are composed of simple yes-or-no rules that create rigid non-smooth decision boundaries. Furthermore, we saw that deep trees tend to have high variance (and low bias) and shallow trees tend to be overly bias (but low variance). In the chapters that follow, we’ll see how we can combine multiple trees together into very powerful prediction models called ensembles.

References

Breiman, Leo. 1984. Classification and Regression Trees. Routledge.

Breiman, Leo, and Ross Ihaka. 1984. Nonlinear Discriminant Analysis via Scaling and Ace. Department of Statistics, University of California.

Fisher, Ronald A. 1936. “The Use of Multiple Measurements in Taxonomic Problems.” Annals of Eugenics 7 (2). Wiley Online Library: 179–88.

Fisher, Walter D. 1958. “On Grouping for Maximum Homogeneity.” Journal of the American Statistical Association 53 (284). Taylor & Francis: 789–98.

Friedman, Jerome, Trevor Hastie, and Robert Tibshirani. 2001. The Elements of Statistical Learning. Vol. 1. Springer Series in Statistics New York, NY, USA:

Loh, Wei-Yin, and Nunta Vanichsetakul. 1988. “Tree-Structured Classification via Generalized Discriminant Analysis.” Journal of the American Statistical Association 83 (403). Taylor & Francis Group: 715–25.

Ripley, Brian D. 2007. Pattern Recognition and Neural Networks. Cambridge University Press.

Therneau, Terry M, Elizabeth J Atkinson, and others. 1997. “An Introduction to Recursive Partitioning Using the RPART Routines.” Mayo Foundation.


  1. Other decision tree algorithms include the Iterative Dichotomiser 3 (Quinlan 1986), C4.5 (Quinlan and others 1996), Chi-square automatic interaction detection (Kass 1980), Conditional inference trees (Hothorn, Hornik, and Zeileis 2006), and more.

  2. Gini index and cross-entropy are the two most commonly applied loss functions used for decision trees. Classification error is rarely used to determine partitions as they are less sensitive to poor performing splits (J. Friedman, Hastie, and Tibshirani 2001).