data(two_class_dat, package = "modeldata") # ------------------------------------------------------------------------------ set.seed(6735) rep_folds <- rsample::vfold_cv(mtcars, v = 2, repeats = 2) spline_rec <- recipes::recipe(mpg ~ ., data = mtcars) %>% recipes::step_ns(disp, deg_free = 3) lin_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") lm_splines <- fit_resamples( lin_mod, spline_rec, rep_folds, control = control_grid(save_pred = TRUE) ) set.seed(93114) rep_folds_class <- rsample::vfold_cv(two_class_dat, v = 2, repeats = 3) svm_mod <- parsnip::svm_rbf(cost = tune("cost value")) %>% parsnip::set_engine("kernlab") %>% parsnip::set_mode("classification") suppressMessages( svm_tune <- tune_bayes( svm_mod, Class ~ ., rep_folds_class, initial = 2, iter = 2, control = control_bayes(save_pred = TRUE) ) ) svm_tune_class <- svm_tune svm_tune_class$.predictions <- purrr::map( svm_tune_class$.predictions, ~ .x %>% dplyr::select(-.pred_Class1, -.pred_Class2) ) attr(svm_tune_class, "metrics") <- yardstick::metric_set(yardstick::kap) svm_grd <- show_best(svm_tune, metric = "roc_auc") %>% dplyr::select(`cost value`) # ------------------------------------------------------------------------------ test_that("`collect_predictions()` errors informatively if there is no `.predictions` column", { expect_snapshot(error = TRUE, { collect_predictions(lm_splines %>% dplyr::select(-.predictions)) }) }) test_that("`collect_predictions()` errors informatively applied to unsupported class", { expect_snapshot( error = TRUE, collect_predictions(lm(mpg ~ disp, mtcars)) ) }) # ------------------------------------------------------------------------------ test_that("`collect_predictions()`, un-averaged", { res <- collect_predictions(lm_splines) exp_res <- unnest(lm_splines %>% dplyr::select(.predictions, starts_with("id")), cols = c(.predictions) ) %>% dplyr::select(all_of(names(res))) expect_equal(res, exp_res) res <- collect_predictions(svm_tune) exp_res <- unnest( svm_tune %>% dplyr::select(.predictions, starts_with("id"), .iter), cols = c(.predictions) ) %>% dplyr::select(all_of(names(res))) res_subset <- collect_predictions(svm_tune, parameters = svm_grd[1, ]) exp_res_subset <- dplyr::filter(exp_res, `cost value` == svm_grd$`cost value`[[1]]) expect_equal(res_subset, exp_res_subset) }) # ------------------------------------------------------------------------------ test_that("bad filter grid", { expect_snapshot( error = TRUE, collect_predictions(svm_tune, parameters = tibble(wrong = "value")) ) expect_true( nrow(collect_predictions(svm_tune, parameters = tibble(`cost value` = 1))) == 0 ) }) # ------------------------------------------------------------------------------ test_that("regression predictions, averaged", { all_res <- collect_predictions(lm_splines) res <- collect_predictions(lm_splines, summarize = TRUE) expect_equal(nrow(res), nrow(mtcars)) expect_false(dplyr::is_grouped_df(res)) # pull out an example to test all_res_subset <- dplyr::filter(all_res, .row == 3) res_subset <- dplyr::filter(res, .row == 3) expect_equal(mean(all_res_subset$.pred), res_subset$.pred) }) # ------------------------------------------------------------------------------ test_that("classification class predictions, averaged", { all_res <- collect_predictions(svm_tune_class) res <- collect_predictions(svm_tune_class, summarize = TRUE) expect_equal(nrow(res), nrow(two_class_dat) * nrow(svm_grd)) expect_false(dplyr::is_grouped_df(res)) expect_named( collect_predictions(svm_tune, summarize = TRUE), c(".pred_class", ".pred_Class1", ".pred_Class2", ".row", "cost value", "Class", ".config", ".iter") ) # pull out an example to test all_res_subset <- dplyr::filter(all_res, .row == 5 & `cost value` == svm_grd$`cost value`[1]) mode_val <- names(sort(table(all_res_subset$.pred_class)))[2] exp_val <- factor(mode_val, levels = levels(all_res_subset$Class)) res_subset <- dplyr::filter(res, .row == 5 & `cost value` == svm_grd$`cost value`[1]) expect_equal(exp_val, res_subset$.pred_class) }) # ------------------------------------------------------------------------------ test_that("classification class and prob predictions, averaged", { all_res <- collect_predictions(svm_tune) res <- collect_predictions(svm_tune, summarize = TRUE) expect_equal(nrow(res), nrow(two_class_dat) * nrow(svm_grd)) expect_false(dplyr::is_grouped_df(res)) # pull out an example to test all_res_subset <- dplyr::filter(all_res, .row == 5 & `cost value` == svm_grd$`cost value`[1]) .pred_Class1 <- mean(all_res_subset$.pred_Class1) .pred_Class2 <- 1 - .pred_Class1 .pred_class <- ifelse(.pred_Class2 > .pred_Class1, "Class2", "Class1") .pred_class <- factor(.pred_class, levels = levels(all_res_subset$Class)) res_subset <- dplyr::filter(res, .row == 5 & `cost value` == svm_grd$`cost value`[1]) expect_equal(.pred_Class1, res_subset$.pred_Class1) expect_equal(.pred_Class2, res_subset$.pred_Class2) expect_equal(.pred_class, res_subset$.pred_class) }) # ------------------------------------------------------------------------------ test_that("collecting notes - fit_resamples", { skip_if(new_rng_snapshots) skip_if(rankdeficient_version) mtcars2 <- mtcars %>% mutate(wt2 = wt) set.seed(1) flds <- rsample::bootstraps(mtcars2, times = 2) lin_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") expect_snapshot( lm_splines <- fit_resamples(lin_mod, mpg ~ ., flds) ) expect_snapshot(lm_splines) nts <- collect_notes(lm_splines) expect_true(all(nts$type == "warning")) expect_true(all(grepl("rank", nts$note))) expect_equal(names(nts), c("id", "location", "type", "note")) }) test_that("collecting notes - last_fit", { skip_if(rankdeficient_version) options(pillar.advice = FALSE, pillar.min_title_chars = Inf) mtcars2 <- mtcars %>% mutate(wt2 = wt) set.seed(1) split <- rsample::initial_split(mtcars2) lin_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") expect_snapshot( lst <- last_fit(lin_mod, mpg ~ ., split) ) expect_snapshot(lst) nts <- collect_notes(lst) expect_true(all(nts$type == "warning")) expect_true(all(grepl("rank", nts$note))) expect_equal(names(nts), c("location", "type", "note")) }) test_that("`collect_notes()` errors informatively applied to unsupported class", { expect_snapshot( error = TRUE, collect_notes(lm(mpg ~ disp, mtcars)) ) }) test_that("collecting extracted objects - fit_resamples", { # skip pre-R-4.0.0 so that snaps aren't affected by stringsAsFactors change skip_if(R.Version()$major < "4") spec <- parsnip::linear_reg() form <- mpg ~ . set.seed(1) boots <- rsample::bootstraps(mtcars, 5) ctrl_fit <- control_resamples(extract = extract_fit_engine) ctrl_err <- control_resamples(extract = function(x) {stop("eeeep! eep!")}) res_fit <- fit_resamples(spec, form, boots, control = ctrl_fit) res_nothing <- fit_resamples(spec, form, boots) suppressMessages({ res_error <- fit_resamples(spec, form, boots, control = ctrl_err) }) expect_snapshot(collect_extracts(res_fit)) expect_snapshot(collect_extracts(res_nothing), error = TRUE) expect_snapshot(collect_extracts(res_error)) }) test_that("`collect_extracts()` errors informatively applied to unsupported class", { expect_snapshot( error = TRUE, collect_extracts(lm(mpg ~ disp, mtcars)) ) }) test_that("`collect_metrics()` errors informatively applied to unsupported class", { expect_snapshot( error = TRUE, collect_metrics(lm(mpg ~ disp, mtcars)) ) }) test_that("`collect_metrics(type)` errors informatively with bad input", { skip_on_cran() expect_snapshot( error = TRUE, collect_metrics(ames_grid_search, type = "boop") ) expect_snapshot( error = TRUE, collect_metrics(ames_grid_search, type = NULL) ) }) test_that("`pivot_metrics()`, grid search, typical metrics, summarized", { expect_equal( pivot_metrics(ames_grid_search, collect_metrics(ames_grid_search)) %>% dplyr::slice(), tibble::tibble( K = integer(0), weight_func = character(0), dist_power = numeric(0), lon = integer(0), lat = integer(0), .config = character(0), rmse = numeric(0), rsq = numeric(0) ) ) }) test_that("`pivot_metrics()`, grid search, typical metrics, unsummarized", { expect_equal( pivot_metrics( ames_grid_search, collect_metrics(ames_grid_search, summarize = FALSE) ) %>% dplyr::slice(), tibble::tibble( K = integer(0), weight_func = character(0), dist_power = numeric(0), lon = integer(0), lat = integer(0), .config = character(0), id = character(0), rmse = numeric(0), rsq = numeric(0) ) ) }) test_that("`pivot_metrics()`, iterative search, typical metrics, summarized", { expect_equal( pivot_metrics(ames_iter_search, collect_metrics(ames_iter_search)) %>% dplyr::slice(), tibble::tibble( K = integer(0), weight_func = character(0), dist_power = numeric(0), lon = integer(0), lat = integer(0), .config = character(0), .iter = integer(0), rmse = numeric(0), rsq = numeric(0) ) ) }) test_that("`pivot_metrics()`, resampled fits, fairness metrics, summarized", { mtcars_fair <- mtcars mtcars_fair$vs <- as.factor(mtcars_fair$vs) mtcars_fair$cyl <- as.factor(mtcars_fair$cyl) mtcars_fair$am <- as.factor(mtcars_fair$am) set.seed(4400) ms <- yardstick::metric_set( yardstick::demographic_parity(cyl), yardstick::demographic_parity(am) ) res <- fit_resamples( nearest_neighbor("classification"), vs ~ mpg + hp + cyl, rsample::bootstraps(mtcars_fair, 3), metrics = ms ) expect_equal( pivot_metrics(res, collect_metrics(res)) %>% slice(), tibble::tibble( .config = character(0), `demographic_parity(am)` = integer(0), `demographic_parity(cyl)` = numeric(0), ) ) })