Augment data with predicted values

# S3 method for nn_adjust
augment(x, new_data, ...)

Arguments

x

An object of class nn_adjust().

new_data

A data frame with the original predictors in their original format.

...

Not currently used.

Value

The data being predicted with an additional column .pred that are the adjusted predictions. If new_data contains the original outcome column, there is also a .resid column.

Examples

# example code

if (rlang::is_installed(c("ggplot2", "parsnip", "rpart", "MASS"))) {

  library(workflows)
  library(dplyr)
  library(parsnip)
  library(ggplot2)

  # ------------------------------------------------------------------------------
  # Use the 1D motorcycle helmet data as an example

  data(mcycle, package = "MASS")

  # Use every fifth data point as a test point
  in_test <- ( 1:nrow(mcycle) ) %% 5 == 0
  cycl_train <- mcycle[-in_test, ]
  cycl_test  <- mcycle[ in_test, ]

  # ------------------------------------------------------------------------------
  # Fit a decision tree

  cart_spec <- decision_tree() %>% set_mode("regression")

  cart_fit <-
    workflow(accel ~ times, cart_spec) %>%
    fit(data = cycl_train)

  adj_obj <- nn_adjust(cart_fit, cycl_train)

  # Raw predictions plus data:
  augment(cart_fit, head(cycl_test))

  # Adjusted predictions:
  predict(adj_obj, head(cycl_test), neighbors = 10)

  # Add the data too
  augment(adj_obj, head(cycl_test), neighbors = 10)

}
#> 
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#> 
#>     filter, lag
#> The following objects are masked from ‘package:base’:
#> 
#>     intersect, setdiff, setequal, union
#> # A tibble: 6 × 4
#>    .pred accel  .resid times
#>    <dbl> <dbl>   <dbl> <dbl>
#> 1  -2.00  -2.7 -0.698    4  
#> 2  -2.68  -2.7 -0.0193   8.2
#> 3  -2.85  -5.4 -2.55    10.2
#> 4  -7.98  -2.7  5.28    13.6
#> 5 -11.2   -9.3  1.86    14.6
#> 6 -36.7  -32.1  4.62    15.4