predict.nn_adjust.Rd
Make adjusted predictions
# S3 method for nn_adjust
predict(object, new_data, neighbors = 3, eps = 1/2, cores = 1, ...)
An object of class nn_adjust()
.
A data frame with the original predictors in their original format.
An integer for the number of neighbors. Zero indicates no adjustment.
A small constant added to distance to avoid divide by zero.
An integer for how many cores gower::gower_topn()
should use.
Not currently used.
A tibble with a numeric column .pred
that are the adjusted
predictions.
# example code
if (rlang::is_installed(c("ggplot2", "parsnip", "rpart", "MASS"))) {
library(workflows)
library(dplyr)
library(parsnip)
library(ggplot2)
# ------------------------------------------------------------------------------
# Use the 1D motorcycle helmet data as an example
data(mcycle, package = "MASS")
# Use every fifth data point as a test point
in_test <- ( 1:nrow(mcycle) ) %% 5 == 0
cycl_train <- mcycle[-in_test, ]
cycl_test <- mcycle[ in_test, ]
# ------------------------------------------------------------------------------
# Fit a decision tree
cart_spec <- decision_tree() %>% set_mode("regression")
cart_fit <-
workflow(accel ~ times, cart_spec) %>%
fit(data = cycl_train)
adj_obj <- nn_adjust(cart_fit, cycl_train)
# Raw predictions plus data:
augment(cart_fit, head(cycl_test))
# Adjusted predictions:
predict(adj_obj, head(cycl_test), neighbors = 10)
# Add the data too
augment(adj_obj, head(cycl_test), neighbors = 10)
}
#> # A tibble: 6 × 4
#> .pred accel .resid times
#> <dbl> <dbl> <dbl> <dbl>
#> 1 -2.00 -2.7 -0.698 4
#> 2 -2.68 -2.7 -0.0193 8.2
#> 3 -2.85 -5.4 -2.55 10.2
#> 4 -7.98 -2.7 5.28 13.6
#> 5 -11.2 -9.3 1.86 14.6
#> 6 -36.7 -32.1 4.62 15.4