test_that("returns the right output", { model <- earth::earth(mpg ~ ., data = mtcars) model$coefficients <- round(model$coefficients, 12) tf <- tidypredict_fit(model) pm <- parse_model(model) expect_type(tf, "language") expect_s3_class(pm, "list") expect_equal(length(pm), 2) expect_equal(pm$general$model, "earth") expect_equal(pm$general$version, 2) expect_snapshot( rlang::expr_text(tf) ) }) test_that("Model can be saved and re-loaded", { model <- earth::earth(mpg ~ ., data = mtcars) model$coefficients <- round(model$coefficients, 7) pm <- parse_model(model) mp <- tempfile(fileext = ".yml") yaml::write_yaml(pm, mp) l <- yaml::read_yaml(mp) pm <- as_parsed_model(l) expect_identical( tidypredict_fit(model), tidypredict_fit(pm) ) }) test_that("formulas produces correct predictions", { # Regression - numeric predictors expect_snapshot( tidypredict_test( earth::earth(age ~ sibsp + parch, data = earth::etitanic), earth::etitanic ) ) # Regression - numeric predictors, degree = 2 expect_snapshot( tidypredict_test( earth::earth(age ~ sibsp + parch, data = earth::etitanic, degree = 2), earth::etitanic ) ) # Regression - numeric predictors, degree = 3 expect_snapshot( tidypredict_test( earth::earth(age ~ sibsp + parch, data = earth::etitanic, degree = 3), earth::etitanic ) ) # Regression - numeric predictors and categorical predictors expect_snapshot( tidypredict_test( earth::earth(age ~ ., data = earth::etitanic), earth::etitanic ) ) # Regression - pmethod = "backwards" expect_snapshot( tidypredict_test( earth::earth(age ~ ., data = earth::etitanic, pmethod = "backward"), earth::etitanic ) ) # Regression - pmethod = "non" expect_snapshot( tidypredict_test( earth::earth(age ~ ., data = earth::etitanic, pmethod = "none"), earth::etitanic ) ) # Regression - pmethod = "exhaustive" expect_snapshot( tidypredict_test( earth::earth(age ~ ., data = earth::etitanic, pmethod = "exhaustive"), earth::etitanic ) ) # Regression - pmethod = "forward" expect_snapshot( tidypredict_test( earth::earth(age ~ ., data = earth::etitanic, pmethod = "forward"), earth::etitanic ) ) # Regression - pmethod = "seqrep" expect_snapshot( tidypredict_test( earth::earth(age ~ ., data = earth::etitanic, pmethod = "seqrep"), earth::etitanic ) ) # binomial expect_snapshot( tidypredict_test( earth::earth( survived ~ age + sibsp, data = earth::etitanic, glm = list(family = binomial) ), earth::etitanic ) ) # binomial - w/ degree expect_snapshot( tidypredict_test( earth::earth( survived ~ age + sibsp, data = earth::etitanic, glm = list(family = binomial), degree = 2 ), earth::etitanic ) ) # binomial - pmethod = "backwards" expect_snapshot( tidypredict_test( earth::earth( survived ~ ., data = earth::etitanic, glm = list(family = binomial), pmethod = "backward" ), earth::etitanic ) ) # binomial - pmethod = "non" expect_snapshot( tidypredict_test( earth::earth( survived ~ ., data = earth::etitanic, glm = list(family = binomial), pmethod = "none" ), earth::etitanic ) ) # binomial - pmethod = "exhaustive" expect_snapshot( tidypredict_test( earth::earth( survived ~ ., data = earth::etitanic, glm = list(family = binomial), pmethod = "exhaustive" ), earth::etitanic ) ) # binomial - pmethod = "forward" expect_snapshot( tidypredict_test( earth::earth( survived ~ ., data = earth::etitanic, glm = list(family = binomial), pmethod = "forward" ), earth::etitanic ) ) # binomial - pmethod = "seqrep" expect_snapshot( tidypredict_test( earth::earth( survived ~ ., data = earth::etitanic, glm = list(family = binomial), pmethod = "seqrep" ), earth::etitanic ) ) # formula interface expect_snapshot( tidypredict_test( earth::earth( Sepal.Length ~ ., data = iris ), iris ) ) # XY interface expect_snapshot( tidypredict_test( earth::earth( x = iris[, -1], y = iris$Sepal.Length ), iris ) ) # formula interface - degree = 2 expect_snapshot( tidypredict_test( earth::earth( Sepal.Length ~ ., data = iris, degree = 2, pmethod = "none" ), iris ) ) # XY interface - degree = 2 expect_snapshot( tidypredict_test( earth::earth( x = iris[, -1], y = iris$Sepal.Length, degree = 2, pmethod = "none" ), iris ) ) }) test_that("probit link works (#194)", { model <- earth::earth( survived ~ age + sibsp, data = earth::etitanic, glm = list(family = binomial(link = "probit")) ) fit <- tidypredict_fit(model) native <- unname(predict(model, earth::etitanic, type = "response")[, 1]) tidy <- rlang::eval_tidy(fit, earth::etitanic) # Uses Bowling et al. approximation to normal CDF expect_equal(tidy, native, tolerance = 0.001) }) test_that("cloglog link works (#194)", { model <- earth::earth( survived ~ age + sibsp, data = earth::etitanic, glm = list(family = binomial(link = "cloglog")) ) fit <- tidypredict_fit(model) native <- unname(predict(model, earth::etitanic, type = "response")[, 1]) tidy <- rlang::eval_tidy(fit, earth::etitanic) expect_equal(tidy, native) }) test_that("Gamma family works (#195)", { model <- earth::earth( mpg ~ cyl + disp + hp, data = mtcars, glm = list(family = Gamma) ) fit <- tidypredict_fit(model) native <- unname(predict(model, mtcars, type = "response")[, 1]) tidy <- rlang::eval_tidy(fit, mtcars) expect_equal(tidy, native) }) test_that("inverse.gaussian family works (#195)", { model <- earth::earth( mpg ~ cyl + disp + hp, data = mtcars, glm = list(family = inverse.gaussian) ) fit <- tidypredict_fit(model) native <- unname(predict(model, mtcars, type = "response")[, 1]) tidy <- rlang::eval_tidy(fit, mtcars) expect_equal(tidy, native) }) # Tests for .extract_earth_multiclass() test_that(".extract_earth_multiclass errors on non-earth model", { model <- lm(mpg ~ ., data = mtcars) expect_snapshot(error = TRUE, .extract_earth_multiclass(model)) }) test_that(".extract_earth_multiclass errors on binary model", { suppressWarnings( model <- earth::earth( vs ~ disp + hp, data = mtcars, glm = list(family = binomial) ) ) expect_snapshot(error = TRUE, .extract_earth_multiclass(model)) }) test_that(".extract_earth_multiclass errors on regression model", { model <- earth::earth(mpg ~ ., data = mtcars) expect_snapshot(error = TRUE, .extract_earth_multiclass(model)) }) test_that(".extract_earth_multiclass returns correct structure", { skip_if_not( exists("contr.earth.response", where = asNamespace("earth")), "earth multiclass not available" ) library(earth) suppressWarnings( model <- earth( Species ~ ., data = iris, glm = list(family = binomial) ) ) result <- .extract_earth_multiclass(model) expect_type(result, "list") expect_length(result, 3) expect_named(result, levels(iris$Species)) expect_type(result[[1]], "character") }) test_that(".extract_earth_multiclass produces correct predictions", { skip_if_not( exists("contr.earth.response", where = asNamespace("earth")), "earth multiclass not available" ) library(earth) suppressWarnings( model <- earth( Species ~ ., data = iris, glm = list(family = binomial) ) ) eqs <- .extract_earth_multiclass(model) n_rows <- nrow(iris) # Evaluate each expression - earth GLM outputs are already on probability scale # (not logits), so we don't apply softmax probs <- sapply(eqs, function(eq) { val <- rlang::eval_tidy(rlang::parse_expr(eq), iris) if (length(val) == 1) rep(val, n_rows) else val }) # Compare to native predictions native <- predict(model, iris, type = "response") expect_equal(unname(probs), unname(native), tolerance = 1e-6) }) test_that(".extract_earth_multiclass works with degree > 1", { skip_if_not( exists("contr.earth.response", where = asNamespace("earth")), "earth multiclass not available" ) library(earth) suppressWarnings( model <- earth( Species ~ ., data = iris, glm = list(family = binomial), degree = 2 ) ) result <- .extract_earth_multiclass(model) expect_type(result, "list") expect_length(result, 3) expect_named(result, levels(iris$Species)) })