# Test file for variable fold weights functionality # Setup test data set.seed(42) test_data <- data.frame( x1 = rnorm(50), x2 = rnorm(50), x3 = rnorm(50) ) test_data$y <- 2 * test_data$x1 + 3 * test_data$x2 + rnorm(50, sd = 0.5) set.seed(123) folds <- rsample::vfold_cv(mtcars, v = 3) # Helper function to create a simple model create_test_model <- function() { parsnip::linear_reg() |> parsnip::set_engine("lm") } test_that("add_resample_weights() validates inputs correctly", { expect_snapshot( add_resample_weights("not_an_rset", c(0.5, 0.3, 0.2)), error = TRUE ) expect_snapshot( add_resample_weights(folds, c("a", "b", "c")), error = TRUE ) expect_snapshot( add_resample_weights(folds, c(0.5, 0.3)), error = TRUE ) expect_snapshot( add_resample_weights(folds, c(-0.1, 0.5, 0.6)), error = TRUE ) expect_snapshot( add_resample_weights(folds, c(0, 0, 0)), error = TRUE ) }) test_that("add_resample_weights() adds weights correctly", { weights <- c(0.1, 0.5, 0.4) weighted_folds <- add_resample_weights(folds, weights) # Weights get normalized to sum to 1 expected_weights <- weights / sum(weights) expect_s3_class(weighted_folds, "rset") expect_equal(attr(weighted_folds, ".resample_weights"), expected_weights) expect_equal(nrow(weighted_folds), nrow(folds)) }) test_that("calculate_resample_weights() works correctly", { auto_weights <- calculate_resample_weights(folds) expect_type(auto_weights, "double") expect_length(auto_weights, nrow(folds)) expect_true(all(auto_weights > 0)) expect_true(abs(sum(auto_weights) - 1) < 1e-10) }) test_that("weights are preserved through tuning pipeline", { weights <- c(0.1, 0.5, 0.4) weighted_folds <- add_resample_weights(folds, weights) mod <- create_test_model() suppressWarnings({ res <- tune_grid( mod, mpg ~ ., resamples = weighted_folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) metrics <- collect_metrics(res) expect_equal(nrow(metrics), 1) expect_true("mean" %in% names(metrics)) expect_true(is.numeric(metrics$mean)) }) test_that("weights affect metric aggregation", { weights <- c(0.1, 0.5, 0.4) weighted_folds <- add_resample_weights(folds, weights) mod <- create_test_model() suppressWarnings({ # Unweighted results res_unweighted <- tune_grid( mod, mpg ~ ., resamples = folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) # Weighted results res_weighted <- tune_grid( mod, mpg ~ ., resamples = weighted_folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) unweighted_rmse <- collect_metrics(res_unweighted)$mean[1] weighted_rmse <- collect_metrics(res_weighted)$mean[1] expect_true(is.numeric(unweighted_rmse)) expect_true(is.numeric(weighted_rmse)) expect_false(is.na(unweighted_rmse)) expect_false(is.na(weighted_rmse)) }) test_that("extreme weights show larger effect", { skip_if_not_installed("kknn") # Create folds for this specific test set.seed(42) test_folds <- rsample::vfold_cv(test_data, v = 3) # Regular weights weights <- c(0.6, 0.2, 0.2) weighted_folds <- add_resample_weights(test_folds, weights) # Extreme weights extreme_weights <- c(0.95, 0.025, 0.025) extreme_weighted_folds <- add_resample_weights(test_folds, extreme_weights) # Create a model with tuning parameter knn_spec <- parsnip::nearest_neighbor(neighbors = tune()) |> parsnip::set_engine("kknn") |> parsnip::set_mode("regression") param_grid <- data.frame(neighbors = c(3, 5)) suppressWarnings({ # Unweighted res_unweighted <- tune_grid( knn_spec, y ~ ., resamples = test_folds, grid = param_grid, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) # Regular weights res_weighted <- tune_grid( knn_spec, y ~ ., resamples = weighted_folds, grid = param_grid, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) # Extreme weights res_extreme <- tune_grid( knn_spec, y ~ ., resamples = extreme_weighted_folds, grid = param_grid, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) unweighted_metrics <- collect_metrics(res_unweighted) weighted_metrics <- collect_metrics(res_weighted) extreme_metrics <- collect_metrics(res_extreme) # Check that results exist and are sensible expect_equal(nrow(unweighted_metrics), 2) expect_equal(nrow(weighted_metrics), 2) expect_equal(nrow(extreme_metrics), 2) # Calculate differences regular_diff <- max(abs(unweighted_metrics$mean - weighted_metrics$mean)) extreme_diff <- max(abs(unweighted_metrics$mean - extreme_metrics$mean)) expect_true(regular_diff >= 0) expect_true(extreme_diff >= 0) expect_true(all(is.finite(c(regular_diff, extreme_diff)))) }) test_that("weight normalization works correctly", { expect_equal( tune:::.validate_resample_weights(c(3, 6, 9), 3), c(1 / 6, 1 / 3, 1 / 2) # normalized to sum to 1 ) expect_equal( tune:::.validate_resample_weights(c(0.2, 0.3, 0.5), 3), c(0.2, 0.3, 0.5) # already normalized to sum to 1 ) }) test_that("equal weights return NULL", { # Simplest integer match expect_null(tune:::.validate_resample_weights(c(2, 2, 2), 3)) # Fractional match expect_null(tune:::.validate_resample_weights(c(1 / 3, 1 / 3, 1 / 3), 3)) # Check more reseampless expect_null(tune:::.validate_resample_weights(c(1, 1, 1, 1, 1), 5)) }) test_that("unequal weights do not return NULL", { # Check non-null decimal values result <- tune:::.validate_resample_weights(c(0.1, 0.5, 0.4), 3) expect_false(is.null(result)) expect_equal(result, c(0.1, 0.5, 0.4)) # Non-null fractional values result2 <- tune:::.validate_resample_weights(c(1, 2, 3), 3) expect_false(is.null(result2)) expect_equal(result2, c(1 / 6, 2 / 6, 3 / 6)) }) test_that("add_resample_weights with equal weights returns NULL attribute", { # Adding equal weights should trigger NULL assignment equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1)) expect_null(attr(equal_weighted_folds, ".resample_weights")) # Verify it's still an rset object expect_s3_class(equal_weighted_folds, "rset") }) test_that("equal weights produce same results as no weights", { mod <- create_test_model() suppressWarnings({ # Results with no weights res_no_weights <- tune_grid( mod, mpg ~ ., resamples = folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) # Results with equal weights equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1)) res_equal_weights <- tune_grid( mod, mpg ~ ., resamples = equal_weighted_folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) metrics_no_weights <- collect_metrics(res_no_weights) metrics_equal_weights <- collect_metrics(res_equal_weights) # Results should match expect_equal(metrics_no_weights$mean, metrics_equal_weights$mean) expect_equal(metrics_no_weights$std_err, metrics_equal_weights$std_err) }) test_that("weighted statistics functions work correctly", { x <- c(1, 2, 3, 4, 5) w <- c(0.1, 0.2, 0.3, 0.2, 0.2) weighted_sd <- tune:::.weighted_sd(x, w) expect_true(is.numeric(weighted_sd)) expect_false(is.na(weighted_sd)) expect_true(weighted_sd >= 0) # Test with NA values x_na <- c(1, 2, NA, 4, 5) weighted_sd_na <- tune:::.weighted_sd(x_na[!is.na(x_na)], w[!is.na(x_na)]) expect_true(is.numeric(weighted_sd_na)) # Test edge cases expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value }) test_that("fold weight extraction works", { weights <- c(0.1, 0.5, 0.4) weighted_folds <- add_resample_weights(folds, weights) # Weights get normalized to sum to 1 expected_weights <- weights / sum(weights) mod <- create_test_model() suppressWarnings({ res <- tune_grid( mod, mpg ~ ., resamples = weighted_folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) extracted_weights <- tune:::.get_resample_weights(res) expect_equal(extracted_weights, expected_weights) }) test_that("individual fold metrics can be collected", { weights <- c(0.1, 0.5, 0.4) weighted_folds <- add_resample_weights(folds, weights) mod <- create_test_model() suppressWarnings({ res <- tune_grid( mod, mpg ~ ., resamples = weighted_folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) # Collect individual fold metrics individual_metrics <- collect_metrics(res, summarize = FALSE) expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold expect_true("id" %in% names(individual_metrics)) expect_true(".estimate" %in% names(individual_metrics)) expect_true(all(is.finite(individual_metrics$.estimate))) }) test_that("backwards compatibility - no weights", { mod <- create_test_model() suppressWarnings({ res <- tune_grid( mod, mpg ~ ., resamples = folds, # No weights grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) metrics <- collect_metrics(res) expect_equal(nrow(metrics), 1) expect_true("mean" %in% names(metrics)) expect_true(is.numeric(metrics$mean)) expect_false(is.na(metrics$mean)) }) test_that("rset tibble conversion includes fold weights", { weights <- c(0.1, 0.4, 0.5) weighted_folds <- add_resample_weights(folds, weights) # Convert to tibble manually (this is what our print method does) x_tbl <- tibble::as_tibble(weighted_folds) x_tbl$resample_weight <- weights # Verify the structure expect_true("resample_weight" %in% names(x_tbl)) expect_equal(x_tbl$resample_weight, weights) expect_equal(nrow(x_tbl), 3) }) test_that("extract_resample_weights() works with rset objects", { weights <- c(0.2, 0.3, 0.5) weighted_folds <- add_resample_weights(folds, weights) # Should return the weights extracted_weights <- extract_resample_weights(weighted_folds) expect_equal(extracted_weights, weights) # Should return NULL for unweighted rsets unweighted_result <- extract_resample_weights(folds) expect_null(unweighted_result) }) test_that("extract_resample_weights() works with tune_results objects", { weights <- c(0.1, 0.5, 0.4) weighted_folds <- add_resample_weights(folds, weights) mod <- create_test_model() suppressWarnings({ res <- tune_grid( mod, mpg ~ ., resamples = weighted_folds, grid = 1, metrics = yardstick::metric_set(yardstick::rmse), control = control_grid(verbose = FALSE) ) }) # Should extract weights from tune results extracted_weights <- extract_resample_weights(res) expected_weights <- weights / sum(weights) # normalized expect_equal(extracted_weights, expected_weights) }) test_that("extract_resample_weights() validates input types", { expect_snapshot( extract_resample_weights("not_valid_input"), error = TRUE ) expect_snapshot( extract_resample_weights(data.frame(x = 1:3)), error = TRUE ) })