context("IAITrees") test_that("common structure", { skip_on_cran() if (iai:::iai_version_less_than("2.0.0")) { X <- JuliaCall::julia_eval( "IAIConvert.convert_to_R(IAI.IAIBase.generate_mixed_data())" ) names(X) <- c("num_attempts", "score1", "score2", "score3", "num_children", "region") y <- X$score1 >= 50 & X$region %in% c("A", "B") | X$score1 < 50 & (X$score2 + 85 * X$score3 + 90 * (X$region == "E")) > 140 lnr <- iai::optimal_tree_classifier( random_seed = 1, max_depth = 2, cp = 0.04, hyperplane_config = list(sparsity = "all"), ) } else { if (iai:::iai_version_less_than("2.2.0")) { X <- JuliaCall::julia_eval( "IAIConvert.convert_to_R(IAI.IAIBase.generate_mixed_data(rng = IAI.IAIBase.make_rng(3)))" ) } else { X <- JuliaCall::julia_eval( "IAIConvert.convert_to_R(IAI.IAIBase.generate_mixed_data(rng = IAI.IAIBase.make_rng(5)))" ) } names(X) <- c("num_attempts", "score1", "score2", "score3", "num_children", "region") y <- X$score1 >= 60 & X$region %in% c("A", "B") | X$score1 < 60 & (X$score2 + 85 * X$score3 + 90 * (X$region == "E")) > 140 lnr <- iai::optimal_tree_classifier( random_seed = 1, max_depth = 2, cp = 0.01, hyperplane_config = list(sparsity = "all"), ) } iai::fit(lnr, X, y) expect_equal(iai::get_num_nodes(lnr), 7) expect_equal(iai::is_leaf(lnr, 1), FALSE) expect_equal(iai::get_depth(lnr, 6), 2) if (iai:::iai_version_less_than("2.0.0")) { expect_equal(iai::get_num_samples(lnr, 6), 97) } else if (iai:::iai_version_less_than("2.2.0")) { expect_equal(iai::get_num_samples(lnr, 6), 72) } else { expect_equal(iai::get_num_samples(lnr, 6), 78) } expect_equal(iai::get_parent(lnr, 2), 1) expect_equal(iai::get_lower_child(lnr, 1), 2) expect_equal(iai::get_upper_child(lnr, 1), 5) expect_equal(iai::is_parallel_split(lnr, 1), TRUE) expect_equal(iai::is_hyperplane_split(lnr, 2), TRUE) expect_equal(iai::is_categoric_split(lnr, 5), TRUE) expect_equal(iai::is_ordinal_split(lnr, 1), FALSE) expect_equal(iai::is_mixed_parallel_split(lnr, 2), FALSE) expect_equal(iai::is_mixed_ordinal_split(lnr, 5), FALSE) expect_equal(iai::missing_goes_lower(lnr, 1), FALSE) expect_equal(iai::get_split_feature(lnr, 1), as.symbol("score1")) if (iai:::iai_version_less_than("2.0.0")) { expect_equal(iai::get_split_threshold(lnr, 1), 50, tolerance = 0.5) } else { expect_equal(iai::get_split_threshold(lnr, 1), 60, tolerance = 0.5) } expect_mapequal(iai::get_split_categories(lnr, 5), list( A = TRUE, B = TRUE, C = FALSE, D = FALSE, E = FALSE )) weights <- iai::get_split_weights(lnr, 2) if (iai:::iai_version_less_than("2.0.0")) { expect_mapequal(weights$numeric, list(score2 = 0.010100076620278502, score3 = 2.0478494324868732)) expect_mapequal(weights$categoric, list(region = list(E = 1.5176596636410404))) } else if (iai:::iai_version_less_than("2.2.0")) { expect_mapequal(weights$numeric, list(score2 = 0.0012369248211116827, score3 = 0.09806740780674195)) expect_mapequal(weights$categoric, list(region = list(E = 0.10571515793193487))) } else { expect_mapequal(weights$numeric, list(score2 = 0.018901518025769143, score3 = 1.2041462082802483)) expect_mapequal(weights$categoric, list(region = list(E = 1.4792242450156097))) } }) test_that("classification structure", { skip_on_cran() lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_iris_tree(random_seed=1)" ) lnr <- iai:::set_obj_class(lnr) expect_equal(iai::get_classification_label(lnr, 2), "setosa") expect_mapequal(iai::get_classification_proba(lnr, 4), list( virginica = 0.09259259259259259, setosa = 0.0, versicolor = 0.9074074074074074 )) expect_error(iai::get_classification_label(lnr, 1)) expect_error(iai::get_classification_proba(lnr, 1)) if (!iai:::iai_version_less_than("2.1.0")) { iai::get_classification_label(lnr, 1, check_leaf = FALSE) iai::get_classification_proba(lnr, 1, check_leaf = FALSE) } if (iai:::iai_version_less_than("3.0.0")) { expect_error(iai::get_regression_constant(lnr, 2), "requires IAI version 3.0.0") expect_error(iai::get_regression_weights(lnr, 2), "requires IAI version 3.0.0") } else { expect_equal(iai::get_regression_constant(lnr, 2), NaN) weights <- iai::get_regression_weights(lnr, 2) expect_equal(length(weights$numeric), 0) expect_equal(length(weights$categoric), 0) } }) test_that("regression structure", { skip_on_cran() if (iai:::iai_version_less_than("2.1.0")) { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_mtcars_tree(random_seed=1, regression_sparsity=\"all\", regression_lambda=0.2)" ) } else if (iai:::iai_version_less_than("3.1.0")) { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_mtcars_tree(random_seed=1, regression_sparsity=\"all\", regression_lambda=0.02)" ) } else { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_mtcars_tree(random_seed=1, regression_features=Set([\"All\"]), regression_lambda=0.02)" ) } lnr <- iai:::set_obj_class(lnr) expect_equal(iai::get_regression_constant(lnr, 2), 30.879999999999995) if (iai:::iai_version_less_than("2.1.0")) { expect_equal(iai::get_regression_constant(lnr, 3), 26.56192034262967) weights <- iai::get_regression_weights(lnr, 3) expect_mapequal(weights$numeric, list(Disp = -0.021044493648366, HP = -0.018861409939436)) expect_true(is.list(weights$categoric) && length(weights$categoric) == 0) } else { expect_equal(iai::get_regression_constant(lnr, 3), 30.887599089534906) weights <- iai::get_regression_weights(lnr, 3) expect_mapequal(weights$numeric, list(Cyl = -0.794565711367838, Gear = 0.058519556715652, HP = -0.012667192837728, WT = -1.649738918131852)) expect_true(is.list(weights$categoric) && length(weights$categoric) == 0) } expect_error(iai::get_regression_constant(lnr, 1)) expect_error(iai::get_regression_weights(lnr, 1)) if (!iai:::iai_version_less_than("2.1.0")) { iai::get_regression_constant(lnr, 1, check_leaf = FALSE) iai::get_regression_weights(lnr, 1, check_leaf = FALSE) } }) test_that("survival structure", { skip_on_cran() if (iai:::iai_version_less_than("2.0.0")) { iai::set_julia_seed(4) lnr <- JuliaCall::julia_eval("IAI.OptimalTrees.load_survival_tree()") } else { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_survival_tree(random_seed=1, max_depth=1, cp=0)") } lnr <- iai:::set_obj_class(lnr) curve <- iai::get_survival_curve(lnr, 2) expect_equal(class(curve), c( "survival_curve", "IAIObject", "JuliaObject" )) if (iai:::iai_version_less_than("2.2.0")) { expect_error(iai::predict_expected_survival_time(curve), "requires IAI version 2.2.0") } else { expect_true(is.numeric(iai::predict_expected_survival_time(curve))) } curve_data <- iai::get_survival_curve_data(curve) if (iai:::iai_version_less_than("2.0.0")) { expect_equal(curve_data$coefs, c( 0.00000000, 0.02380952, 0.02597403, 0.02813853, 0.03030303, 0.03246753, 0.03463203, 0.03679654, 0.03896104, 0.04112554, 0.04329004, 0.04545455, 0.06734007, 0.08922559, 0.11111111, 0.13299663, 0.15488215, 0.17676768, 0.19865320, 0.22053872, 0.24242424)) expect_equal(curve_data$times, c( 00000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000)) } else if (iai:::iai_version_less_than("2.2.0")) { expect_equal(curve_data$coefs, c( 0.000000, 0.003472, 0.024306, 0.066330, 0.098207, 0.112815, 0.123968, 0.150682, 0.193501, 0.225786, 0.260612, 0.316929, 0.377445, 0.427421, 0.462996, 0.501441, 0.547923, 0.575877, 0.618289, 0.685522, 0.749382, 0.862943), tolerance = 1e-6) expect_equal(curve_data$times, c( 00000, 06000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000)) } else { expect_equal(curve_data$coefs, c( 0.000000, 0.005814, 0.023256, 0.052538, 0.076310, 0.094997, 0.114033, 0.140796, 0.168202, 0.190038, 0.220994, 0.271124, 0.340665, 0.395205, 0.433004, 0.482354, 0.523395, 0.546250, 0.597468, 0.669475, 0.727317, 0.845066), tolerance = 1e-6) expect_equal(curve_data$times, c( 00000, 7000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000)) } if (iai:::iai_version_less_than("2.1.0")) { expect_error(iai::get_survival_expected_time(), "requires IAI version 2.1.0") expect_error(iai::get_survival_hazard(), "requires IAI version 2.1.0") } else if (iai:::iai_version_less_than("2.2.0")) { expect_equal(iai::get_survival_expected_time(lnr, 2), 22981.39) expect_equal(iai::get_survival_hazard(lnr, 2), 0.9541041, tolerance = 1e-6) } else { expect_equal(iai::get_survival_expected_time(lnr, 2), 23443.187) expect_equal(iai::get_survival_hazard(lnr, 2), 0.8880508, tolerance = 1e-6) } expect_error(iai::get_survival_curve(lnr, 1)) expect_error(iai::get_survival_expected_time(lnr, 1)) expect_error(iai::get_survival_hazard(lnr, 1)) if (!iai:::iai_version_less_than("2.1.0")) { iai::get_survival_curve(lnr, 1, check_leaf = FALSE) iai::get_survival_expected_time(lnr, 1, check_leaf = FALSE) iai::get_survival_hazard(lnr, 1, check_leaf = FALSE) } if (iai:::iai_version_less_than("3.0.0")) { expect_error(iai::get_regression_constant(lnr, 2), "requires IAI version 3.0.0") expect_error(iai::get_regression_weights(lnr, 2), "requires IAI version 3.0.0") } else { expect_equal(iai::get_regression_constant(lnr, 2), NaN) weights <- iai::get_regression_weights(lnr, 2) expect_equal(length(weights$numeric), 0) expect_equal(length(weights$categoric), 0) } }) test_that("prescription structure", { skip_on_cran() if (iai:::iai_version_less_than("2.0.0")) { iai::set_julia_seed(2) lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\", regression_lambda=0.22, max_depth=2)" ) } else if (iai:::iai_version_less_than("2.1.0")) { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\", regression_lambda=0.22, max_depth=2, random_seed=1)" ) } else if (iai:::iai_version_less_than("2.2.0")) { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\", regression_weighted_betas=true, regression_lambda=1.9, max_depth=2, random_seed=1)" ) } else if (iai:::iai_version_less_than("3.1.0")) { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\", regression_weighted_betas=true, regression_lambda=1.9, max_depth=2, random_seed=2)" ) } else { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_prescription_tree( regression_features=Set([\"All\"]), regression_weighted_betas=true, regression_lambda=1.9, max_depth=2, random_seed=2, )" ) } lnr <- iai:::set_obj_class(lnr) if (iai:::iai_version_less_than("2.0.0")) { weights <- iai::get_regression_weights(lnr, 5, 1) expect_mapequal(weights$numeric, list(Disp = -0.007198454096246)) expect_true(is.list(weights$categoric) && length(weights$categoric) == 0) expect_equal(iai::get_prescription_treatment_rank(lnr, 2), c(1, 0)) expect_equal(iai::get_regression_constant(lnr, 2, 0), 30.5) } else if (iai:::iai_version_less_than("2.1.0")) { weights <- iai::get_regression_weights(lnr, 5, 0) expect_mapequal(weights$numeric, list(Disp = -0.00853409230131, AM = 1.316408317777783)) expect_true(is.list(weights$categoric) && length(weights$categoric) == 0) expect_equal(iai::get_prescription_treatment_rank(lnr, 5), c(1, 0)) expect_equal(iai::get_regression_constant(lnr, 5, 0), 18.507454507299066) } else if (iai:::iai_version_less_than("2.2.0")) { weights <- iai::get_regression_weights(lnr, 4, 0) expect_mapequal(weights$numeric, list(Cyl = -0.189847291283807)) expect_true(is.list(weights$categoric) && length(weights$categoric) == 0) expect_equal(iai::get_prescription_treatment_rank(lnr, 4), c(0, 1)) expect_equal(iai::get_regression_constant(lnr, 4, 0), 20.7970532059596) } else { weights <- iai::get_regression_weights(lnr, 2, 0) expect_mapequal(weights$numeric, list(Cyl = -1.377692110219233)) expect_true(is.list(weights$categoric) && length(weights$categoric) == 0) expect_equal(iai::get_prescription_treatment_rank(lnr, 2), c(0, 1)) expect_equal(iai::get_regression_constant(lnr, 2, 0), 28.682819327982067) } expect_error(iai::get_prescription_treatment_rank(lnr, 1)) expect_error(iai::get_regression_constant(lnr, 1, 0)) expect_error(iai::get_regression_weights(lnr, 1, 0)) if (!iai:::iai_version_less_than("2.1.0")) { iai::get_prescription_treatment_rank(lnr, 1, check_leaf = FALSE) iai::get_regression_constant(lnr, 1, 0, check_leaf = FALSE) iai::get_regression_weights(lnr, 1, 0, check_leaf = FALSE) } }) test_that("policy structure", { skip_on_cran() if (iai:::iai_version_less_than("2.0.0")) { expect_error(iai::get_policy_treatment_rank(), "requires IAI version 2.0.0") } else { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_policy_tree(max_depth=2, random_seed=1)" ) lnr <- iai:::set_obj_class(lnr) expect_equal(iai::get_policy_treatment_rank(lnr, 3), c("A", "C", "B")) } if (iai:::iai_version_less_than("2.1.0")) { expect_error(iai::get_policy_treatment_outcome(), "requires IAI version 2.1.0") } else { outcomes <- iai::get_policy_treatment_outcome(lnr, 3) if (iai:::iai_version_less_than("2.2.0")) { expect_equal(outcomes$A, 0.8276032, tolerance = 1e-6) expect_equal(outcomes$B, 1.698339, tolerance = 1e-6) expect_equal(outcomes$C, 1.096775, tolerance = 1e-6) } else { expect_equal(outcomes$A, 0.827778, tolerance = 1e-6) expect_equal(outcomes$B, 1.70248, tolerance = 1e-5) expect_equal(outcomes$C, 1.09849, tolerance = 1e-5) } } if (iai:::iai_version_less_than("3.2.0")) { expect_error(iai::get_policy_treatment_outcome_standard_error(), "requires IAI version 3.2.0") } else { errors <- iai::get_policy_treatment_outcome_standard_error(lnr, 3) expect_equal(errors$A, 0.0777876, tolerance = 1e-5) expect_equal(errors$B, 0.083841, tolerance = 1e-5) expect_equal(errors$C, 0.10806, tolerance = 1e-5) } expect_error(iai::get_policy_treatment_rank(lnr, 1)) expect_error(iai::get_policy_treatment_outcome(lnr, 1)) expect_error(iai::get_policy_treatment_outcome_standard_error(lnr, 1)) if (!iai:::iai_version_less_than("2.1.0")) { iai::get_policy_treatment_rank(lnr, 1, check_leaf = FALSE) iai::get_policy_treatment_outcome(lnr, 1, check_leaf = FALSE) } if (!iai:::iai_version_less_than("3.2.0")) { iai::get_policy_treatment_outcome_standard_error(lnr, 1, check_leaf = FALSE) } }) test_that("visualization", { skip_on_cran() lnr <- JuliaCall::julia_eval("IAI.OptimalTrees.load_iris_tree()") lnr <- iai:::set_obj_class(lnr) extra_content <- replicate(iai::get_num_nodes(lnr), list("node_color" = "#FFFFFF"), FALSE) if (!iai:::iai_version_less_than("3.1.0")) { if (!JuliaCall::julia_eval("IAI.IAITrees.has_graphviz()")) { iai::load_graphviz() } expect_true(JuliaCall::julia_eval("IAI.IAITrees.has_graphviz()")) } if (JuliaCall::julia_eval("IAI.IAITrees.has_graphviz()")) { iai::write_png("test.png", lnr) expect_true(file.exists("test.png")) file.remove("test.png") if (iai:::iai_version_less_than("2.1.0")) { } else { iai::write_png("test.png", lnr, extra_content = extra_content) expect_true(file.exists("test.png")) file.remove("test.png") } if (iai:::iai_version_less_than("2.1.0")) { error_message <- "requires IAI version 2.1.0" expect_error(iai::write_pdf("test.pdf", lnr), error_message) expect_error(iai::write_svg("test.svg", lnr), error_message) } else { iai::write_pdf("test.pdf", lnr) expect_true(file.exists("test.pdf")) file.remove("test.pdf") iai::write_pdf("test.pdf", lnr, extra_content = extra_content) expect_true(file.exists("test.pdf")) file.remove("test.pdf") iai::write_svg("test.svg", lnr) expect_true(file.exists("test.svg")) file.remove("test.svg") iai::write_svg("test.svg", lnr, extra_content = extra_content) expect_true(file.exists("test.svg")) file.remove("test.svg") } } iai::write_dot("test.dot", lnr) expect_true(file.exists("test.dot")) file.remove("test.dot") if (iai:::iai_version_less_than("2.1.0")) { } else { iai::write_dot("test.dot", lnr, extra_content = extra_content) expect_true(file.exists("test.dot")) file.remove("test.dot") } iai::write_html("tree.html", lnr) expect_true(file.exists("tree.html")) lines <- readLines("tree.html") expect_false(length(grep("\"Target\"", lines, value = TRUE)) > 0) expect_false(length(grep("\"Results\"", lines, value = TRUE)) > 0) file.remove("tree.html") iai::write_html("tree.html", lnr, extra_content = extra_content) expect_true(file.exists("tree.html")) lines <- readLines("tree.html") file.remove("tree.html") iai::write_questionnaire("question.html", lnr) expect_true(file.exists("question.html")) file.remove("question.html") if (iai:::iai_version_less_than("1.1.0")) { expect_error(iai::tree_plot(lnr), "requires IAI version 1.1.0") expect_error(iai::questionnaire(lnr), "requires IAI version 1.1.0") expect_error(iai::multi_tree_plot(list()), "requires IAI version 1.1.0") expect_error(iai::multi_questionnaire(list()), "requires IAI version 1.1.0") } else { feature_renames <- list( "PetalLength" = "A", "PetalWidth" = "B", "SepalWidth" = "C" ) vis <- iai::tree_plot(lnr, feature_renames = feature_renames) expect_equal(class(vis), c( "tree_plot", "abstract_visualization", "IAIObject", "JuliaObject" )) iai::write_html("tree_rename.html", vis) expect_true(file.exists("tree_rename.html")) file.remove("tree_rename.html") vis <- iai::questionnaire(lnr, feature_renames = feature_renames) expect_equal(class(vis), c( "questionnaire", "abstract_visualization", "IAIObject", "JuliaObject" )) iai::write_html("questionnaire_rename.html", vis) expect_true(file.exists("questionnaire_rename.html")) file.remove("questionnaire_rename.html") questions <- list("Use learner with" = list( "renamed features" = lnr, "extra text output" = lnr )) vis <- iai::multi_tree_plot(questions) expect_equal(class(vis), c( "multi_tree_plot", "abstract_visualization", "IAIObject", "JuliaObject" )) iai::write_html("multitree.html", vis) expect_true(file.exists("multitree.html")) file.remove("multitree.html") vis <- iai::multi_questionnaire(questions) expect_equal(class(vis), c( "multi_questionnaire", "abstract_visualization", "IAIObject", "JuliaObject" )) iai::write_html("multiquestion.html", vis) expect_true(file.exists("multiquestion.html")) file.remove("multiquestion.html") } X <- iris[, 1:4] y <- iris$Species grid <- iai::grid_search( iai::optimal_tree_classifier( random_seed = 1, max_depth = 1, ), ) iai::fit(grid, X, y) if (iai:::iai_version_less_than("2.0.0")) { expect_error(iai::write_html("grid.html", grid), "requires IAI version 2.0.0") expect_error(iai::write_questionnaire("grid.html", grid), "requires IAI version 2.0.0") expect_error(iai::show_in_browser(grid), "requires IAI version 2.0.0") expect_error(iai::show_questionnaire(grid), "requires IAI version 2.0.0") } else { vis <- iai::multi_tree_plot(grid) expect_equal(class(vis), c( "multi_tree_plot", "abstract_visualization", "IAIObject", "JuliaObject" )) iai::write_html("multitree.html", vis) expect_true(file.exists("multitree.html")) file.remove("multitree.html") vis <- iai::multi_questionnaire(grid) expect_equal(class(vis), c( "multi_questionnaire", "abstract_visualization", "IAIObject", "JuliaObject" )) iai::write_html("multiquestion.html", vis) expect_true(file.exists("multiquestion.html")) file.remove("multiquestion.html") } # Data visualization if (iai:::iai_version_less_than("2.1.0")) { } else { X <- iris[, 1:4] y <- iris$Species grid <- iai::grid_search( iai::optimal_tree_classifier( random_seed = 1, max_depth = 1, ), ) iai::fit(grid, X, y) lnr <- iai::get_learner(grid) iai::write_html("tree_with_data.html", lnr, data = list(X, y)) lines <- readLines("tree_with_data.html") expect_true(length(grep("\"Target\"", lines, value = TRUE)) > 0) expect_true(length(grep("\"Results\"", lines, value = TRUE)) > 0) file.remove("tree_with_data.html") if (iai:::iai_version_less_than("2.2.0")) { expect_error(iai::write_html("tree_with_data.html", lnr, data = X)) } else { iai::write_html("tree_with_data.html", lnr, data = X) lines <- readLines("tree_with_data.html") expect_false(length(grep("\"Target\"", lines, value = TRUE)) > 0) expect_true(length(grep("\"Results\"", lines, value = TRUE)) > 0) file.remove("tree_with_data.html") } } }) test_that("tree API", { skip_on_cran() X <- iris[, 1:4] y <- iris$Species lnr <- iai::optimal_tree_classifier(max_depth = 1, cp = 0) iai::fit(lnr, X, y) expect_equal(length(iai::apply(lnr, X)), length(y)) expect_equal(length(iai::apply_nodes(lnr, X)), iai::get_num_nodes(lnr)) path <- iai::decision_path(lnr, X) expect_equal(nrow(path), length(y)) expect_equal(ncol(path), iai::get_num_nodes(lnr)) iai::print_path(lnr, X, 1) expect_true(is.data.frame(iai::variable_importance(lnr))) if (iai:::iai_version_less_than("2.2.0")) { expect_error(iai::get_features_used(lnr), "requires IAI version 2.2.0") } else { expect_true(is.vector(iai::get_features_used(lnr))) } }) test_that("classification tree API", { skip_on_cran() X <- iris[, 1:4] y <- iris$Species == "setosa" lnr <- iai::optimal_tree_classifier(max_depth = 1, cp = 0) iai::fit(lnr, X, y) expect_true(is.data.frame(iai::predict_proba(lnr, X))) expect_equal(iai::get_num_nodes(lnr), 3) iai::set_threshold(lnr, TRUE, 0, simplify = TRUE) expect_equal(iai::get_num_nodes(lnr), 1) iai::set_display_label(lnr, TRUE) expect_true(grepl("true)", print(lnr))) iai::reset_display_label(lnr) expect_false(grepl("true)", print(lnr))) }) test_that("survival tree API", { skip_on_cran() lnr <- iai::optimal_tree_survival_learner(max_depth = 1, cp = 0) n <- 100 X <- matrix(rnorm(200), n, 2) died <- rbinom(n, 1, 0.5) == 1 times <- runif(n) iai::fit(lnr, X, died, times) if (iai:::iai_version_less_than("1.2.0")) { expect_error(iai::predict_hazard(lnr, X), "requires IAI version 1.2.0") } else { expect_equal(length(iai::predict_hazard(lnr, X)), n) } if (iai:::iai_version_less_than("2.0.0")) { expect_error(iai::predict_expected_survival_time(lnr, X), "requires IAI version 2.0.0") } else { expect_equal(length(iai::predict_expected_survival_time(lnr, X)), n) } }) test_that("prescription tree API", { skip_on_cran() for (f in c(iai::optimal_tree_prescription_minimizer, iai::optimal_tree_prescription_maximizer)) { lnr <- f(max_depth = 1, cp = 0) X <- matrix(rnorm(200), 100, 2) treatments <- rbinom(100, 1, 0.5) outcomes <- runif(100) iai::fit(lnr, X, treatments, outcomes) expect_true(is.data.frame(iai::predict_outcomes(lnr, X))) pred <- iai::predict(lnr, X) expect_true(is.list(pred)) expect_equal(names(pred), c("treatments", "outcomes")) } }) test_that("policy tree API", { skip_on_cran() if (iai:::iai_version_less_than("2.0.0")) { expect_error(iai::optimal_tree_policy_minimizer(), "requires IAI version 2.0.0") expect_error(iai::optimal_tree_policy_maximizer(), "requires IAI version 2.0.0") } else { for (f in c(iai::optimal_tree_policy_minimizer, iai::optimal_tree_policy_maximizer)) { lnr <- f(max_depth = 1, cp = 0) X <- matrix(rnorm(200), 100, 2) rewards <- matrix(rnorm(200), 100, 2) iai::fit(lnr, X, rewards) expect_true(is.vector(iai::predict(lnr, X))) expect_true(is.vector(iai::predict_outcomes(lnr, X, rewards))) } } }) test_that("stability", { skip_on_cran() X <- iris[, 1:4] y <- iris$Species == "setosa" lnr <- iai::optimal_tree_classifier(max_depth = 1, cp = 0) iai::fit(lnr, X, y) if (iai:::iai_version_less_than("2.2.0")) { expect_error(iai::stability_analysis(lnr, X, y), "requires IAI version 2.2.0") } else { stability <- iai::stability_analysis(lnr, X, y) expect_equal(class(stability), c( "stability_analysis", "abstract_visualization", "IAIObject", "JuliaObject" )) } deviations <- runif(4) if (iai:::iai_version_less_than("2.2.0")) { expect_error(iai::similarity_comparison(lnr, lnr, deviations), "requires IAI version 2.2.0") } else { similarity <- iai::similarity_comparison(lnr, lnr, deviations) expect_equal(class(similarity), c( "similarity_comparison", "abstract_visualization", "IAIObject", "JuliaObject" )) } }) test_that("multi classification structure", { skip_on_cran() if (iai:::iai_version_less_than("3.2.0")) { expect_error(iai::optimal_tree_multi_classifier(), "requires IAI version 3.2.0") } else { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_iris_tree_multi(random_seed=1, max_depth=1)" ) lnr <- iai:::set_obj_class(lnr) label_all <- iai::get_classification_label(lnr, 3) expect_true(is.list(label_all)) label_single <- iai::get_classification_label(lnr, 3, "y1") expect_true(is.character(label_single)) expect_equal(label_all$y1, label_single) proba_all <- iai::get_classification_proba(lnr, 3) expect_true(is.list(proba_all)) proba_single <- iai::get_classification_proba(lnr, 3, "y1") expect_true(is.list(proba_single)) expect_equal(proba_all$y1, proba_single) const_all <- iai::get_regression_constant(lnr, 3) expect_true(is.list(const_all)) const_single <- iai::get_regression_constant(lnr, 3, "y1") expect_true(is.nan(const_single)) expect_true(is.nan(const_all$y1)) weights_all <- iai::get_regression_weights(lnr, 3) expect_true(is.list(weights_all)) weights_single <- iai::get_regression_weights(lnr, 3, "y1") expect_true(is.list(weights_single)) expect_equal(length(weights_single), 2) expect_true(is.list(weights_single$numeric)) expect_true(is.list(weights_single$categoric)) expect_equal(length(weights_single$numeric), 0) expect_equal(length(weights_single$categoric), 0) expect_equal(weights_all$y1, weights_single) } }) test_that("multi regression structure", { skip_on_cran() if (iai:::iai_version_less_than("3.2.0")) { expect_error(iai::optimal_tree_multi_regressor(), "requires IAI version 3.2.0") } else { lnr <- JuliaCall::julia_eval( "IAI.OptimalTrees.load_mtcars_tree_multi( random_seed=1, regression_lambda=0.01, regression_features=Set([\"All\"]), max_depth=1, )" ) lnr <- iai:::set_obj_class(lnr) const_all <- iai::get_regression_constant(lnr, 1) expect_true(is.list(const_all)) const_single <- iai::get_regression_constant(lnr, 1, "MPG") expect_true(is.numeric(const_single)) expect_equal(const_all$MPG, const_single) weights_all <- iai::get_regression_weights(lnr, 1) expect_true(is.list(weights_all)) weights_single <- iai::get_regression_weights(lnr, 1, "MPG") expect_true(is.list(weights_single)) expect_equal(length(weights_single), 2) expect_true(is.list(weights_single$numeric)) expect_true(is.list(weights_single$categoric)) expect_equal(weights_all$MPG, weights_single) } })