# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand test_that("nest_cv functions correctly", { # Setup test data base_dt <- data.table::data.table( id = 1:100, group = rep(c("A", "B"), each = 50), value = rnorm(100), category = factor(rep(c("X", "Y"), times = 50)) ) # Create nested test data test_dt <- data.table::data.table( name = c("group1", "group2"), data = list( base_dt[group == "A"], base_dt[group == "B"] ) ) # Test basic functionality test_that("basic functionality works", { result <- nest_cv(test_dt, v = 5, repeats = 2) # Check structure expect_true(data.table::is.data.table(result)) expect_true(all(c("name", "id", "splits", "train", "validate") %in% names(result))) expect_equal(nrow(result), 5 * 2 * 2) # v * repeats * groups # Check content types expect_true(all(sapply(result$train, data.table::is.data.table))) expect_true(all(sapply(result$validate, data.table::is.data.table))) # Check split sizes first_fold <- result$train[[1]] expect_equal(nrow(first_fold), 40) # 80% of 50 for training expect_equal(nrow(result$validate[[1]]), 10) # 20% of 50 for validation }) # Test with stratification test_that("stratification works correctly", { result <- nest_cv(test_dt, v = 5, repeats = 1, strata = "category") # Check if strata is maintained in splits first_train <- result$train[[1]] first_validate <- result$validate[[1]] # Check proportions in training and validation sets train_prop <- table(first_train$category) / nrow(first_train) validate_prop <- table(first_validate$category) / nrow(first_validate) expect_equal(train_prop, train_prop, tolerance = 0.1) }) # Test with different parameters test_that("parameter variations work", { # Test with different v result_v3 <- nest_cv(test_dt, v = 3, repeats = 1) expect_equal(nrow(result_v3), 3 * 2) # 3 folds * 2 groups # Test with different repeats result_r3 <- nest_cv(test_dt, v = 5, repeats = 3) expect_equal(nrow(result_r3), 5 * 3 * 2) # 5 folds * 3 repeats * 2 groups # Test with different breaks result_breaks <- nest_cv(test_dt, v = 5, breaks = 3) expect_true(!is.null(result_breaks)) # Test with different pool result_pool <- nest_cv(test_dt, v = 5, pool = 0.2) expect_true(!is.null(result_pool)) }) # Test error handling test_that("error handling works correctly", { # Test empty input empty_dt <- test_dt[0] expect_error( nest_cv(empty_dt), "Input 'nest_dt' cannot be empty" ) # Test input without nested columns bad_dt <- data.table::data.table(a = 1:3, b = 4:6) expect_error( nest_cv(bad_dt), "Input 'nest_dt' must contain at least one nested column" ) }) # Test with multiple nested columns test_that("multiple nested columns work", { # Create test data with multiple nested columns multi_nest_dt <- data.table::data.table( name = c("group1", "group2"), data1 = list( base_dt[group == "A"], base_dt[group == "B"] ), data = list( base_dt[group == "A"], base_dt[group == "B"] ) ) result <- nest_cv(multi_nest_dt, v = 2) expect_true(!is.null(result)) expect_true(all(c("name", "splits", "train", "validate") %in% names(result))) }) # Test data consistency test_that("data consistency is maintained", { result <- nest_cv(test_dt, v = 5, repeats = 1) # Check that all original columns are preserved in splits first_train <- result$train[[1]] expect_true(all(names(base_dt) %in% names(first_train))) # Check that no observations are lost or duplicated for (i in seq_len(nrow(result))) { train_set <- result$train[[i]] validate_set <- result$validate[[i]] # Total number of observations should equal original group size expect_equal(nrow(train_set) + nrow(validate_set), 50) # No duplicates between train and validate train_ids <- train_set$id validate_ids <- validate_set$id expect_equal(length(intersect(train_ids, validate_ids)), 0) } }) # Test reproducibility test_that("results are reproducible with seed", { set.seed(123) result1 <- nest_cv(test_dt, v = 5) set.seed(123) result2 <- nest_cv(test_dt, v = 5) expect_equal(result1, result2) }) # Test with different data types test_that("handles different data types correctly", { # Create test data with various data types complex_dt <- data.table::data.table( id = 1:50, num = rnorm(50), int = 1:50, fct = factor(rep(letters[1:5], 10)), date = seq(as.Date("2024-01-01"), by = "day", length.out = 50), char = letters[1:50] ) nested_complex <- data.table::data.table( name = "group1", data = list(complex_dt) ) result <- nest_cv(nested_complex, v = 5) # Check that data types are preserved first_train <- result$train[[1]] expect_type(first_train$num, "double") expect_type(first_train$int, "integer") expect_s3_class(first_train$fct, "factor") expect_s3_class(first_train$date, "Date") expect_type(first_train$char, "character") }) })