Make adjusted predictions

# S3 method for nn_adjust
predict(object, new_data, neighbors = 3, eps = 1/2, cores = 1, ...)

Arguments

object

An object of class nn_adjust().

new_data

A data frame with the original predictors in their original format.

neighbors

An integer for the number of neighbors. Zero indicates no adjustment.

eps

A small constant added to distance to avoid divide by zero.

cores

An integer for how many cores gower::gower_topn() should use.

...

Not currently used.

Value

A tibble with a numeric column .pred that are the adjusted predictions.

Examples

# example code

if (rlang::is_installed(c("ggplot2", "parsnip", "rpart", "MASS"))) {

  library(workflows)
  library(dplyr)
  library(parsnip)
  library(ggplot2)

  # ------------------------------------------------------------------------------
  # Use the 1D motorcycle helmet data as an example

  data(mcycle, package = "MASS")

  # Use every fifth data point as a test point
  in_test <- ( 1:nrow(mcycle) ) %% 5 == 0
  cycl_train <- mcycle[-in_test, ]
  cycl_test  <- mcycle[ in_test, ]

  # ------------------------------------------------------------------------------
  # Fit a decision tree

  cart_spec <- decision_tree() %>% set_mode("regression")

  cart_fit <-
    workflow(accel ~ times, cart_spec) %>%
    fit(data = cycl_train)

  adj_obj <- nn_adjust(cart_fit, cycl_train)

  # Raw predictions plus data:
  augment(cart_fit, head(cycl_test))

  # Adjusted predictions:
  predict(adj_obj, head(cycl_test), neighbors = 10)

  # Add the data too
  augment(adj_obj, head(cycl_test), neighbors = 10)

}
#> # A tibble: 6 × 4
#>    .pred accel  .resid times
#>    <dbl> <dbl>   <dbl> <dbl>
#> 1  -2.00  -2.7 -0.698    4  
#> 2  -2.68  -2.7 -0.0193   8.2
#> 3  -2.85  -5.4 -2.55    10.2
#> 4  -7.98  -2.7  5.28    13.6
#> 5 -11.2   -9.3  1.86    14.6
#> 6 -36.7  -32.1  4.62    15.4