library(testthat) library(recipes) r_version <- function() paste0("R", getRversion()[, 1:2]) ex_tr <- data.frame( x1 = 1:100, x2 = rep(1:5, each = 20), x3 = factor(rep(letters[1:2], each = 50)) ) ex_tr_mis <- ex_tr ex_tr_mis$x1[2] <- NA ex_tr_mis$x3[10] <- NA ex_te <- data.frame(x1 = c(1, 50, 101, NA)) lvls_breaks_4 <- c("[missing]", "[-Inf,25.8]", "(25.8,50.5]", "(50.5,75.2]", "(75.2, Inf]") lvls_breaks_4_bin <- c("bin_missing", "bin1", "bin2", "bin3", "bin4") test_that("default args", { bin_1 <- discretize(ex_tr$x1, prefix = NULL) pred_1 <- predict(bin_1, ex_te$x1) exp_1 <- factor(lvls_breaks_4[c(2, 3, 5, 1)], levels = lvls_breaks_4) expect_equal(pred_1, exp_1) bin_1 <- discretize(ex_tr$x1) pred_1 <- predict(bin_1, ex_te$x1) exp_1 <- factor(c("bin1", "bin2", "bin4", "bin_missing"), levels = lvls_breaks_4_bin) expect_equal(pred_1, exp_1) }) test_that("NA values", { bin_2 <- discretize(ex_tr$x1, keep_na = FALSE, prefix = NULL) pred_2 <- predict(bin_2, ex_te$x1) exp_2 <- factor(lvls_breaks_4[c(2, 3, 5, NA)], levels = lvls_breaks_4[-1]) expect_equal(pred_2, exp_2) bin_2 <- discretize(ex_tr$x1, keep_na = FALSE) pred_2 <- predict(bin_2, ex_te$x1) exp_2 <- factor(c("bin1", "bin2", "bin4", NA), levels = lvls_breaks_4_bin[-1]) expect_equal(pred_2, exp_2) }) test_that("bad values", { expect_snapshot(error = TRUE, discretize(letters)) }) test_that("printing of discretize()", { expect_snapshot(discretize(1:100)) expect_snapshot(discretize(1:100, cuts = 6)) expect_snapshot(discretize(1:100, keep_na = FALSE)) expect_snapshot( res <- discretize(1:2) ) expect_snapshot(res) }) test_that("NA values from out of range", { bin_3 <- discretize(ex_tr$x1, keep_na = FALSE, infs = FALSE, prefix = NULL) pred_3 <- predict(bin_3, ex_te$x1) exp_3 <- factor( c("[1,25.8]", "(25.8,50.5]", NA, NA), levels = c("[1,25.8]", "(25.8,50.5]", "(50.5,75.2]", "(75.2,100]") ) expect_equal(pred_3, exp_3) bin_3 <- discretize(ex_tr$x1, keep_na = FALSE, infs = FALSE) pred_3 <- predict(bin_3, ex_te$x1) exp_3 <- factor(c("bin1", "bin2", NA, NA), levels = lvls_breaks_4_bin[-1]) expect_equal(pred_3, exp_3) }) test_that("NA values with step_discretize (issue #127)", { iris_na <- iris iris_na$sepal_na <- iris_na$Sepal.Length iris_na$sepal_na[1:5] <- NA disc_values <- discretize( iris_na$sepal_na, min.unique = 2, cuts = 2, keep_na = TRUE, na.rm = TRUE ) # We expect na.rm to be overwritten opts <- list(min.unique = 2, cuts = 2, keep_na = TRUE, na.rm = FALSE) rec <- recipe(~., data = iris_na) %>% step_discretize(sepal_na, options = opts) %>% prep(training = iris_na) expect_equal(rec$steps[[1]]$objects$sepal_na, disc_values) }) test_that("tidys", { rec <- recipe(~., data = ex_tr) %>% step_discretize(x1, id = "") tidy_exp_un <- tibble( terms = "x1", value = NA_real_, id = "" ) expect_equal(tidy(rec, 1), tidy_exp_un) rec_trained <- prep(rec, training = ex_tr) br <- rec_trained$steps[[1]]$objects$x1$breaks tidy_exp_tr <- tibble( terms = rep("x1", length(br)), value = br, id = "" ) expect_equal(tidy(rec_trained, 1), tidy_exp_tr) }) test_that("multiple column prefix", { set.seed(1234) example_data <- tibble( x1 = rnorm(1000), x2 = rnorm(1000) ) expect_snapshot( recipe(~., data = example_data) %>% step_discretize(x1, x2, options = list(prefix = "hello")) %>% prep() ) expect_snapshot(error = TRUE, recipe(~., data = example_data) %>% step_discretize(x1, x2, options = list(labels = "hello")) %>% prep(), variant = r_version() ) }) test_that("bad args", { expect_snapshot(error = TRUE, recipe(~., data = ex_tr) %>% step_discretize(x1, num_breaks = 1) %>% prep() ) expect_snapshot( recipe(~., data = ex_tr) %>% step_discretize(x1, num_breaks = 100) %>% prep() ) expect_snapshot( recipe(~., data = ex_tr) %>% step_discretize(x1, options = list(prefix = "@$")) %>% prep() ) }) test_that("tunable", { rec <- recipe(~., data = iris) %>% step_discretize(all_predictors()) rec_param <- tunable.step_discretize(rec$steps[[1]]) expect_equal(rec_param$name, c("min_unique", "num_breaks")) expect_true(all(rec_param$source == "recipe")) expect_true(is.list(rec_param$call_info)) expect_equal(nrow(rec_param), 2) expect_equal( names(rec_param), c("name", "call_info", "source", "component", "component_id") ) }) # Infrastructure --------------------------------------------------------------- test_that("bake method errors when needed non-standard role columns are missing", { rec <- recipe(cyl ~ ., mtcars) rec <- step_discretize(rec, mpg, min_unique = 3) %>% update_role(mpg, new_role = "potato") %>% update_role_requirements(role = "potato", bake = FALSE) rec <- prep(rec, mtcars) expect_error(bake(rec, new_data = mtcars[, 2:ncol(mtcars)]), class = "new_data_missing_column") }) test_that("empty printing", { rec <- recipe(mpg ~ ., mtcars) rec <- step_discretize(rec) expect_snapshot(rec) rec <- prep(rec, mtcars) expect_snapshot(rec) }) test_that("empty selection prep/bake is a no-op", { rec1 <- recipe(mpg ~ ., mtcars) rec2 <- step_discretize(rec1) rec1 <- prep(rec1, mtcars) rec2 <- prep(rec2, mtcars) baked1 <- bake(rec1, mtcars) baked2 <- bake(rec2, mtcars) expect_identical(baked1, baked2) }) test_that("empty selection tidy method works", { rec <- recipe(mpg ~ ., mtcars) rec <- step_discretize(rec) expect <- tibble(terms = character(), value = double(), id = character()) expect_identical(tidy(rec, number = 1), expect) rec <- prep(rec, mtcars) expect_identical(tidy(rec, number = 1), expect) }) test_that("printing", { rec <- recipe(~., data = ex_tr) %>% step_discretize(x1) expect_snapshot(print(rec)) expect_snapshot(prep(rec)) }) test_that("tunable is setup to work with extract_parameter_set_dials", { skip_if_not_installed("dials") rec <- recipe(~., data = mtcars) %>% step_discretize( all_predictors(), min_unique = hardhat::tune(), num_breaks = hardhat::tune() ) params <- extract_parameter_set_dials(rec) expect_s3_class(params, "parameters") expect_identical(nrow(params), 2L) })