withr::local_envvar("OMP_THREAD_LIMIT" = 1) test_that("condition inference trees", { skip_if_not_installed("partykit") skip_if_not_installed("modeldata") suppressPackageStartupMessages(library(partykit)) expect_snapshot( decision_tree() |> set_engine("partykit") |> set_mode("regression") ) expect_snapshot( decision_tree() |> set_engine("partykit", teststat = "maximum") |> set_mode("classification") ) # ---------------------------------------------------------------------------- # regression expect_no_error({ ct_fit_1 <- decision_tree() |> set_engine("partykit") |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) }) pk_fit_1 <- ctree(mpg ~ ., data = mtcars) expect_equal(pk_fit_1$fitted, ct_fit_1$fit$fitted) expect_no_error(ct_pred_1 <- predict(ct_fit_1, mtcars)$.pred) pk_pred_1 <- unname(predict(pk_fit_1, mtcars)) expect_equal(pk_pred_1, ct_pred_1) expect_no_error({ ct_fit_2 <- decision_tree(tree_depth = 1) |> set_engine("partykit") |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) }) pk_fit_2 <- ctree( mpg ~ ., data = mtcars, control = ctree_control(maxdepth = 1) ) expect_equal(pk_fit_2$fitted, ct_fit_2$fit$fitted) expect_no_error(ct_pred_2 <- predict(ct_fit_2, mtcars)$.pred) pk_pred_2 <- unname(predict(pk_fit_2, mtcars)) expect_equal(pk_pred_2, ct_pred_2) expect_no_error({ ct_fit_3 <- decision_tree() |> set_engine("partykit", mincriterion = .99) |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) }) pk_fit_3 <- ctree( mpg ~ ., data = mtcars, control = ctree_control(mincriterion = .99) ) expect_equal(pk_fit_3$fitted, ct_fit_3$fit$fitted) expect_no_error(ct_pred_3 <- predict(ct_fit_3, mtcars)$.pred) pk_pred_3 <- unname(predict(pk_fit_3, mtcars)) expect_equal(pk_pred_3, ct_pred_3) # ---------------------------------------------------------------------------- # classification data(ad_data, package = "modeldata") expect_no_error({ ct_fit_4 <- decision_tree() |> set_engine("partykit") |> set_mode("classification") |> fit(Class ~ ., data = ad_data) }) pk_fit_4 <- ctree(Class ~ ., data = ad_data) expect_equal(pk_fit_4$fitted, ct_fit_4$fit$fitted) expect_no_error(ct_pred_4 <- predict(ct_fit_4, ad_data)$.pred_class) pk_pred_4 <- unname(predict(pk_fit_4, ad_data)) expect_equal(pk_pred_4, ct_pred_4) expect_no_error(ct_prob_4 <- predict(ct_fit_4, ad_data, type = "prob")[[2]]) pk_prob_4 <- unname(predict(pk_fit_4, ad_data, type = "prob")[, 2]) expect_equal(pk_prob_4, ct_prob_4) }) test_that("condition inference forests", { skip_if_not_installed("partykit") skip_if_not_installed("modeldata") suppressPackageStartupMessages(library(partykit)) expect_snapshot( rand_forest() |> set_engine("partykit") |> set_mode("regression") ) expect_snapshot( rand_forest() |> set_engine("partykit", teststat = "maximum") |> set_mode("classification") ) # ---------------------------------------------------------------------------- # regression expect_no_error({ set.seed(1) cf_fit_1 <- rand_forest(trees = 5) |> set_engine("partykit") |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) }) set.seed(1) pk_fit_1 <- cforest(mpg ~ ., data = mtcars, ntree = 5) expect_equal(pk_fit_1$fitted, cf_fit_1$fit$fitted) expect_no_error(cf_pred_1 <- predict(cf_fit_1, mtcars)$.pred) pk_pred_1 <- unname(predict(pk_fit_1, mtcars)) expect_equal(pk_pred_1, cf_pred_1) expect_no_error({ set.seed(1) cf_fit_2 <- rand_forest(trees = 5, mtry = 2) |> set_engine("partykit") |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) }) set.seed(1) pk_fit_2 <- cforest(mpg ~ ., data = mtcars, ntree = 5, mtry = 2) expect_equal(pk_fit_2$fitted, cf_fit_2$fit$fitted) expect_no_error(cf_pred_2 <- predict(cf_fit_2, mtcars)$.pred) pk_pred_2 <- unname(predict(pk_fit_2, mtcars)) expect_equal(pk_pred_2, cf_pred_2) expect_no_error({ set.seed(1) cf_fit_3 <- rand_forest(trees = 5) |> set_engine("partykit", mincriterion = .99) |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) }) set.seed(1) pk_fit_3 <- cforest( mpg ~ ., data = mtcars, ntree = 5, control = ctree_control(mincriterion = .99) ) expect_equal(pk_fit_3$fitted, cf_fit_3$fit$fitted) expect_no_error(cf_pred_3 <- predict(cf_fit_3, mtcars)$.pred) pk_pred_3 <- unname(predict(pk_fit_3, mtcars)) expect_equal(pk_pred_3, cf_pred_3) # ---------------------------------------------------------------------------- # classification data(ad_data, package = "modeldata") expect_no_error({ set.seed(1) cf_fit_4 <- rand_forest(trees = 5) |> set_engine("partykit") |> set_mode("classification") |> fit(Class ~ ., data = ad_data) }) set.seed(1) pk_fit_4 <- cforest(Class ~ ., data = ad_data, ntree = 5) expect_equal(pk_fit_4$fitted, cf_fit_4$fit$fitted) expect_no_error(cf_pred_4 <- predict(cf_fit_4, ad_data)$.pred_class) pk_pred_4 <- unname(predict(pk_fit_4, ad_data)) expect_equal(pk_pred_4, cf_pred_4) expect_no_error(cf_prob_4 <- predict(cf_fit_4, ad_data, type = "prob")[[2]]) pk_prob_4 <- unname(predict(pk_fit_4, ad_data, type = "prob")[, 2]) expect_equal(pk_prob_4, cf_prob_4) })