In this lab we will go through the model building, validation, and interpretation of tree models. The focus will be on rpart package.
CART stands for classification and regression trees:
For the regression trees example, we will use the Boston Housing data. Recall the response variable is the housing price. For the classification trees example, we will use the credit scoring data. The response variable is whether the loan went to default.
Note that unlkie logistic regreesion, the response variable does not have to be binary in case of classification trees. We can use classification trees on classification problems with more than 2 outcomes.
Let us load the data sets. Random sampled training and test datasets will lead to different results,
library(MASS) #this data is in MASS package
boston_data <- data(Boston)
sample_index <- sample(nrow(Boston),nrow(Boston)*0.90)
boston_train <- Boston[sample_index,]
boston_test <- Boston[-sample_index,]
We will use the ‘rpart’ library for model building and ‘rpart.plot’ for plotting.
install.packages('rpart')
install.packages('rpart.plot')
library(rpart)
library(rpart.plot)
The simple form of the rpart function is similar to lm and glm. It takes a formula argument in which you specify the response and predictor variables, and a data argument in which you specify the data frame.
boston_rpart <- rpart(formula = medv ~ ., data = boston_train)
boston_rpart
## n= 455
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 455 38652.3900 22.44593
## 2) rm< 6.9715 393 16978.1700 19.96387
## 4) lstat>=14.4 159 3089.0480 14.80314
## 8) crim>=6.99237 70 987.7749 11.79143 *
## 9) crim< 6.99237 89 966.9598 17.17191 *
## 5) lstat< 14.4 234 6777.0670 23.47051
## 10) lstat>=5.41 204 3722.1120 22.43627
## 20) lstat>=9.725 102 611.0287 20.65784 *
## 21) lstat< 9.725 102 2465.8680 24.21471
## 42) nox< 0.589 95 1054.1060 23.58105 *
## 43) nox>=0.589 7 855.9486 32.81429 *
## 11) lstat< 5.41 30 1352.9300 30.50333
## 22) tax< 364 23 295.2365 28.05652 *
## 23) tax>=364 7 467.5571 38.54286 *
## 3) rm>=6.9715 62 3906.2030 38.17903
## 6) rm< 7.437 35 696.3760 33.02000 *
## 7) rm>=7.437 27 1070.7200 44.86667 *
prp(boston_rpart,digits = 4, extra = 1)
Make sure you know how to interpret this tree model!
Exercise: What is the predicted median housing price (in thousand) given following information:
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat | medv |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0.09 | 0 | 25.65 | 0 | 0.58 | 5.96 | 92.9 | 2.09 | 2 | 188 | 19.1 | 378.09 | 17.93 | 20.5 |
The in-sample and out-of-sample prediction for regression trees is also similar to lm and glm models.
boston_train_pred_tree = predict(boston_rpart)
boston_test_pred_tree = predict(boston_rpart,boston_test)
We often denote MSE as training error, and MSPE as testing error when sample size is large.
Exercise: Calculate the mean squared error (MSE) for this tree model
MSE.tree<-
MSPE.tree <-
We can compare this model’s out-of-sample performance with the linear regression model with all variables in it.
boston.reg = lm(medv~., data = boston_train)
boston_test_pred_reg = predict(boston.reg, boston_test)
mean((boston_test_pred_reg - boston_test$medv)^2)
## [1] 21.7341
Calculate the average (mean) squared error (MSE) and mean squared prediction error (MSPE) for linear regression model using all variables. Then compare the results. What is your conclusion? Further, try to compare the regression trees with the best linear regression model using some variable selection procedures.
boston_lm<-
boston_train_pred_lm<-
boston_test_pred_lm<-
MSE_lm<-
MSPE_lm<-
In rpart(), the cp(complexity parameter) argument is one of the parameters that are used to control the compexity of the tree. The help document for rpart tells you “Any split that does not decrease the overall lack of fit by a factor of cp is not attempted”. For a regression tree, the overall R-square must increase by cp at each step. Basically, the smaller the cp value, the larger (complex) tree rpart will attempt to fit. The default value for cp is 0.01.
What happens when you have a large tree? The following tree has 27 splits.
boston_largetree <- rpart(formula = medv ~ ., data = boston_train, cp = 0.001)
Try plot it yourself to see its structure.
prp(boston_largetree)
The plotcp() function gives the relationship between 10-fold cross-validation error in the training set and size of tree.
plotcp(boston_largetree)
You can observe from the above graph that the cross-validation error (x-val) does not always go down when the tree becomes more complex. The analogy is when you add more variables in a regression model, its ability to predict future observations not necessarily increases.
A good choice of cp for pruning is often the leftmost value for which the mean lies below the horizontal line. In the Boston housing example, you may conclude that having a tree mode with more than 10 splits is not helpful.
To look at the error vs size of tree more carefully, you can look at the following table:
printcp(boston_largetree)
##
## Regression tree:
## rpart(formula = medv ~ ., data = boston_train, cp = 0.001)
##
## Variables actually used in tree construction:
## [1] age crim dis lstat nox ptratio rm tax
##
## Root node error: 38652/455 = 84.95
##
## n= 455
##
## CP nsplit rel error xerror xstd
## 1 0.4596875 0 1.00000 1.00268 0.088248
## 2 0.1840003 1 0.54031 0.61664 0.060139
## 3 0.0553422 2 0.35631 0.44374 0.049993
## 4 0.0440342 3 0.30097 0.37817 0.047625
## 5 0.0293465 4 0.25694 0.36334 0.048169
## 6 0.0166928 5 0.22759 0.31131 0.044247
## 7 0.0152678 6 0.21090 0.28782 0.041987
## 8 0.0143798 7 0.19563 0.28442 0.041996
## 9 0.0087457 8 0.18125 0.27679 0.041338
## 10 0.0079478 9 0.17250 0.26690 0.041899
## 11 0.0061900 10 0.16456 0.26761 0.041865
## 12 0.0059678 11 0.15837 0.26538 0.041927
## 13 0.0049986 12 0.15240 0.26441 0.041942
## 14 0.0042142 13 0.14740 0.25229 0.041068
## 15 0.0038490 14 0.14318 0.24916 0.040708
## 16 0.0034799 15 0.13934 0.24782 0.040654
## 17 0.0032086 16 0.13586 0.23767 0.037795
## 18 0.0021502 17 0.13265 0.23588 0.037810
## 19 0.0020504 18 0.13050 0.23732 0.037840
## 20 0.0017577 19 0.12845 0.23777 0.037846
## 21 0.0014051 21 0.12493 0.23877 0.037922
## 22 0.0012637 22 0.12353 0.24300 0.037957
## 23 0.0012431 23 0.12226 0.24211 0.037956
## 24 0.0011411 24 0.12102 0.24185 0.037940
## 25 0.0010000 27 0.11760 0.24218 0.037946
Root node error is the error when you do not do anything too smart in prediction, in regression case, it is the average (mean) squared error (MSE) if you use the average of medv as the prediction. Note it is the same as
sum((boston_train$medv - mean(boston_train$medv))^2)/nrow(boston_train)
## [1] 84.95031
The first 2 columns CP and nsplit tells you how large the tree is. rel.error \(\times\) root node error gives you the in sample error.
xerror gives you the cross-validation (default is 10-fold) error. You can see that the rel error (in-sample error) is always decreasing as model is more complex, while the cross-validation error (measure of performance on future observations) is not. That is why we prune the tree to avoid overfitting the training data.
The way rpart() does it is that it uses some default control parameters to avoid fitting a large tree. The main reason for this approach is to save computation time. For example by default rpart set a cp = 0.01 and the minimum number of observations that must exist in a node to be 20. Use ?rpart.control to view these parameters. Sometimes we wish to change these parameters to see how more complex trees will perform, as we did above. If we have a larger than necessary tree, we can use prune() function and specify a new cp:
prune(boston_largetree, cp = 0.008)
## n= 455
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 455 38652.3900 22.44593
## 2) rm< 6.9715 393 16978.1700 19.96387
## 4) lstat>=14.4 159 3089.0480 14.80314
## 8) crim>=6.99237 70 987.7749 11.79143 *
## 9) crim< 6.99237 89 966.9598 17.17191 *
## 5) lstat< 14.4 234 6777.0670 23.47051
## 10) lstat>=5.41 204 3722.1120 22.43627
## 20) lstat>=9.725 102 611.0287 20.65784 *
## 21) lstat< 9.725 102 2465.8680 24.21471
## 42) nox< 0.589 95 1054.1060 23.58105 *
## 43) nox>=0.589 7 855.9486 32.81429 *
## 11) lstat< 5.41 30 1352.9300 30.50333
## 22) tax< 364 23 295.2365 28.05652 *
## 23) tax>=364 7 467.5571 38.54286 *
## 3) rm>=6.9715 62 3906.2030 38.17903
## 6) rm< 7.437 35 696.3760 33.02000 *
## 7) rm>=7.437 27 1070.7200 44.86667
## 14) ptratio>=17.6 7 465.9686 38.88571 *
## 15) ptratio< 17.6 20 266.7080 46.96000 *
Exercise: Prune a classification tree. Start with “cp=0.001”, and find a reasonable cp value, then obtain the pruned tree.
Some software/packages can automatically prune the tree.