# Setup utils::data(models.reg) utils::data(X.reg) utils::data(Y.reg) utils::data(models.class) utils::data(X.class) utils::data(Y.class) set.seed(1234L) ens.reg <- caretEnsemble(models.reg) ens.class <- caretEnsemble(models.class) mod <- caret::train( X.reg, Y.reg, method = "lm", trControl = caret::trainControl(method = "none") ) # Helper function for testing expect_data_table_structure <- function(dt, expected_names) { testthat::expect_s3_class(dt, "data.table") testthat::expect_named(dt, expected_names) } ############################################################################# testthat::context("caretPredict and extractMetric") ############################################################################# testthat::test_that("caretPredict extracts best predictions correctly", { stacked_preds_class <- caretPredict(models.class[[1L]], excluded_class_id = 0L) stacked_preds_reg <- caretPredict(models.reg[[1L]]) expect_data_table_structure(stacked_preds_class, c("No", "Yes")) expect_data_table_structure(stacked_preds_reg, "pred") }) testthat::test_that("extractMetric works for different model types", { # Test for model with no resampling (no SD) metric <- extractMetric(mod) expect_data_table_structure(metric, c("model_name", "metric", "value", "sd")) testthat::expect_true(is.na(metric$value), is.na(metric$sd)) # Test for ensemble models for (ens in list(ens.class, ens.reg)) { metrics <- extractMetric(ens) expect_data_table_structure(metrics, c("model_name", "metric", "value", "sd")) testthat::expect_named(ens$models, metrics$model_name[-1L]) } }) ############################################################################# testthat::context("S3 methods and model operations") ############################################################################# testthat::test_that("c.train on 2 train objects", { testthat::expect_error(c.train(list()), "class of modelList1 must be 'caretList' or 'train'") combined_models <- c(models.class[[1L]], models.class[[1L]]) testthat::expect_s3_class(combined_models, "caretList") testthat::expect_length(combined_models, 2L) testthat::expect_identical(anyDuplicated(names(combined_models)), 0L) testthat::expect_length(unique(names(combined_models)), 2L) }) testthat::test_that("c.train on a train and a caretList", { bigList <- c(models.reg[[1L]], models.class) testthat::expect_is(bigList, "caretList") testthat::expect_identical(anyDuplicated(names(bigList)), 0L) testthat::expect_length(unique(names(bigList)), 5L) }) testthat::test_that("extractModelName handles different model types", { testthat::expect_identical(extractModelName(models.class[[1L]]), "rf") testthat::expect_identical(extractModelName(models.reg[[1L]]), "rf") custom_model <- models.class[[1L]] custom_model$method <- list(method = "custom_rf") testthat::expect_identical(extractModelName(custom_model), "custom_rf") mock_model <- list(method = list(method = "custom_method")) class(mock_model) <- "train" testthat::expect_identical(extractModelName(mock_model), "custom_method") mock_model <- list(method = "custom", modelInfo = list(method = "custom_method")) class(mock_model) <- "train" testthat::expect_identical(extractModelName(mock_model), "custom_method") }) ############################################################################# testthat::context("isClassifierAndValidate") ############################################################################# testthat::test_that("isClassifierAndValidate handles various model types", { models_multi <- caretList( iris[, 1L:2L], iris[, 5L], tuneLength = 1L, verbose = FALSE, methodList = c("rf", "gbm") ) models_multi_bin_reg <- c(models_multi, models.class, models.reg) testthat::expect_is(vapply(models_multi_bin_reg, isClassifierAndValidate, logical(1L)), "logical") # Test when predictions are missing model_list <- models.class model_list[[1L]]$pred <- NULL testthat::expect_is(vapply(model_list, isClassifierAndValidate, logical(1L)), "logical") testthat::expect_equivalent(unique(vapply(model_list, isClassifierAndValidate, logical(1L))), TRUE) # Test error cases model_list <- models.class model_list[[1L]]$modelInfo$prob <- FALSE testthat::expect_error( lapply(model_list, isClassifierAndValidate), "No probability function found. Re-fit with a method that supports prob." ) model_list <- models.class model_list[[1L]]$control$classProbs <- FALSE testthat::expect_error( lapply(model_list, isClassifierAndValidate, validate_for_stacking = TRUE), "classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl." ) # Test for non-caretList object testthat::expect_error( isClassifierAndValidate(list(model = lm(Y.reg ~ ., data = as.data.frame(X.reg)))), "is(object, \"train\") is not TRUE", fixed = TRUE ) # Test for models without savePredictions model <- models.class[[1L]] model$control$savePredictions <- NULL testthat::expect_error( isClassifierAndValidate(model), "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions." ) model$control$savePredictions <- "BAD_VALUE" testthat::expect_error( isClassifierAndValidate(model), "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions." ) }) ############################################################################# testthat::context("validateExcludedClass") ############################################################################# testthat::test_that("validateExcludedClass handles various inputs", { testthat::expect_error(validateExcludedClass("invalid"), "classification excluded level must be numeric: invalid") testthat::expect_warning( testthat::expect_error(validateExcludedClass(Inf), "classification excluded level must be finite: Inf"), "classification excluded level is not an integer: Inf" ) testthat::expect_warning( testthat::expect_error(validateExcludedClass(-1.0), "classification excluded level must be >= 0: -1"), "classification excluded level is not an integer:" ) testthat::expect_warning(validateExcludedClass(1.1), "classification excluded level is not an integer: 1.1") testthat::expect_identical(validateExcludedClass(3L), 3L) # Edge cases testthat::expect_identical(validateExcludedClass(0L), 0L) testthat::expect_identical(validateExcludedClass(1L), 1L) testthat::expect_identical(validateExcludedClass(4L), 4L) w <- "classification excluded level is not an integer:" testthat::expect_warning(testthat::expect_identical(validateExcludedClass(0.0), 0L), w) testthat::expect_warning(testthat::expect_identical(validateExcludedClass(1.0), 1L), w) testthat::expect_warning(testthat::expect_identical(validateExcludedClass(4.0), 4L), w) testthat::expect_error(validateExcludedClass(-1L), "classification excluded level must be >= 0: -1") # Additional tests testthat::expect_warning(validateExcludedClass(NULL), "No excluded_class_id set. Setting to 1L.") testthat::expect_error( validateExcludedClass(c(1L, 2L)), "classification excluded level must have a length of 1: length=2" ) testthat::expect_warning( testthat::expect_error(validateExcludedClass(-0.000001), "classification excluded level must be >= 0: -1e-06"), "classification excluded level is not an integer" ) })