Three new things in tidymodels

Max Kuhn (Posit PBC)
Hannah Frick (Posit PBC)
Emil Hvitfeldt (Posit PBC)
Qiushi Yan (Vanderbilt University)

Censored regression in tidymodels

Specification of survival models

  • New model type proportional_hazards()
  • New mode "censored regression"
  • New engines for
    • Parametric models
    • Semi-parametric models
    • Tree-based models
  • Formula interface for all models, including stratification

Make a Surv() object at the start

library(tidymodels)
library(censored)

data("time_to_million")

time_to_million <- time_to_million %>%
  mutate(surv = Surv(time, event), .keep = "unused") %>% 
  select(-title, -released)
glimpse(time_to_million)
#> Rows: 551
#> Columns: 46
#> $ released_theaters <dbl> 3427, 102, 3018, 349, 2471, 838, 899, 115, 3615, 523, 384, 661, 3589, 158, 3102, 3708, 2139,…
#> $ distributor       <fct> paramount_pi, sony_pictures, warner_bros, lionsgate, entertainmen, focus_features, samuel_go…
#> $ year              <dbl> 2016, 2018, 2018, 2017, 2017, 2018, 2015, 2016, 2017, 2016, 2015, 2017, 2018, 2015, 2018, 20…
#> $ rated             <fct> pg_13, pg, r, pg_13, pg_13, pg_13, pg_13, not_rated, r, r, pg_13, pg, pg_13, pg_13, r, r, r,…
#> $ runtime           <dbl> 103, 102, 130, 106, 89, 107, 121, 152, 104, 98, 99, 104, 90, 97, 117, 136, 104, 109, 91, 120…
#> $ action            <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,…
#> $ drama             <dbl> 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,…
#> $ horror            <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,…
#> $ mystery           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,…
#> $ sci_fi            <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,…
#> $ thriller          <dbl> 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0,…
#> $ comedy            <dbl> 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,…
#> $ history           <dbl> 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ war               <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,…
#> $ family            <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,…
#> $ adventure         <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,…
#> $ romance           <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,…
#> $ crime             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,…
#> $ music             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,…
#> $ biography         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,…
#> $ fantasy           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,…
#> $ musical           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ animation         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ documentary       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ sport             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ western           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ short             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ english           <dbl> 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,…
#> $ hindi             <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ russian           <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,…
#> $ spanish           <dbl> 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ german            <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,…
#> $ french            <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,…
#> $ japanese          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ italian           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,…
#> $ mandarin          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ usa               <dbl> 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,…
#> $ india             <dbl> 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ mexico            <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ uk                <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,…
#> $ france            <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ china             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,…
#> $ canada            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ japan             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ australia         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
#> $ surv              <Surv> <Surv[40 x 2]>

Regularized Cox model

cox_spec <- proportional_hazards(penalty = 0.01) %>%
  set_engine("glmnet")

cox_fit <- 
  cox_spec %>% fit(surv ~ ., data = time_to_million)
tidy(cox_fit)
#> # A tibble: 3,873 × 5
#>    term               step estimate lambda dev.ratio
#>    <chr>             <dbl>    <dbl>  <dbl>     <dbl>
#>  1 released_theaters     2 0.000100  0.747    0.0194
#>  2 released_theaters     3 0.000193  0.680    0.0358
#>  3 released_theaters     4 0.000280  0.620    0.0497
#>  4 released_theaters     5 0.000361  0.565    0.0616
#>  5 released_theaters     6 0.000437  0.515    0.0718
#>  6 released_theaters     7 0.000509  0.469    0.0804
#>  7 released_theaters     8 0.000576  0.427    0.0879
#>  8 released_theaters     9 0.000639  0.389    0.0942
#>  9 released_theaters    10 0.000698  0.355    0.0996
#> 10 released_theaters    11 0.000753  0.323    0.104 
#> # … with 3,863 more rows
autoplot(cox_fit, best_penalty = 0.01)

Available in censored

All for the mode "censored regression".

model engine
bag_tree() rpart
boost_tree() mboost
decision_tree() rpart
decision_tree() partykit
proportional_hazards() survival
proportional_hazards() glmnet
rand_forest() partykit
survival_reg() survival
survival_reg() flexsurv

Prediction with survival models

  • For all models:
    • Survival time via type = "time"
    • Survival probability via type = "survival"
  • Depending on the engine: types "hazard", "quantile", and "linear_pred"

All adhere to the tidymodels prediction guarantee:

  • The predictions are always inside a tibble.
  • The column names and types are unsurprising and predictable.
  • The number of rows in new_data and the output are the same.

Predict survival time

# Pull out some specific movies:
selected_rows <- c(40, 372, 126)

# Evaluate them at these weeks:
at_weeks <- seq(1, 50, by = 1 / 2)

predict(cox_fit, time_to_million[selected_rows, ], type = "time")
#> # A tibble: 3 × 1
#>   .pred_time
#>        <dbl>
#> 1     0.0534
#> 2   110.    
#> 3    28.9

Predict survival probability


pred <- predict(cox_fit, time_to_million[selected_rows, ], type = "survival", time = at_weeks)
pred
#> # A tibble: 3 × 1
#>   .pred            
#>   <list>           
#> 1 <tibble [99 × 2]>
#> 2 <tibble [99 × 2]>
#> 3 <tibble [99 × 2]>
pred$.pred[[1]]
#> # A tibble: 99 × 2
#>    .time .pred_survival
#>    <dbl>          <dbl>
#>  1   1        1.07e- 59
#>  2   1.5      8.73e-101
#>  3   2        7.69e-175
#>  4   2.5      3.81e-220
#>  5   3        3.32e-272
#>  6   3.5      1.37e-317
#>  7   4        0        
#>  8   4.5      0        
#>  9   5        0        
#> 10   5.5      0        
#> # … with 89 more rows

Approximation the survival curve

pred %>%
  mutate(id = factor(selected_rows)) %>%
  tidyr::unnest(cols = .pred)
#> # A tibble: 297 × 3
#>    .time .pred_survival id   
#>    <dbl>          <dbl> <fct>
#>  1   1        1.07e- 59 40   
#>  2   1.5      8.73e-101 40   
#>  3   2        7.69e-175 40   
#>  4   2.5      3.81e-220 40   
#>  5   3        3.32e-272 40   
#>  6   3.5      1.37e-317 40   
#>  7   4        0         40   
#>  8   4.5      0         40   
#>  9   5        0         40   
#> 10   5.5      0         40   
#> # … with 287 more rows

Approximation the survival curve

pred %>%
  mutate(id = factor(selected_rows)) %>%
  tidyr::unnest(cols = .pred) %>%
  ggplot(
    aes(x = .time, y = .pred_survival,
        col = id)
  ) +
  geom_step() +
  labs(x = "Time", y = "Pr[survival]") +
  scale_y_continuous(limits = c(0,1)) +
  theme(legend.position = "top")

Unblinded IDs

  • 40: Ant Man
  • 126: Elvis & Nixon
  • 372: Sorry to Bother You

What’s next?

So far, we’ve done a lot with individual models.

We have more to add:

  • Performance metrics
    • time-dependent ROC curves
    • Brier scores (time-dependent and integrated)
    • and so on.
  • Tighter tidymodels integration for model tuning and resampling

H20 Integration

Changes

  • new parsnip 'h2o' engine for many models.

    • See here for a complete list.
  • The interactions between H2O and tidymodels are minimized so that there are few expensive data transfer penalties.

  • We’ve added som wrappers around H2O functions that make it easier to use within tidymodels.

Fit models on the H2O server

library(agua)
data(ad_data)

set.seed(1)
ad_split <- initial_split(ad_data, strata = Class)
ad_train <- training(ad_split)
ad_test  <- testing(ad_split)

# start h2o server
h2o_start()

lr_spec <- logistic_reg() %>%
  set_engine("h2o")

lr_fit <- lr_spec %>% fit(Class ~ .,  data = ad_train)
lr_fit
#> parsnip model object
#> 
#> Model Details:
#> ==============
#> 
#> H2OBinomialModel: glm
#> Model ID:  GLM_model_R_1667314396292_36193 
#> GLM Model: summary
#>     family  link                               regularization number_of_predictors_total number_of_active_predictors
#> 1 binomial logit Elastic Net (alpha = 0.5, lambda = 0.03768 )                        134                          35
#>   number_of_iterations    training_frame
#> 1                    6 object_oqorahuxjg
#> 
#> Coefficients: glm coefficients
#>                             names coefficients standardized_coefficients
#> 1                       Intercept   -28.586754                 -1.433734
#> 2  ACE_CD143_Angiotensin_Converti     0.000000                  0.000000
#> 3 ACTH_Adrenocorticotropic_Hormon     0.153691                  0.040988
#> 4                             AXL     0.000000                  0.000000
#> 5                     Adiponectin     0.000000                  0.000000
#> 
#> ---
#>            names coefficients standardized_coefficients
#> 130         male     0.000000                  0.000000
#> 131 GenotypeE2E3     0.000000                  0.000000
#> 132 GenotypeE2E4     0.000000                  0.000000
#> 133 GenotypeE3E3     0.000000                  0.000000
#> 134 GenotypeE3E4     0.000000                  0.000000
#> 135 GenotypeE4E4     0.353067                  0.072697
#> 
#> H2OBinomialMetrics: glm
#> ** Reported on training data. **
#> 
#> MSE:  0.07387355
#> RMSE:  0.2717969
#> LogLoss:  0.2661139
#> Mean Per-Class Error:  0.1094004
#> AUC:  0.9583198
#> AUCPR:  0.9279918
#> Gini:  0.9166396
#> R^2:  0.6278653
#> Residual Deviance:  132.5247
#> AIC:  204.5247
#> 
#> Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
#>          Control Impaired    Error     Rate
#> Control      176        5 0.027624   =5/181
#> Impaired      13       55 0.191176   =13/68
#> Totals       189       60 0.072289  =18/249
#> 
#> Maximum Metrics: Maximum metrics at their respective thresholds
#>                         metric threshold      value idx
#> 1                       max f1  0.431653   0.859375  59
#> 2                       max f2  0.286331   0.859155  82
#> 3                 max f0point5  0.515650   0.910714  52
#> 4                 max accuracy  0.431653   0.927711  59
#> 5                max precision  0.901022   1.000000   0
#> 6                   max recall  0.094922   1.000000 159
#> 7              max specificity  0.901022   1.000000   0
#> 8             max absolute_mcc  0.431653   0.813859  59
#> 9   max min_per_class_accuracy  0.319595   0.882353  79
#> 10 max mean_per_class_accuracy  0.349492   0.893321  69
#> 11                     max tns  0.901022 181.000000   0
#> 12                     max fns  0.901022  67.000000   0
#> 13                     max fps  0.011287 181.000000 248
#> 14                     max tps  0.094922  68.000000 159
#> 15                     max tnr  0.901022   1.000000   0
#> 16                     max fnr  0.901022   0.985294   0
#> 17                     max fpr  0.011287   1.000000 248
#> 18                     max tpr  0.094922   1.000000 159
#> 
#> Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`

What’s different?

library(agua)
data(ad_data)

set.seed(1)
ad_split <- initial_split(ad_data, strata = Class)
ad_train <- training(ad_split)
ad_test  <- testing(ad_split)

# start h2o server
h2o_start()

lr_spec <- logistic_reg() %>%
  set_engine("h2o")

lr_fit <- lr_spec %>% fit(Class ~ .,  data = ad_train)
lr_fit

Predict on the test set

predict(lr_fit, ad_test, type = "prob")
#> # A tibble: 84 × 2
#>    .pred_Control .pred_Impaired
#>            <dbl>          <dbl>
#>  1        0.932          0.0675
#>  2        0.821          0.179 
#>  3        0.973          0.0267
#>  4        0.941          0.0588
#>  5        0.383          0.617 
#>  6        0.836          0.164 
#>  7        0.0753         0.925 
#>  8        0.446          0.554 
#>  9        0.614          0.386 
#> 10        0.926          0.0744
#> # … with 74 more rows
library(vip)

# We can also use the native H2O object: 
lr_fit %>%
  extract_fit_engine() %>%
  vip()

Hyperparameter tuning

set.seed(2)
ad_folds <- vfold_cv(ad_train, v = 10, strata = Class)

lr_spec <- logistic_reg(penalty = tune()) %>%
  set_engine("h2o")

ad_rec <- recipe(Class ~ ., data = ad_train) %>%
  step_YeoJohnson(all_numeric_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

lr_wflow <- workflow() %>%
  add_model(lr_spec) %>%
  add_recipe(ad_rec)

set.seed(3)
lr_res <- lr_wflow %>% tune_grid(resamples = ad_folds, grid = 10)

What’s different?

set.seed(2)
ad_folds <- vfold_cv(ad_train, v = 10, strata = Class)

lr_spec <- logistic_reg(penalty = tune()) %>%
  set_engine("h2o")

ad_rec <- recipe(Class ~ ., data = ad_train) %>%
  step_YeoJohnson(all_numeric_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

lr_wflow <- workflow() %>%
  add_model(lr_spec) %>%
  add_recipe(ad_rec)

set.seed(3)
lr_res <- lr_wflow %>% tune_grid(resamples = ad_folds, grid = 10)

Parallel processing is easy to do via R and/or H2O tools. See this vignette.

Use parsnip’s new model type auto_ml()

auto_mod <- auto_ml() %>%
  set_engine("h2o", max_runtime_secs = 300) %>%
  set_mode("classification")

auto_fit <- auto_mod %>% fit(Class ~ ., data = ad_train)

Summarize model performances with get_leaderboard()

auto_fit %>%
  get_leaderboard() %>% 
  slice(1:5)
#> # A tibble: 5 × 7
#>   model_id                                                  auc logloss aucpr mean_per_class_error  rmse   mse
#>   <chr>                                                   <dbl>   <dbl> <dbl>                <dbl> <dbl> <dbl>
#> 1 StackedEnsemble_BestOfFamily_4_AutoML_5_20221101_160842 0.894   0.343 0.820                0.184 0.321 0.103
#> 2 DeepLearning_grid_1_AutoML_5_20221101_160842_model_26   0.884   1.13  0.807                0.182 0.374 0.140
#> 3 StackedEnsemble_BestOfFamily_6_AutoML_5_20221101_160842 0.883   0.355 0.819                0.177 0.323 0.104
#> 4 GBM_grid_1_AutoML_5_20221101_160842_model_9             0.882   0.389 0.759                0.180 0.352 0.124
#> 5 GBM_grid_1_AutoML_5_20221101_160842_model_21            0.880   0.376 0.772                0.177 0.340 0.116

Helper functions for visualizing member models, etc

autoplot(auto_fit, type = "rank", metric = "roc_auc")

Tidy Clustering

Overview

  • tidymodels not support clustering models with addition of tidyclust (soon to be on CRAN)

  • Specify, fit, predict and evaluate clustering models

  • Learning tidyclust should be a breeze if you are already familiar with tidymodels

Specify clustering models

# remotes::install_github("EmilHvitfeldt/tidyclust")
library(tidyclust)

kmeans_spec <- k_means(num_clusters = 4) %>%
  set_engine("stats") %>%
  set_mode("partition")
kmeans_spec
#> K Means Cluster Specification (partition)
#> 
#> Main Arguments:
#>   num_clusters = 4
#> 
#> Computational engine: stats

Fitting clustering models

Notice how tidyclust works fluently with recipes and workflow (notice how we didn’t set an outcome in the recipe)

# Make an unsupervised recipe with just the assay data
ad_rec <- ad_train %>% 
  select(-Class, -Genotype, -male) %>% 
  recipe(formula = ~ .) %>%
  step_YeoJohnson(all_numeric_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

A workflow object binds the pre-processor and the model together for one fitting interface.

kmeans_wflow <- workflow(ad_rec, kmeans_spec)
kmeans_fit <- fit(kmeans_wflow, data = ad_train)
kmeans_fit
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: k_means()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────────────────────────────────────────────
#> 4 Recipe Steps
#> 
#> • step_YeoJohnson()
#> • step_dummy()
#> • step_zv()
#> • step_normalize()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────────────────────────────────────────────
#> K-means clustering with 4 clusters of sizes 69, 66, 68, 46
#> 
#> Cluster means:
#>   ACE_CD143_Angiotensin_Converti ACTH_Adrenocorticotropic_Hormon        AXL Adiponectin Alpha_1_Antichymotrypsin
#> 1                      0.6836882                       0.1528054  0.6497034  -0.1241515               -0.1174433
#> 2                     -0.5493031                      -0.3274418 -0.5512176   0.4283056                0.4878523
#> 3                     -0.6224848                      -0.0762455 -0.6941297  -0.6646911               -0.9951585
#> 4                      0.6827932                       0.3533104  0.8424271   0.5542887                0.9473068
#>   Alpha_1_Antitrypsin Alpha_1_Microglobulin Alpha_2_Macroglobulin Angiopoietin_2_ANG_2 Angiotensinogen
#> 1          -0.2848140            -0.3098983             0.1847187            0.4919176      0.09626379
#> 2           0.5213351             0.5620315             0.1030472           -0.3798560     -0.35010590
#> 3          -0.6641991            -0.9399764            -0.9028277           -0.7969235      0.02288645
#> 4           0.6610779             1.0479847             0.9096863            0.9851953      0.32409803
#>   Apolipoprotein_A_IV Apolipoprotein_A1 Apolipoprotein_A2 Apolipoprotein_B Apolipoprotein_CI Apolipoprotein_CIII
#> 1          -0.3740499        -0.2580867        -0.4134386       -0.3347926        -0.2237942          -0.3815270
#> 2           0.7292428         0.6861236         0.7999672        0.6669666         0.6421302           0.6637349
#> 3          -0.7272732        -0.8551889        -0.8282364       -0.6583244        -0.8100441          -0.6887179
#> 4           0.5898694         0.6668842         0.6967284        0.5184121         0.6118307           0.6380799
#>   Apolipoprotein_D Apolipoprotein_E Apolipoprotein_H B_Lymphocyte_Chemoattractant_BL      BMP_6 Beta_2_Microglobulin
#> 1       -0.2172552        0.4510260       -0.5020307                     -0.04017782 -0.0536867            0.4860300
#> 2        0.3500162       -0.3286519        0.6368637                      0.31717109 -0.1234632           -0.2635367
#> 3       -0.7457368       -0.7601465       -0.6213225                     -0.76400837  0.3122190           -0.9266446
#> 4        0.9260791        0.9186999        0.7577616                      0.73459884 -0.2038683            1.0188953
#>   Betacellulin C_Reactive_Protein       CD40       CD5L  Calbindin  Calcitonin        CgA Clusterin_Apo_J Complement_3
#> 1 -0.003903972         -0.1518390  0.4049581 -0.4009738  0.4833949  0.18152964  0.6615849       0.3944845  -0.07213652
#> 2 -0.029453872          0.2118493 -0.2744039  0.4737180 -0.4407043  0.09218262 -0.6694636      -0.2018028   0.44782518
#> 3  0.267720151         -0.2942125 -0.7071246 -0.5831438 -0.4646311 -0.22569856 -0.4033751      -0.9562452  -0.97700689
#> 4 -0.347644361          0.3587236  0.8315875  0.7838171  0.5940685 -0.07091513  0.5644510       1.1113963   0.90994404
#>   Complement_Factor_H Connective_Tissue_Growth_Factor    Cortisol Creatine_Kinase_MB Cystatin_C      EGF_R    EN_RAGE
#> 1          -0.3425673                      -0.3963910  0.08981845         -0.1850983  0.6362843  0.5439616 -0.2025079
#> 2           0.3084882                       0.4891289  0.10248532          0.1608744 -0.4406589 -0.4002516  0.2100168
#> 3          -0.4311773                       0.2397263 -0.31233202          0.2024333 -0.7334024 -0.7162191 -0.1177400
#> 4           0.7086300                      -0.4615853  0.17993637         -0.2524216  0.7619833  0.8170904  0.1764840
#>        ENA_78   Eotaxin_3        FAS FSH_Follicle_Stimulation_Hormon   Fas_Ligand Fatty_Acid_Binding_Protein   Ferritin
#> 1 -0.03638214  0.17869977  0.2957729                    -0.007069386 -0.005420406                  0.5204176  0.3502047
#> 2  0.16887322 -0.04227573 -0.1235480                    -0.310849036 -0.092441566                 -0.3099420 -0.1565868
#> 3 -0.39149165 -0.77133156 -0.8582362                     0.011277465 -0.219933758                 -0.8841779 -0.7244796
#> 4  0.39100364  0.93283609  1.0023020                     0.439933835  0.465883628                  0.9711186  0.7703309
#>     Fetuin_A Fibrinogen    GRO_alpha Gamma_Interferon_induced_Monokin Glutathione_S_Transferase_alpha     HB_EGF
#> 1 -0.3919251 -0.3172756 -0.009909929                       -0.2097486                    -0.409480889  0.3261901
#> 2  0.7259104  0.5035059  0.056906133                        0.3075442                     0.466919295 -0.2177115
#> 3 -0.8853076 -0.6686406 -0.476696566                       -0.7073394                     0.006149705 -0.4633678
#> 4  0.8550796  0.7419172  0.637898843                        0.9189960                    -0.064797218  0.5080620
#>        HCC_4 Hepatocyte_Growth_Factor_HGF      I_309       ICAM_1     IGF_BP_2       IL_11       IL_13      IL_16
#> 1 -0.2752478                    0.4406145  0.4307925  0.068911718  0.140764658  0.08105526  0.02187974  0.1166058
#> 2  0.3365999                   -0.2745098 -0.2344247  0.001539618 -0.009467071 -0.23963765  0.11235664  0.1417507
#> 3 -0.5855593                   -0.8389536 -0.7610287 -0.757107051 -0.854498920  0.16160408 -0.45166305 -0.9542671
#> 4  0.7955334                    0.9731324  0.8151586  1.013625133  1.065608520 -0.01664793  0.47364884  1.0323655
#>         IL_17E  IL_1alpha         IL_3        IL_4       IL_5       IL_6 IL_6_Receptor       IL_7        IL_8
#> 1  0.177293327 -0.1571770 -0.005319973 -0.01412368  0.2655407 -0.2339475     0.4415787 -0.1362562 -0.06166829
#> 
#> ...
#> and 59 more lines.

Tuning clustering parameters

Trying multiple hyper parameter values can be done with tune_cluster() which has a similar interface as tune_grid()

kmeans_wflow <- kmeans_wflow %>%
  update_model(kmeans_spec %>% set_args(num_clusters = tune()))

grid <- tibble(num_clusters = 1:20)

kmeans_res <- tune_cluster(kmeans_wflow, resamples = ad_folds, grid = grid)
collect_metrics(kmeans_res)
#> # A tibble: 40 × 7
#>    num_clusters .metric .estimator   mean     n std_err .config              
#>           <int> <chr>   <chr>       <dbl> <int>   <dbl> <chr>                
#>  1            1 tot_sse standard   28557.    10    23.0 Preprocessor1_Model01
#>  2            1 tot_wss standard   28557.    10    23.0 Preprocessor1_Model01
#>  3            2 tot_sse standard   28557.    10    23.0 Preprocessor1_Model02
#>  4            2 tot_wss standard   23465.    10    52.0 Preprocessor1_Model02
#>  5            3 tot_sse standard   28557.    10    23.0 Preprocessor1_Model03
#>  6            3 tot_wss standard   21884.    10    45.6 Preprocessor1_Model03
#>  7            4 tot_sse standard   28557.    10    23.0 Preprocessor1_Model04
#>  8            4 tot_wss standard   20657.    10    52.8 Preprocessor1_Model04
#>  9            5 tot_sse standard   28557.    10    23.0 Preprocessor1_Model05
#> 10            5 tot_wss standard   20037.    10    47.7 Preprocessor1_Model05
#> # … with 30 more rows

Tuning clustering parameters

autoplot(kmeans_res, metric = "tot_wss")

Thanks!

We want to thank our outside collaborators: Qiushi Yan, Erin LeDell, Tomas Fryda, and Kelly Bodwin.

Thanks to the r-lib, tidymodels, and tidyverse teams at Rstudio Posit PBC

These slides are at https://topepo.github.io/2022-r-pharma.

Learn more about tidymodels at