Applies cv_fun to the folds using future_lapply and combines the results across folds using combine_results.

cross_validate(cv_fun, folds, ..., use_future = TRUE, .combine = TRUE,
  .combine_control = list(), .old_results = NULL)

Arguments

cv_fun

a function that takes a 'fold' as it's first argument and returns a list of results from that fold. NOTE: the use of an argument named 'X' is specifically disallowed in any input function for compliance with the functions lapply and future.apply::future_lapply.

folds

a list of folds to loop over generated using make_folds.

...

other arguments passed to cvfun.

use_future

logical option for whether to run the main loop of cross-validation with future_lapply or with lapply.

.combine

(logical) - should combine_results be called.

.combine_control

(list) - arguments to combine_results.

.old_results

(list) - the returned result from a previous call to This function. Will be combined with the current results. This is useful for adding additional CV folds to a results object.

Value

A list of results, combined across folds.

Examples

############################################################################### # This example explains how to use the cross_validate function naively. ############################################################################### data(mtcars) # resubstitution MSE r <- lm(mpg ~ ., data = mtcars) mean(resid(r)^2)
#> [1] 4.609201
# function to calculate cross-validated squared error cv_lm <- function(fold, data, reg_form) { # get name and index of outcome variable from regression formula out_var <- as.character(unlist(stringr::str_split(reg_form, " "))[1]) out_var_ind <- as.numeric(which(colnames(data) == out_var)) # split up data into training and validation sets train_data <- training(data) valid_data <- validation(data) # fit linear model on training set and predict on validation set mod <- lm(as.formula(reg_form), data = train_data) preds <- predict(mod, newdata = valid_data) # capture results to be returned as output out <- list(coef = data.frame(t(coef(mod))), SE = ((preds - valid_data[, out_var_ind])^2)) return(out) } # replicate the resubstitution estimate resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]] resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .") mean(resub_results$SE)
#> [1] 4.609201
# cross-validated estimate folds <- make_folds(mtcars) cv_results <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ .") mean(cv_results$SE)
#> [1] 14.59195
############################################################################### # This example explains how to use the cross_validate function with # parallelization using the framework of the future package. ############################################################################### suppressMessages(library(data.table)) library(future)
#> #> Attaching package: ‘future’
#> The following object is masked from ‘package:origami’: #> #> future_lapply
data(mtcars) set.seed(1) # make a lot of folds folds <- make_folds(mtcars, fold_fun = folds_bootstrap, V = 1000) # function to calculate cross-validated squared error for linear regression cv_lm <- function(fold, data, reg_form) { # get name and index of outcome variable from regression formula out_var <- as.character(unlist(str_split(reg_form, " "))[1]) out_var_ind <- as.numeric(which(colnames(data) == out_var)) # split up data into training and validation sets train_data <- training(data) valid_data <- validation(data) # fit linear model on training set and predict on validation set mod <- lm(as.formula(reg_form), data = train_data) preds <- predict(mod, newdata = valid_data) # capture results to be returned as output out <- list(coef = data.frame(t(coef(mod))), SE = ((preds - valid_data[, out_var_ind])^2)) return(out) } plan(sequential) time_seq <- system.time({ results_seq <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ .") })
#> Warning: All iterations resulted in errors
plan(multicore) time_mc <- system.time({ results_mc <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars, reg_form = "mpg ~ .") })
#> Warning: All iterations resulted in errors
if(availableCores() > 1) { time_mc["elapsed"] < 1.2 * time_seq["elapsed"] }
#> elapsed #> TRUE