Given a set of training data, this function builds the Linear Discriminant
Analysis (LDA) classifier, where the distributions of each class are assumed
to be multivariate normal and share a common covariance matrix. When the
pooled sample covariance matrix is singular, the linear discriminant function
is incalculable. This function replaces the inverse of pooled sample
covariance matrix with an estimator proposed by Schafer and Strimmer
(2005). The estimator is calculated via
The Linear Discriminant Analysis (LDA) classifier involves the assumption that the distributions of each class are assumed to be multivariate normal and share a common covariance matrix. When the pooled sample covariance matrix is singular, the linear discriminant function is incalculable. Here, the inverse of the pooled sample covariance matrix is replaced with an estimator from Schafer and Strimmer (2005).
lda_schafer(x, ...) # S3 method for default lda_schafer(x, y, prior = NULL, ...) # S3 method for formula lda_schafer(formula, data, prior = NULL, ...) # S3 method for lda_schafer predict(object, newdata, type = c("class", "prob", "score"), ...)
Matrix or data frame containing the training data. The rows are the sample observations, and the columns are the features. Only complete data are retained.
Options passed to
Vector of class labels for each training observation. Only complete data are retained.
Vector with prior probabilities for each class. If NULL (default), then equal probabilities are used. See details.
A formula of the form
data frame from which variables specified in
Fitted model object
Matrix or data frame of observations to predict. Each row corresponds to a new observation.
Prediction type: either `"class"`, `"prob"`, or `"score"`.
lda_schafer object that contains the trained classifier
The matrix of training observations are given in
x. The rows of
contain the sample observations, and the columns contain the features for each
The vector of class labels given in
y are coerced to a
The length of
y should match the number of rows in
An error is thrown if a given class has less than 2 observations because the variance for each feature within a class cannot be estimated with less than 2 observations.
prior, contains the a priori class membership for
each class. If
prior is NULL (default), the class membership
probabilities are estimated as the sample proportion of observations belonging
to each class. Otherwise,
prior should be a vector with the same length
as the number of classes in
prior probabilities should be
nonnegative and sum to one.
Schafer, J., and Strimmer, K. (2005). "A shrinkage approach to large-scale covariance estimation and implications for functional genomics," Statist. Appl. Genet. Mol. Biol. 4, 32.
library(modeldata) data(penguins) pred_rows <- seq(1, 344, by = 20) penguins <- penguins[, c("species", "body_mass_g", "flipper_length_mm")] lda_schafer_out <- lda_schafer(species ~ ., data = penguins[-pred_rows, ]) predicted <- predict(lda_schafer_out, penguins[pred_rows, -1], type = "class") lda_schafer_out2 <- lda_schafer(x = penguins[-pred_rows, -1], y = penguins$species[-pred_rows]) predicted2 <- predict(lda_schafer_out2, penguins[pred_rows, -1], type = "class") all.equal(predicted, predicted2)#>  TRUE