Supervised Learning

Supervised Learning

  • Supervised learning has the goal of making predictions with a set of known labels for the response variable.
  • In unsupervised learning, we try to find structure in the data of the response variable without predetermined labels.

Goal: predict the personality type of each character in Animal Crossing

Data set: Animal Crossing

Source: VillagerDB, MetaCritic, and TidyTuesday

library("caret")
library("randomForest")
library("tidymodels")
library("tidyverse")

# critic <- readr::read_tsv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-05/critic.tsv')
# user_reviews <- readr::read_tsv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-05/user_reviews.tsv')
# items <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-05/items.csv')
villagers <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-05/villagers.csv')

Exploratory Data Analysis

dim(villagers)
## [1] 391  11
str(villagers)
## tibble [391 x 11] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ row_n      : num [1:391] 2 3 4 6 7 8 9 10 11 13 ...
##  $ id         : chr [1:391] "admiral" "agent-s" "agnes" "al" ...
##  $ name       : chr [1:391] "Admiral" "Agent S" "Agnes" "Al" ...
##  $ gender     : chr [1:391] "male" "female" "female" "male" ...
##  $ species    : chr [1:391] "bird" "squirrel" "pig" "gorilla" ...
##  $ birthday   : chr [1:391] "1-27" "7-2" "4-21" "10-18" ...
##  $ personality: chr [1:391] "cranky" "peppy" "uchi" "lazy" ...
##  $ song       : chr [1:391] "Steep Hill" "DJ K.K." "K.K. House" "Steep Hill" ...
##  $ phrase     : chr [1:391] "aye aye" "sidekick" "snuffle" "Ayyeeee" ...
##  $ full_id    : chr [1:391] "villager-admiral" "villager-agent-s" "villager-agnes" "villager-al" ...
##  $ url        : chr [1:391] "https://villagerdb.com/images/villagers/thumb/admiral.98206ee.png" "https://villagerdb.com/images/villagers/thumb/agent-s.96c789b.png" "https://villagerdb.com/images/villagers/thumb/agnes.9f51f32.png" "https://villagerdb.com/images/villagers/thumb/al.1e17090.png" ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   row_n = col_double(),
##   ..   id = col_character(),
##   ..   name = col_character(),
##   ..   gender = col_character(),
##   ..   species = col_character(),
##   ..   birthday = col_character(),
##   ..   personality = col_character(),
##   ..   song = col_character(),
##   ..   phrase = col_character(),
##   ..   full_id = col_character(),
##   ..   url = col_character()
##   .. )
#table(villagers$name)
#length(unique(villagers$name))
table(villagers$personality)
## 
## cranky   jock   lazy normal  peppy   smug snooty   uchi 
##     55     55     60     59     49     34     55     24

Personality across Gender

villagers %>%
  ggplot(aes(x = personality, fill = gender)) +
  geom_bar(stat = "count", position = "dodge") +
  labs(title = "Getting to Know the Animal Crossing Villagers",
       subtitle = "and judging their personalities",
       caption = "Source: VillagerDB") +
  theme_minimal()

Personality across Species

villagers %>%
  ggplot(aes(x = personality, fill = species)) +
  geom_bar(stat = "count", position = "stack") +
  labs(title = "Getting to Know the Animal Crossing Villagers",
       subtitle = "and judging their personalities",
       caption = "Source: VillagerDB") +
  theme_minimal()

Predictor Variables

table(villagers$gender)
## 
## female   male 
##    187    204
table(villagers$species)
## 
## alligator  anteater      bear      bird      bull       cat   chicken       cow 
##         7         7        15        13         6        23         9         4 
##       cub      deer       dog      duck     eagle  elephant      frog      goat 
##        16        10        16        17         9        11        18         8 
##   gorilla   hamster     hippo     horse  kangaroo     koala      lion    monkey 
##         9         8         7        15         8         9         7         8 
##     mouse   octopus   ostrich   penguin       pig    rabbit     rhino     sheep 
##        15         3        10        13        15        20         6        13 
##  squirrel     tiger      wolf 
##        18         7        11

Extracting Birth Month

villagers <- villagers %>%
  separate(birthday, c("birth_month", "birth_day"), remove = FALSE)

villagers$birth_month_factor <- factor(villagers$birth_month, 
                                levels = 1:12)

table(villagers$birth_month_factor)
## 
##  1  2  3  4  5  6  7  8  9 10 11 12 
## 32 29 33 29 31 33 35 36 32 37 30 34
villagers %>%
  ggplot(aes(x = personality, fill = birth_month_factor)) +
  geom_bar(color = "black", stat = "count", position = "stack") +
  labs(title = "Getting to Know the Animal Crossing Villagers",
       subtitle = "and judging their personalities",
       caption = "Source: VillagerDB") +
  theme_minimal()

villagers %>%
  ggplot(aes(x = personality, fill = birth_month_factor)) +
  geom_bar(color = "black", stat = "count", position = "fill") +
  labs(title = "Getting to Know the Animal Crossing Villagers",
       subtitle = "and judging their personalities",
       caption = "Source: VillagerDB",
       y = "proportion") +
  scale_fill_manual(name = "Month of Birth",
                    labels = c("January", "February", "March", "April",
                                 "May", "June", "July", "August",
                                 "September", "October", "November", "December"),
                    values=sample(c("#4b48c9", "#5c90a8", "#86d5fe", "#c0d1ef",
                             "#edd1a2", "#948572", "#f8c79e", "#eec688",
                             "#642200", "#764936", "#fd974d", "#7b4a3c"))) +
  theme_minimal()

(color palette generated at CSS Drive)

model formula

  • response variable: personality
  • predictor variables: gender, species, birth_month
  • model formula: personality ~ gender + species + birth_month
predictor_variables <- c("gender", "species", "birth_month")
model_formula <- paste("personality~", paste(sprintf("`%s`", predictor_variables), collapse="+"))

Data Split

villagers_split <- initial_split(villagers)
villagers_train <- training(villagers_split)
villagers_test  <- testing(villagers_split)

Random Forests

“Random forest models are ensembles of decision trees. A large number of decision tree models are created for the ensemble based on slightly different versions of the training set. When creating the individual decision trees, the fitting process encourages them to be as diverse as possible. The collection of trees are combined into the random forest model and, when a new sample is predicted, the votes from each tree are used to calculate the final predicted value for the new sample.” —tidymodels.org

Define the Forest

random_forest_model <- 
  rand_forest(trees = 1000) %>% 
  set_engine("ranger") %>% 
  set_mode("classification")

Fitting the Forest

# CAUTION: at the time of filming, this code chunk did not work (model formula parser threw an error)
random_forest_fit <-
  random_forest_model %>%
  fit(personality ~ gender + species + birth_month, data = villagers_train)

random_forest_fit

Visualizing the Forest

# OLD-FASHIONED WAY with the caret package
model_rf <- caret::train(personality ~ gender + species + birth_month,
                         data = villagers_train, method = "rf")
model_rf
## Random Forest 
## 
## 294 samples
##   3 predictor
##   8 classes: 'cranky', 'jock', 'lazy', 'normal', 'peppy', 'smug', 'snooty', 'uchi' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 294, 294, 294, 294, 294, 294, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.2550533  0.1357101
##   24    0.2889225  0.1759381
##   46    0.2892118  0.1774835
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 46.
model_rpart <- caret::train(personality ~ gender + species + birth_month,
                         data = villagers_train, method = "rpart")
model_rpart
## CART 
## 
## 294 samples
##   3 predictor
##   8 classes: 'cranky', 'jock', 'lazy', 'normal', 'peppy', 'smug', 'snooty', 'uchi' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 294, 294, 294, 294, 294, 294, ... 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa     
##   0.02024291  0.2975188  0.18582216
##   0.02429150  0.2876263  0.17374970
##   0.19028340  0.2017836  0.07541971
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.02024291.
#source:  https://shiring.github.io/machine_learning/2017/03/16/rf_plot_ggraph

library("ggraph")
library("igraph")

tree_func <- function(final_model, 
                      tree_num) {
  
  # get tree by index
  tree <- randomForest::getTree(final_model, 
                                k = tree_num, 
                                labelVar = TRUE) %>%
    tibble::rownames_to_column() %>%
    # make leaf split points to NA, so the 0s won't get plotted
    mutate(`split point` = ifelse(is.na(prediction), `split point`, NA))
  
  # prepare data frame for graph
  graph_frame <- data.frame(from = rep(tree$rowname, 2),
                            to = c(tree$`left daughter`, tree$`right daughter`))
  
  # convert to graph and delete the last node that we don't want to plot
  graph <- graph_from_data_frame(graph_frame) %>%
    delete_vertices("0")
  
  # set node labels
  V(graph)$node_label <- gsub("_", " ", as.character(tree$`split var`))
  V(graph)$leaf_label <- as.character(tree$prediction)
  V(graph)$split <- as.character(round(tree$`split point`, digits = 2))
  
  # plot
  plot <- ggraph(graph, 'dendrogram') + 
    theme_bw() +
    geom_edge_link() +
    geom_node_point() +
    geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
    geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
    geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE, 
                    repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
    theme(panel.grid.minor = element_blank(),
          panel.grid.major = element_blank(),
          panel.background = element_blank(),
          plot.background = element_rect(fill = "white"),
          panel.border = element_blank(),
          axis.line = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks = element_blank(),
          axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          plot.title = element_text(size = 18))
  
  print(plot)
}

tree_num <- which(model_rf$finalModel$forest$ndbigtree == min(model_rf$finalModel$forest$ndbigtree))

tree_func(final_model = model_rf$finalModel, tree_num)

Predictions

predictions <- predict(model_rf, newdata = villagers_test)
confusionMatrix(predictions, factor(villagers_test$personality))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction cranky jock lazy normal peppy smug snooty uchi
##     cranky      5    1    0      0     0    1      0    0
##     jock        1   10    2      0     0    2      0    0
##     lazy        0    5   13      0     0    2      0    0
##     normal      0    0    0     12     1    0      1    2
##     peppy       0    0    0      2     9    0      3    0
##     smug        0    1    0      0     0    3      0    0
##     snooty      0    0    0      3     0    0     13    2
##     uchi        0    0    0      0     0    0      0    3
## 
## Overall Statistics
##                                           
##                Accuracy : 0.701           
##                  95% CI : (0.5996, 0.7898)
##     No Information Rate : 0.1753          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.6502          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: cranky Class: jock Class: lazy Class: normal
## Sensitivity                0.83333      0.5882      0.8667        0.7059
## Specificity                0.97802      0.9375      0.9146        0.9500
## Pos Pred Value             0.71429      0.6667      0.6500        0.7500
## Neg Pred Value             0.98889      0.9146      0.9740        0.9383
## Prevalence                 0.06186      0.1753      0.1546        0.1753
## Detection Rate             0.05155      0.1031      0.1340        0.1237
## Detection Prevalence       0.07216      0.1546      0.2062        0.1649
## Balanced Accuracy          0.90568      0.7629      0.8907        0.8279
##                      Class: peppy Class: smug Class: snooty Class: uchi
## Sensitivity               0.90000     0.37500        0.7647     0.42857
## Specificity               0.94253     0.98876        0.9375     1.00000
## Pos Pred Value            0.64286     0.75000        0.7222     1.00000
## Neg Pred Value            0.98795     0.94624        0.9494     0.95745
## Prevalence                0.10309     0.08247        0.1753     0.07216
## Detection Rate            0.09278     0.03093        0.1340     0.03093
## Detection Prevalence      0.14433     0.04124        0.1856     0.03093
## Balanced Accuracy         0.92126     0.68188        0.8511     0.71429

Related