# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand # Create test data create_test_data <- function() { dt1 <- data.table( x = 1:100, y = rnorm(100), group = rep(letters[1:4], 25) ) dt2 <- data.table( x = 1:50, y = rnorm(50), group = rep(letters[1:2], 25) ) return(list(data1 = dt1, data2 = dt2)) } # Basic functionality test test_that("split_cv returns correct structure", { test_data <- create_test_data() result <- split_cv(test_data, v = 5, repeats = 1) # Check basic structure of return value expect_type(result, "list") expect_equal(length(result), length(test_data)) expect_named(result, names(test_data)) # Check structure of each result set for (res in result) { expect_true(is.data.table(res)) expect_true(all(c("splits", "id", "train", "validate") %in% names(res))) expect_equal(nrow(res), 5) } }) # Error handling test test_that("split_cv handles invalid inputs correctly", { expect_error(split_cv(NULL)) expect_error(split_cv(list())) expect_error(split_cv(list(a = 1, b = 2))) test_data <- create_test_data() expect_warning( split_cv(test_data, v = 5, strata = "non_existent_column"), "Strata variable 'non_existent_column' not found" ) }) # Repeated cross-validation test test_that("split_cv handles repeats correctly", { test_data <- create_test_data() # Single cross-validation result_single <- split_cv(test_data, v = 5, repeats = 1) for (res in result_single) { expect_true("id" %in% names(res)) expect_false("id2" %in% names(res)) expect_equal(nrow(res), 5) expect_true(all(res$id %in% paste0("Fold", 1:5))) } # Multiple repeats result_multiple <- split_cv(test_data, v = 5, repeats = 3) for (res in result_multiple) { expect_true(all(c("id", "id2") %in% names(res))) expect_equal(nrow(res), 15) expect_true(all(grepl("^Repeat\\d+$", res$id))) expect_true(all(grepl("^Fold\\d+$", res$id2))) expect_equal(length(unique(res$id)), 3) expect_equal(length(unique(res$id2)), 5) # Check number of folds in each repeat for (repeat_id in unique(res$id)) { expect_equal(res[id == repeat_id, .N], 5) } } }) # Train and validation sets test test_that("split_cv generates correct train and validate sets", { test_data <- create_test_data() result <- split_cv(test_data, v = 5, repeats = 2) for (i in seq_along(result)) { res <- result[[i]] original_data <- test_data[[i]] # Randomly check one split sample_split_idx <- sample(1:nrow(res), 1) train_set <- res$train[[sample_split_idx]] validate_set <- res$validate[[sample_split_idx]] # Check set properties train_rows <- train_set$x validate_rows <- validate_set$x original_rows <- nrow(original_data) # Check mutual exclusivity expect_equal(length(intersect(train_rows, validate_rows)), 0) # Check sizes expect_equal(length(validate_rows), original_rows/5, tolerance = 1) expect_equal(length(train_rows), original_rows * 4/5, tolerance = 1) # Check completeness all_rows <- sort(unique(c(train_rows, validate_rows))) expect_equal(all_rows, 1:original_rows) } }) # Stratification test test_that("split_cv handles stratification correctly", { test_data <- create_test_data() result <- split_cv(test_data, v = 5, strata = "group") for (res in result) { first_split <- res$train[[1]] expect_true("group" %in% names(first_split)) unique_groups <- unique(first_split$group) expect_true(length(unique_groups) > 1) } }) # Data type handling test test_that("split_cv handles different input data types", { # Test with data.frame df_data <- list( data1 = as.data.frame(create_test_data()[[1]]), data2 = as.data.frame(create_test_data()[[2]]) ) result_df <- split_cv(df_data, v = 5) expect_true(all(sapply(result_df, is.data.table))) # Test with mixed types mixed_data <- list( data1 = as.data.frame(create_test_data()[[1]]), data2 = create_test_data()[[2]] ) result_mixed <- split_cv(mixed_data, v = 5) expect_true(all(sapply(result_mixed, is.data.table))) })