Tree-based methods can be applied to both regression and classification problems. This article is about regression trees and how to prune them using cost complexity pruning.
What are regression trees?
Though the tree in this animation is performing the task of classification, it helps to visualize how the predictor space is being divided into 4 non-overlapping regions using a decision tree (on the left). These 4 regions (on the right) are called the ‘terminal nodes’/ ‘leaves’ of the tree. The first point of the tree which splits the predictor space is called ‘root node’, but every such subsequent point is called an ‘internal node’. Hence, this tree has one root node, two internal nodes, and four terminal nodes.
Decision trees where the target variable takes continuous values are called regression trees.
The end goal of regression trees is not different from that of other machine learning models, i.e minimize errors in prediction. Specifically, the goal for regression trees is to minimize RSS (residual sum of squares). In other words, we want to find divide the predictor space into regions R1 to RK , such that we can minimize
However, it is computationally infeasible to consider every possible partition of the predictor space into K regions (which oddly reminds me of Dr. Strange’s amazing super powers!).
Hence, we use an approach called “recursive binary splitting.”
Recursive binary splitting
In this process, we begin at the top of the tree and then successively go on splitting the predictor space, with each split leading to two new branches (hence, binary) further down. However, at every step of the tree-building process, the model looks for the best split only at that particular step, instead of looking ahead and picking a split which will create a better tree. For these reasons, this is considered to be a top-down and greedy approach.
To decide the predictor and the cutpoint for the best split, we consider all the predictors X1 to Xp and all the possible values of cutpoints for each of those predictors, and find the combination such that the resulting tree has the smallest RSS. For example, for any kth predictor and a value s for the cutpoint which divides the predictor space into regions R1 and R2, we want to minimize the equation,
Thereafter, we look for the best predictor and the best cutpoint in order to split R1 or R2, so as to minimize the RSS in each of the resulting regions. This gives us three regions. Again we split one of the three regions using the predictor and cutpoint which minimizes the RSS in each of the resulting regions. We continue this process until our stopping criterion is reached.
Once we have created the regions R1 to RK of the predictor space, we predict the outcome for a test observation using the mean of the training outcomes from the region to which the test observation belongs.
A drawback of the above described process is that the resulting tree might be pretty complex and overfit the training data, thereby leading to poor performance on the test data. Pruning is the process of converting the internal nodes of a large tree T0 into terminal nodes, in order to obtain a subtree with the lowest test error rate. Cost complexity pruning (a.k.a weakest link pruning) is one of the ways to do it.
In this method, for each value of the non negative regularization parameter α, there corresponds a subtree T ⊂ T0, such that
is minimized. When α = 0, the above formula only measures the training error and hence, the subtree T is simply T0. However, as α goes on increasing, there is a penalty for every additional terminal node, and as a result, the above equation will tend to be minimized for smaller trees. The best value for α can be selected using cross validation. Finally, we select the subtree corresponding to the chosen value of α.
To summarize the algorithm for building a regression tree:
Step 1: Build a large tree using recursive binary splitting on the training data, stopping only once the stopping criterion is reached (for example: every terminal node has fewer than some pre-specified number of observations).
Step 2: Find the best subtree corresponding to every α using cost complexity pruning.
Step 3: Choose best value of α using K – cross validation.
- Divide the training data into K folds. (example: K = 5)
- For every k = 1,2,…,K, repeat steps 1 and 2 for every fold except the kth fold.
- Calculate the mean squared prediction error for the data in the remaining kth fold, as a function of α. [This gives us a K x α matrix of mean squared prediction errors].
- Average the results for every value of α and select the α that minimizes the average error.
Step 4: Select the subtree that corresponds to the chosen value of α.
Thanks for reading!
James, G., Witten, D., Hastie, T., and Tibshirani,R., An Introduction to Statistical Learning with Applications in R (Springer)