test_that("recipe merges", { data("Chicago", package = "modeldata") spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% recipes::step_date(date) %>% recipes::step_holiday(date) %>% recipes::step_rm(date, dplyr::ends_with("away")) %>% recipes::step_impute_knn(recipes::all_predictors(), neighbors = tune("imputation")) %>% recipes::step_other(recipes::all_nominal(), threshold = tune()) %>% recipes::step_dummy(recipes::all_nominal()) %>% recipes::step_normalize(recipes::all_numeric_predictors()) %>% recipes::step_bs(recipes::all_predictors(), deg_free = tune(), degree = tune()) spline_grid <- tibble::tribble( ~imputation, ~threshold, ~deg_free, ~degree, 3L, 0.088, 14L, 1L, 6L, 0.058, 8L, 1L, 8L, 0.051, 14L, 1L, 9L, 0.007, 10L, 1L, 1L, 0.032, 15L, 2L, 8L, 0.018, 9L, 2L, 1L, 0.036, 5L, 1L, 10L, 0.1, 10L, 2L, 9L, 0.094, 12L, 1L, 4L, 0.025, 12L, 1L ) expect_error( spline_updated <- merge(spline_rec, spline_grid), NA ) check_merged_tibble(spline_updated) for (i in 1:nrow(spline_grid)) { expect_equal( spline_updated$x[[i]]$steps[[4]]$neighbors, spline_grid$imputation[[i]] ) expect_equal( spline_updated$x[[i]]$steps[[5]]$threshold, spline_grid$threshold[[i]] ) expect_equal( spline_updated$x[[i]]$steps[[8]]$deg_free, spline_grid$deg_free[[i]] ) expect_equal( spline_updated$x[[i]]$steps[[8]]$degree, spline_grid$degree[[i]] ) } }) test_that("partially recipe merge", { data("Chicago", package = "modeldata") spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% recipes::step_date(date) %>% recipes::step_holiday(date) %>% recipes::step_rm(date, dplyr::ends_with("away")) %>% recipes::step_impute_knn(recipes::all_predictors(), neighbors = tune("imputation")) %>% recipes::step_other(recipes::all_nominal(), threshold = tune()) %>% recipes::step_dummy(recipes::all_nominal()) %>% recipes::step_normalize(recipes::all_numeric_predictors()) %>% recipes::step_bs(recipes::all_predictors(), deg_free = tune(), degree = tune()) spline_grid <- tibble::tribble( ~imputation, ~threshold, ~deg_free, ~degree, 3L, 0.088, 14L, 1L, 6L, 0.058, 8L, 1L, 8L, 0.051, 14L, 1L, 9L, 0.007, 10L, 1L, 1L, 0.032, 15L, 2L, 8L, 0.018, 9L, 2L, 1L, 0.036, 5L, 1L, 10L, 0.1, 10L, 2L, 9L, 0.094, 12L, 1L, 4L, 0.025, 12L, 1L ) expect_error( spline_updated <- merge(spline_rec, spline_grid[, -1]), NA ) check_merged_tibble(spline_updated, complete = FALSE) for (i in 1:nrow(spline_grid)) { expect_equal( spline_updated$x[[i]]$steps[[4]]$neighbors, tune("imputation") ) expect_equal( spline_updated$x[[i]]$steps[[5]]$threshold, spline_grid$threshold[[i]] ) expect_equal( spline_updated$x[[i]]$steps[[8]]$deg_free, spline_grid$deg_free[[i]] ) expect_equal( spline_updated$x[[i]]$steps[[8]]$degree, spline_grid$degree[[i]] ) } }) test_that("umerged recipe merge", { data("Chicago", package = "modeldata") spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% recipes::step_date(date) %>% recipes::step_holiday(date) %>% recipes::step_rm(date, dplyr::ends_with("away")) %>% recipes::step_impute_knn(recipes::all_predictors(), neighbors = tune("imputation")) %>% recipes::step_other(recipes::all_nominal(), threshold = tune()) %>% recipes::step_dummy(recipes::all_nominal()) %>% recipes::step_normalize(recipes::all_numeric_predictors()) %>% recipes::step_bs(recipes::all_predictors(), deg_free = tune(), degree = tune()) bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2)) expect_error( spline_updated <- merge(spline_rec, bst_grid), NA ) check_merged_tibble(spline_updated, complete = FALSE) for (i in 1:nrow(bst_grid)) { expect_equal( spline_updated$x[[i]]$steps[[4]]$neighbors, tune("imputation") ) expect_equal( spline_updated$x[[i]]$steps[[5]]$threshold, tune() ) expect_equal( spline_updated$x[[i]]$steps[[8]]$deg_free, tune() ) expect_equal( spline_updated$x[[i]]$steps[[8]]$degree, tune() ) } }) # ------------------------------------------------------------------------------ test_that("model spec merges", { bst_model <- parsnip::boost_tree(mode = "classification", trees = tune("funky name \n")) %>% parsnip::set_engine("C5.0", rules = tune(), noGlobalPruning = TRUE) bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2)) expect_error( bst_updated <- merge(bst_model, bst_grid), NA ) check_merged_tibble(bst_updated, "model_spec") for (i in 1:nrow(bst_grid)) { expect_equal( bst_updated$x[[i]]$args$trees, rlang::as_quosure(bst_grid[["funky name \n"]][[i]], empty_env()) ) expect_equal( bst_updated$x[[i]]$eng_args$rules, rlang::as_quosure(bst_grid$rules[[i]], empty_env()) ) } # ensure that `grid` can handle list-columns bst_model_obj <- boost_tree(mode = "classification") %>% set_args(objective = tune()) bst_grid_obj <- tibble::tibble(objective = list("hey", "there")) bst_updated_obj <- merge(bst_model_obj, bst_grid_obj) expect_equal( rlang::quo_get_expr(bst_updated_obj$x[[1]]$eng_args$objective), "hey" ) expect_equal( rlang::quo_get_expr(bst_updated_obj$x[[2]]$eng_args$objective), "there" ) }) test_that("partially model spec merge", { bst_model <- parsnip::boost_tree(mode = "classification", trees = tune("funky name \n")) %>% parsnip::set_engine("C5.0", rules = tune(), noGlobalPruning = TRUE) bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2)) expect_error( bst_updated <- merge(bst_model, bst_grid[, -1]), NA ) check_merged_tibble(bst_updated, "model_spec", complete = FALSE) for (i in 1:nrow(bst_grid)) { expect_equal( rlang::get_expr(bst_updated$x[[i]]$args$trees), tune("funky name \n") ) expect_equal( bst_updated$x[[i]]$eng_args$rules, rlang::as_quosure(bst_grid$rules[[i]], empty_env()) ) } }) test_that("umerged model spec merge", { bst_model <- parsnip::boost_tree(mode = "classification", trees = tune("funky name \n")) %>% parsnip::set_engine("C5.0", rules = tune(), noGlobalPruning = TRUE) bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2)) other_grid <- bst_grid names(bst_grid) <- letters[1:2] expect_error( bst_not_updated <- merge(bst_model, other_grid), NA ) check_merged_tibble(bst_not_updated, "model_spec", complete = FALSE) # for (i in 1:nrow(other_grid)) { # if (i == 1) { # print(rlang::get_expr(bst_not_updated$x[[i]]$args$trees)) # print(rlang::get_expr(bst_not_updated$x[[i]]$eng_args$rules)) # } # expect_equal( # rlang::get_expr(bst_not_updated$x[[i]]$args$trees), tune("funky name \n") # ) # expect_equal( # rlang::get_expr(bst_not_updated$x[[i]]$eng_args$rules), tune() # ) # } })