Basic Bagging

Author

Brett Devine

Published

September 1, 2021

In this document we play around with growing a basic bagging techniques for a simple regression problem. We have a target feature \(Y\) and use a single predictor, \(X\). The data has been simulated so that we can easily visualize the “true” relationship between \(Y\) and \(X\) that exist in the population. This relationship is super-imposed on a scatter plot allowing us to clearly see the role of noise.

The code below generates the simulated data and plots the relationship for us to clearly see.

set.seed(1234)
sim.size = 300
X = rnorm(sim.size, mean = 3, sd = 1.75)
Z = rgamma(sim.size, shape = 1)

TRUE_Y = 10 + 2.2*X - 0.92*X^2 + 0.075*X^3
TRUE_Y = ifelse(X > 4.1 & X < 7.3, 1.2*X, TRUE_Y)

ERROR = rnorm(sim.size, mean = 0, sd = 1.5)
Y = TRUE_Y + ERROR
data = as.tibble(data.frame("Y"=Y, "X"=X))
Warning: `as.tibble()` was deprecated in tibble 2.0.0.
ℹ Please use `as_tibble()` instead.
ℹ The signature and semantics have changed, see `?as_tibble`.
Ybar = mean(data$Y)

library(latex2exp)
colors = c("E[Y|X]" = "deepskyblue", "Mean Y" = "red")
data %>%
  ggplot(aes(x=X,y=Y)) + 
  geom_line(aes(x=X, y=TRUE_Y, color="E[Y|X]"), size=3, alpha=0.4) +
  geom_line(aes(x=X, y=Ybar, color="Mean Y"), size = 1) +
  geom_point(pch=21, color="black", bg="gray", size=2) +
  labs(y=expression(Y), x=expression(X)) +
  ggthemes::theme_clean() +
  scale_color_manual(values = colors)
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.

In the above plot, the thick transparent blue line represents the true relationship between \(X\) and \(Y\) that would enable us to make the best predictions. However, there is always noise and it affects the relationship between \(X\) and \(Y\) such that outcomes are scattered all around the blue line.

We’re going to use a machine learning method known as a decision tree (or a regression tree when applied to regression problems). This method will use a simple algorithm to gradually build a set of if-then rules which will allow us to predict \(Y\) from the value of \(X\). As the rules increase and the model gets more complex, the predicted relationship gets closer to the true one. However, if we allow the model to grow too complex (come up with too many rules based on the 1 sample), it will begin to “fit to the noise” which is overfitting. We start making rules that we shouldn’t be.

Decision trees are based on rules. For example if \(X > 5\) then predict \(Y = 10\) but if \(X \leq 5\) then predict \(Y = 7\). We stack lots of these rules in a tree-like structure to build a prediction machine. The question point, i.e., \(X > 5\), is often referred to as a threshold. We seek to choose these thresholds (and thereby build the tree) in a way that reduces prediction error with each new rule. We can measure prediction error a number of ways, but we’ll use SSE (sum of squared errors).

We can test many values of \(X\) to determine which threshold point would decrease SSE the most.

In the code below, we use the rpart.control function to control the rpart() tree building function in R. This executes the algorithm and we can change hyperparameters to control the tree’s growth, or “prune” its branches after its growth (more on this later).

The notebook is setup so that you can edit the tree growth hyperparameters and then train the tree and then view a picture of it, along with the predictions it makes on the scatter plot.

Above we have the ability to compare the model’s RMSE to the sample mean and sample standard deviatio of \(Y\). The size of the mean compared to our error gives us context into the size of our error. If its a very small percentage of the mean, then our errors are relatively small. Additionally, the standard deviation is already a measure of error (essentially the RMSE of the mean) and spread in the data. If our model’s RMSE is not significantly smaller than the standard deviation, then we’re not reducing the uncertainty in our predictions. Each time we increase the complexity, we should be looking for the RMSE to decrease.

Below we can view an image of the tree that was learned along with the rules it learned. Additionally, we see the predicted values (red line) generated by the tree compared to the actual values (points) and the true relationship (blue line).

Creating some test data.

# Create some new unseen data first
set.seed(1234)
sim.size = 300
X.text = rnorm(sim.size, mean = 3, sd = 1.75)

TRUE_Y.test = 10 + 2.2*X - 0.92*X^2 + 0.075*X^3
TRUE_Y.test = ifelse(X > 4.1 & X < 7.3, 1.2*X, TRUE_Y.test)

ERROR.test = rnorm(sim.size, mean = 0, sd = 1.5)
Y.test = TRUE_Y + ERROR
testing = as.tibble(data.frame("Y"=Y, "X"=X))

Bagged Linear Regression

Here we will use a simple linear regression model on the data, but employ bagging (bootstrap aggregation). We will bootstrap 50 different training sets and perform a simple linear regression on each. The predictions resulting from each bootstrap will be plotted against the data, as well as the average (aggregation) of those bootstrapped models.

boots = data.frame("X"=NULL, "predictions"=NULL, "modelNo"=NULL)
boot.rmse = c()
num.boots = 50
for(i in 1:num.boots) {
  B = data[sample(c(1:nrow(data)), nrow(data), replace = TRUE), ]
  m = lm(Y ~ poly(X,4), data = B)
  X.vals = data.frame("X"=seq(min(data$X), max(data$X), 0.5))
  boots.append = data.frame("X"=X.vals, "predictions"=predict(m, X.vals), "model"=i)
  boots = boots %>% bind_rows(boots.append)
  boot.test = predict(m, newdata = testing)
  boot.rmse[i] = caret::RMSE(boot.test, testing$Y)
}
avg.mod = boots %>% group_by(X) %>% summarize(predictions = mean(predictions))
mean(boot.rmse)
[1] 1.688976
ggplot() +
  geom_point(data = data, aes(x=X,y=Y), pch=21, color="black", bg="gray", size=2) +
  geom_line(aes(x=X,y=TRUE_Y), color = "deepskyblue", alpha = 0.4, size =3) +
  geom_line(data = boots, aes(x=X,y=predictions, group=model), alpha=0.3)  +
  geom_line(data = avg.mod, aes(x=X,y=predictions), color = "red", size = 2, alpha = 0.8)

  theme_clean()
List of 94
 $ line                      :List of 6
  ..$ colour       : chr "black"
  ..$ linewidth    : num 0.545
  ..$ linetype     : num 1
  ..$ lineend      : chr "butt"
  ..$ arrow        : logi FALSE
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_line" "element"
 $ rect                      :List of 5
  ..$ fill         : chr "white"
  ..$ colour       : chr "black"
  ..$ linewidth    : num 0.545
  ..$ linetype     : num 1
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_rect" "element"
 $ text                      :List of 11
  ..$ family       : chr "sans"
  ..$ face         : chr "plain"
  ..$ colour       : chr "black"
  ..$ size         : num 12
  ..$ hjust        : num 0.5
  ..$ vjust        : num 0.5
  ..$ angle        : num 0
  ..$ lineheight   : num 0.9
  ..$ margin       : 'margin' num [1:4] 0points 0points 0points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : logi FALSE
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ title                     : NULL
 $ aspect.ratio              : NULL
 $ axis.title                :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : num 10
  ..$ hjust        : NULL
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.title.x              :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 1
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 3points 0points 0points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.title.x.top          :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 0
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 0points 3points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.title.x.bottom       : NULL
 $ axis.title.y              :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 1
  ..$ angle        : num 90
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 3points 0points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.title.y.left         : NULL
 $ axis.title.y.right        :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 0
  ..$ angle        : num -90
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 0points 0points 3points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.text                 :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : chr "black"
  ..$ size         : num 9
  ..$ hjust        : NULL
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.text.x               :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 1
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 2.4points 0points 0points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.text.x.top           :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 0
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 0points 2.4points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.text.x.bottom        : NULL
 $ axis.text.y               :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : num 1
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 2.4points 0points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.text.y.left          : NULL
 $ axis.text.y.right         :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : num 0
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 0points 0points 2.4points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ axis.ticks                :List of 6
  ..$ colour       : NULL
  ..$ linewidth    : NULL
  ..$ linetype     : NULL
  ..$ lineend      : NULL
  ..$ arrow        : logi FALSE
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_line" "element"
 $ axis.ticks.x              : NULL
 $ axis.ticks.x.top          : NULL
 $ axis.ticks.x.bottom       : NULL
 $ axis.ticks.y              : NULL
 $ axis.ticks.y.left         : NULL
 $ axis.ticks.y.right        : NULL
 $ axis.ticks.length         : 'simpleUnit' num 3points
  ..- attr(*, "unit")= int 8
 $ axis.ticks.length.x       : NULL
 $ axis.ticks.length.x.top   : NULL
 $ axis.ticks.length.x.bottom: NULL
 $ axis.ticks.length.y       : NULL
 $ axis.ticks.length.y.left  : NULL
 $ axis.ticks.length.y.right : NULL
 $ axis.line                 : list()
  ..- attr(*, "class")= chr [1:2] "element_blank" "element"
 $ axis.line.x               :List of 6
  ..$ colour       : chr "black"
  ..$ linewidth    : num 0.5
  ..$ linetype     : chr "solid"
  ..$ lineend      : NULL
  ..$ arrow        : logi FALSE
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_line" "element"
 $ axis.line.x.top           : NULL
 $ axis.line.x.bottom        : NULL
 $ axis.line.y               :List of 6
  ..$ colour       : chr "black"
  ..$ linewidth    : num 0.5
  ..$ linetype     : chr "solid"
  ..$ lineend      : NULL
  ..$ arrow        : logi FALSE
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_line" "element"
 $ axis.line.y.left          : NULL
 $ axis.line.y.right         : NULL
 $ legend.background         :List of 5
  ..$ fill         : NULL
  ..$ colour       : chr "black"
  ..$ linewidth    : NULL
  ..$ linetype     : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_rect" "element"
 $ legend.margin             : 'margin' num [1:4] 6points 6points 6points 6points
  ..- attr(*, "unit")= int 8
 $ legend.spacing            : 'simpleUnit' num 12points
  ..- attr(*, "unit")= int 8
 $ legend.spacing.x          : NULL
 $ legend.spacing.y          : NULL
 $ legend.key                :List of 5
  ..$ fill         : chr "white"
  ..$ colour       : logi NA
  ..$ linewidth    : NULL
  ..$ linetype     : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_rect" "element"
 $ legend.key.size           : 'simpleUnit' num 1.2lines
  ..- attr(*, "unit")= int 3
 $ legend.key.height         : NULL
 $ legend.key.width          : NULL
 $ legend.text               :List of 11
  ..$ family       : chr "sans"
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : num 11
  ..$ hjust        : NULL
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ legend.text.align         : NULL
 $ legend.title              :List of 11
  ..$ family       : chr "sans"
  ..$ face         : chr "bold"
  ..$ colour       : NULL
  ..$ size         : num 12
  ..$ hjust        : num 0
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ legend.title.align        : NULL
 $ legend.position           : chr "right"
 $ legend.direction          : NULL
 $ legend.justification      : chr "center"
 $ legend.box                : NULL
 $ legend.box.just           : NULL
 $ legend.box.margin         : 'margin' num [1:4] 0cm 0cm 0cm 0cm
  ..- attr(*, "unit")= int 1
 $ legend.box.background     : list()
  ..- attr(*, "class")= chr [1:2] "element_blank" "element"
 $ legend.box.spacing        : 'simpleUnit' num 12points
  ..- attr(*, "unit")= int 8
 $ panel.background          : list()
  ..- attr(*, "class")= chr [1:2] "element_blank" "element"
 $ panel.border              : list()
  ..- attr(*, "class")= chr [1:2] "element_blank" "element"
 $ panel.spacing             : 'simpleUnit' num 6points
  ..- attr(*, "unit")= int 8
 $ panel.spacing.x           : NULL
 $ panel.spacing.y           : NULL
 $ panel.grid                :List of 6
  ..$ colour       : NULL
  ..$ linewidth    : NULL
  ..$ linetype     : NULL
  ..$ lineend      : NULL
  ..$ arrow        : logi FALSE
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_line" "element"
 $ panel.grid.major          : NULL
 $ panel.grid.minor          : list()
  ..- attr(*, "class")= chr [1:2] "element_blank" "element"
 $ panel.grid.major.x        : list()
  ..- attr(*, "class")= chr [1:2] "element_blank" "element"
 $ panel.grid.major.y        :List of 6
  ..$ colour       : chr "gray"
  ..$ linewidth    : NULL
  ..$ linetype     : chr "dotted"
  ..$ lineend      : NULL
  ..$ arrow        : logi FALSE
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_line" "element"
 $ panel.grid.minor.x        : NULL
 $ panel.grid.minor.y        : NULL
 $ panel.ontop               : logi FALSE
 $ plot.background           :List of 5
  ..$ fill         : NULL
  ..$ colour       : chr "black"
  ..$ linewidth    : NULL
  ..$ linetype     : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_rect" "element"
 $ plot.title                :List of 11
  ..$ family       : NULL
  ..$ face         : chr "bold"
  ..$ colour       : NULL
  ..$ size         : num 14
  ..$ hjust        : num 0
  ..$ vjust        : num 1
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 0points 6points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ plot.title.position       : chr "panel"
 $ plot.subtitle             :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : num 13
  ..$ hjust        : num 0
  ..$ vjust        : num 1
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 0points 0points 6points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ plot.caption              :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : 'rel' num 0.8
  ..$ hjust        : num 1
  ..$ vjust        : num 1
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 6points 0points 0points 0points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ plot.caption.position     : chr "panel"
 $ plot.tag                  :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : 'rel' num 1.2
  ..$ hjust        : num 0.5
  ..$ vjust        : num 0.5
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ plot.tag.position         : chr "topleft"
 $ plot.margin               : 'margin' num [1:4] 6points 6points 6points 6points
  ..- attr(*, "unit")= int 8
 $ strip.background          :List of 5
  ..$ fill         : NULL
  ..$ colour       : NULL
  ..$ linewidth    : NULL
  ..$ linetype     : num 0
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_rect" "element"
 $ strip.background.x        : NULL
 $ strip.background.y        : NULL
 $ strip.clip                : chr "inherit"
 $ strip.placement           : chr "inside"
 $ strip.text                :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : 'rel' num 0.8
  ..$ hjust        : NULL
  ..$ vjust        : NULL
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : 'margin' num [1:4] 4.8points 4.8points 4.8points 4.8points
  .. ..- attr(*, "unit")= int 8
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ strip.text.x              :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : num 0.5
  ..$ angle        : NULL
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ strip.text.y              :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : NULL
  ..$ angle        : num -90
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi FALSE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 $ strip.switch.pad.grid     : 'simpleUnit' num 3points
  ..- attr(*, "unit")= int 8
 $ strip.switch.pad.wrap     : 'simpleUnit' num 3points
  ..- attr(*, "unit")= int 8
 $ strip.text.y.left         :List of 11
  ..$ family       : NULL
  ..$ face         : NULL
  ..$ colour       : NULL
  ..$ size         : NULL
  ..$ hjust        : NULL
  ..$ vjust        : NULL
  ..$ angle        : num 90
  ..$ lineheight   : NULL
  ..$ margin       : NULL
  ..$ debug        : NULL
  ..$ inherit.blank: logi TRUE
  ..- attr(*, "class")= chr [1:2] "element_text" "element"
 - attr(*, "class")= chr [1:2] "theme" "gg"
 - attr(*, "complete")= logi TRUE
 - attr(*, "validate")= logi TRUE

Smaller Less Flexible Trees

boots = data.frame("X"=NULL, "predictions"=NULL, "modelNo"=NULL)
boot.rmse = c()
num.boots = 100
for(i in 1:num.boots) {
  B = data[sample(c(1:nrow(data)), nrow(data), replace = TRUE), ]
  m = rpart(Y ~ X, data = B,
            cp = 0.0,
            minsplit = 1,
            minbucket = 1,
            maxdepth = 2)
  X.vals = data.frame("X"=seq(min(data$X), max(data$X), 0.1))
  boots.append = data.frame("X"=X.vals, "predictions"=predict(m, X.vals), "model"=i)
  boots = boots %>% bind_rows(boots.append)
  boot.test = predict(m, newdata = testing)
  boot.rmse[i] = caret::RMSE(boot.test, testing$Y)
}
avg.mod = boots %>%
  group_by(X) %>%
  summarize(predictions = mean(predictions))
mean(boot.rmse)
[1] 1.624439
ggplot() +
  geom_point(data = data, aes(x=X,y=Y), pch=21, color="black", bg="gray", size=2) +
  geom_line(aes(x=X,y=TRUE_Y), color = "deepskyblue", alpha = 0.4, size =3) +
  geom_line(data = boots, aes(x=X,y=predictions, group=model), alpha=0.3) +
  geom_line(data = avg.mod, aes(x=X, y=predictions), size=2, color="red", alpha=0.7) +
  labs(title = "Bootstrap Aggregated Shallow/Stiff Trees") +
  theme_clean()

Highly Flexible Overfitted Trees

boots = data.frame("X"=NULL, "predictions"=NULL, "modelNo"=NULL)
boot.rmse = c()
num.boots = 100
for(i in 1:num.boots) {
  B = data[sample(c(1:nrow(data)), nrow(data), replace = TRUE), ]
  m = rpart(Y ~ X, data = B,
            cp = 0.007,
            minsplit = 1,
            minbucket = 1,
            maxdepth = 15)
  X.vals = data.frame("X"=seq(min(data$X), max(data$X), 0.1))
  boots.append = data.frame("X"=X.vals, "predictions"=predict(m, X.vals), "model"=i)
  boots = boots %>% bind_rows(boots.append)
  boot.test = predict(m, newdata = testing)
  boot.rmse[i] = caret::RMSE(boot.test, testing$Y)
}
avg.mod = boots %>%
  group_by(X) %>%
  summarize(predictions = mean(predictions))
mean(boot.rmse)
[1] 1.389182
ggplot() +
  geom_point(data = data, aes(x=X,y=Y), pch=21, color="black", bg="gray", size=2) +
  geom_line(aes(x=X,y=TRUE_Y), color = "deepskyblue", alpha = 0.4, size =3) +
  geom_line(data = boots, aes(x=X,y=predictions, group=model), alpha=0.3) +
  geom_line(data = avg.mod, aes(x=X, y=predictions), size=2, color="red", alpha=0.7) +
  labs(title = "Bootstrap Aggregated Deep/Flexible Trees") +
  theme_clean()

Randomly Altering cp with each bootstrap iteration.

boots = data.frame("X"=NULL, "predictions"=NULL, "modelNo"=NULL)
num.boots = 200
for(i in 1:num.boots) {
  B = data[sample(c(1:nrow(data)), nrow(data), replace = TRUE), ]
  m = rpart(Y ~ X, data = B,
            cp = sample(seq(0.0, 0.075, 0.0001), 1),
            minsplit = 1,
            minbucket = 1,
            maxdepth = 15)
  X.vals = data.frame("X"=seq(min(data$X), max(data$X), 0.1))
  boots.append = data.frame("X"=X.vals, "predictions"=predict(m, X.vals), "model"=i)
  boots = boots %>% bind_rows(boots.append)
}
avg.mod = boots %>%
  group_by(X) %>%
  summarize(predictions = mean(predictions))
ggplot() +
  geom_point(data = data, aes(x=X,y=Y), pch=21, color="black", bg="gray", size=2) +
  geom_line(aes(x=X,y=TRUE_Y), color = "deepskyblue", alpha = 0.4, size =3) +
  geom_line(data = boots, aes(x=X,y=predictions, group=model), alpha=0.3) +
  geom_line(data = avg.mod, aes(x=X, y=predictions), size=2, color="red", alpha=0.7) +
  labs(title = "Bootstrap Aggregated Trees with random cp-pruning") +
  theme_clean()

library(ipred)
boost_train = function(df, number, shrink=1) {
# This function will take in our data
# it will create a sequence of <number> # weak learners (trees) and return the # sequence of weak learners at the end. #
# Parameters
# -----------------------------------
#       df : the data frame containing Y, X.
#   number : the number of weak learners
#            to use in sequence
#   shrink : the shrinkage in error updating
#            and learning (0, 1). Default
#            value is 1.
#
# Returns
# -----------------------------------
#   wlearn : list of weak learners trained
#            on the data df.
  
  df$error = df$Y # initialize the error = Y
  wlearn = list() # create empty list to contain learners predictions = rep(0, nrow(df)) # create prediction vector. 
  for (k in 1:number) {
  # We will use simple decision trees, which can be 
    # "bagged" if we set nbagg > 1.
    wlearn[[k]] = rpart(formula = error ~ X,
                    data = df,
                    maxdepth=1, cp = 0) 
    # Update the errors for the next weak learner.
    df$error = df$error - shrink * predict(wlearn[[k]], df, type="vector")
  }
# provide the weak learn list to ensure 
# it is returned.
  wlearn
}
boost_predict = function(learners, data, shrink=1) {
  # Applies the learners to the data to get
  # the strong learner prediction.
  
  # Parameters
  # -----------------------------------
  #   learners : list of weak learners
  #       data : data to input into the 
  #              strong learner
  #
  # Returns
  # -----------------------------------
  #   pred : vector of predictions from 
  #          the strong learner.
  
  N = length(data) # determine # of rows.
  NBoost = length(learners) # number of learners
  pred = rep(0, N) # initialize predictions @ 0.
  for (k in 1:NBoost) {
    # Use predict function iteratively
    # to obtain combined prediction.
    pred = pred + shrink*predict(learners[[k]],
                          newdata = data,
                          type="vector")
  }
  # Return the prediction vector.
  pred
}
learners = boost_train(data, 500, 1)
boost.pred = boost_predict(learners, X.vals, 1)
ggplot() +
  geom_line(aes(x=X, y=TRUE_Y), color="deepskyblue", alpha = 0.4, size = 3) +
  geom_point(data = data, aes(x=X,y=Y),pch=21, color="black", bg="gray", size=2) +
  geom_line(aes(x=X.vals$X, y=boost.pred), color="red", size = 2, alpha = 0.7)

Now use caret to train a tuned regression tree, random forest, and a gradient boosting machine on the data and compare

ctrl = caret::trainControl(method = "cv", number = 5)
dt = train(Y ~ .,
           data = data,
           method = "rpart",
           trControl = ctrl,
           tuneGrid = expand.grid(cp = seq(0, 1, 0.01)),
           control = rpart.control(maxdepth = 9, minsplit = 1))
Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
There were missing values in resampled performance measures.
plot(dt)

data("USArrests")
ctrl = caret::trainControl(method = "cv", number = 10)
rf = caret::train(Murder ~ .,
                  data = USArrests,
                  method = 'rf',
                  trainControl = ctrl,
                  tuneGrid = expand.grid(mtry = c(1,2,3)))
rf
Random Forest 

50 samples
 3 predictor

No pre-processing
Resampling: Bootstrapped (25 reps) 
Summary of sample sizes: 50, 50, 50, 50, 50, 50, ... 
Resampling results across tuning parameters:

  mtry  RMSE      Rsquared   MAE     
  1     2.786967  0.6287161  2.305333
  2     2.730548  0.6380341  2.229039
  3     2.874331  0.6040413  2.323361

RMSE was used to select the optimal model using the smallest value.
The final value used for the model was mtry = 2.