context("Multiplicities") test_that("Multiplicity class and methods", { mp = Multiplicity(0) expect_multiplicity(mp) expect_equal(mp, as.Multiplicity(0)) expect_error(assert_multiplicity(0, .var.name = "x"), regexp = "inherit from class 'Multiplicity'") expect_multiplicity(assert_multiplicity(mp)) nmp = Multiplicity(Multiplicity(0)) expect_multiplicity(assert_multiplicity(nmp)) expect_multiplicity(assert_multiplicity(nmp, check_nesting = TRUE)) expect_error(assert_multiplicity(as.Multiplicity(list(0, Multiplicity(0))), .var.name = "y", check_nesting = TRUE), regexp = "Inconsistent multiplicity nesting level") expect_output(print(mp), regexp = "Multiplicity:") expect_output(print(Multiplicity()), regexp = "Empty Multiplicity.") }) test_that("multiplicity_nests_deeper_than", { expect_true(multiplicity_nests_deeper_than(Multiplicity(Multiplicity(1)), 1)) expect_false(multiplicity_nests_deeper_than(Multiplicity(Multiplicity(1)), 2)) # expect_true(is.na(multiplicity_nests_deeper_than(Multiplicity(Multiplicity()), 2))) # TODO: adapt this once #596 is fixed. expect_false(multiplicity_nests_deeper_than(Multiplicity(Multiplicity(), Multiplicity(1)), 2)) # expect_true(multiplicity_nests_deeper_than(Multiplicity(Multiplicity(), Multiplicity(Multiplicity())), 2)) # TODO: adapt this once #596 is fixed. }) test_that("PipeOp - assert_connection_table", { # toy PipeOp that should fail to be constructed PipeOpTestMultiplicitesTable = R6Class("PipeOpTestMultiplicitesTable", inherit = PipeOp, public = list( initialize = function(id = "multiplicitiestable", param_vals = list()) { super$initialize(id, param_vals = param_vals, input = data.table(name = "input", train = "[*]", predict = "*"), output = data.table(name = "output", train = "*", predict = "*"), tags = "multiplicity" ) } ) ) expect_error(PipeOpTestMultiplicitesTable$new(), regexp = "Multiplicity during train and predict conflicts") }) # FIXME: check_types in PipeOp test_that("PipeOp - multiplicity_type_nesting_level", { expect_equal(multiplicity_type_nesting_level(c("Task", "[Prediction]", "[[*]]")), c(0L, 1L, 2L)) expect_error(multiplicity_type_nesting_level("[wrong", varname = "test"), regexp = "square bracket mismatch") }) test_that("PipeOp - unpack_multiplicities", { expect_equal(unpack_multiplicities(list(a = Multiplicity(1, 2), b = 4), c(0, 0), c("a", "b"), "test"), list(list(a = 1, b = 4), list(a = 2, b = 4))) expect_equal(unpack_multiplicities(list(a = Multiplicity(x = 1, y = 2), b = 4), c(0, 0), c("a", "b"), "test"), list(x = list(a = 1, b = 4), y = list(a = 2, b = 4))) expect_equal(unpack_multiplicities(list(a = Multiplicity(x = 1, y = 2), b = Multiplicity(x = 10, y = 20)), c(0, 0), c("a", "b"), "test"), list(x = list(a = 1, b = 10), y = list(a = 2, b = 20))) expect_error(unpack_multiplicities(list(a = Multiplicity(x = 1, z = 2), b = Multiplicity(x = 10, y = 20)), c(0, 0), c("a", "b"), "test"), regexp = "bad multiplicities") expect_equal(unpack_multiplicities(list(a = Multiplicity(x = 1, z = 2), b = Multiplicity(x = 10, y = 20)), c(0, 1), c("a", "b"), "test"), list(x = list(a = 1, b = Multiplicity(x = 10, y = 20)), z = list(a = 2, b = Multiplicity(x = 10, y = 20)))) expect_equal(unpack_multiplicities(list(0), 0, "a", "test"), NULL) }) test_that("PipeOp - evaluate_multiplicities", { # toy PipeOp only for testing PipeOpTestMultiplicites = R6Class("PipeOpTestMultiplicites", inherit = PipeOp, public = list( initialize = function(num, id = "multiplicities", param_vals = list()) { assert_int(num, lower = 1L) ps = ParamSet$new(params = list( ParamUty$new("state", tags = "train") )) super$initialize(id, param_set = ps, param_vals = param_vals, input = data.table(name = rep_suffix("input", num), train = "*", predict = "*"), output = data.table(name = rep_suffix("output", num), train = "*", predict = "*"), tags = "multiplicity" ) } ), private = list( # allows to stop with an error on purpose .train = function(inputs) { if (self$param_set$values$state == "error") stop("Error.") self$state = self$param_set$values$state inputs }, .predict = function(inputs) { if (self$param_set$values$state == "error") stop("Error.") inputs } ) ) task = mlr_tasks$get("iris") po = PipeOpTestMultiplicites$new(2) expect_null(po$state) po$param_set$values$state = "trained" train_out1 = po$train(as.Multiplicity(list(0, as.Multiplicity(0)))) expect_multiplicity(train_out1[[1]]) expect_equal(po$state, as.Multiplicity(list("trained"))) predict_out1 = po$predict(as.Multiplicity(list(0, as.Multiplicity(0)))) expect_equal(po$state, as.Multiplicity(list("trained"))) expect_multiplicity(predict_out1[[1]]) po$state = list("no_multiplicties") expect_error(po$predict(as.Multiplicity(list(0, as.Multiplicity(0)))), regexp = "state was not a multiplicity") expect_equal(po$state, list("no_multiplicties")) po$state = as.Multiplicity(NULL) expect_error(po$predict(as.Multiplicity(list(0, as.Multiplicity(0)))), regexp = "state had different length / names than input") expect_equal(po$state, as.Multiplicity(NULL)) po$param_set$values$state = "trained" train_out2 = po$train(as.Multiplicity(list(0, as.Multiplicity(0)))) expect_multiplicity(train_out2[[1]]) old_state = po$state po$param_set$values$state = "error" expect_error(po$train(as.Multiplicity(list(0, as.Multiplicity(0)))), regexp = "Error") expect_equal(po$state, NULL) # state is completely reset to NULL }) test_that("Graph - add_edge", { learner = lrn("classif.rpart") g1 = PipeOpOVRSplit$new() %>>% learner %>>% PipeOpOVRUnite$new() g2 = Graph$new() g2$add_pipeop(PipeOpOVRSplit$new()) g2$add_pipeop(learner) g2$add_pipeop(PipeOpOVRUnite$new()) g2$add_edge("ovrsplit", "classif.rpart") g2$add_edge("classif.rpart", "ovrunite") expect_identical(g1$edges, g2$edges) }) test_that("Multiplicity checking", { p = po("pca") expect_error(p$train(list(x = 1)), "Assertion on 'input 1 \\(\"input\"\\) of PipeOp pca's \\$train\\(\\)") expect_task(p$train(list(x = tsk("iris")))[[1]]) expect_task(p$predict(list(x = tsk("iris")))[[1]]) expect_error(p$predict(list(x = 1)), "Assertion on 'input 1 \\(\"input\"\\) of PipeOp pca's \\$predict\\(\\)") p$output$predict = "numeric" expect_task(p$train(list(x = tsk("iris")))[[1]]) expect_error(p$predict(list(x = tsk("iris")))[[1]], "Assertion on 'output 1 \\(\"output\"\\) of PipeOp pca's \\$predict\\(\\)") p$output$train = "numeric" expect_error(p$train(list(x = tsk("iris")))[[1]], "Assertion on 'output 1 \\(\"output\"\\) of PipeOp pca's \\$train\\(\\)") })