test_that("returns the right output", { model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = 1) tf <- tidypredict_fit(model) pm <- parse_model(model) expect_type(tf, "language") expect_s3_class(pm, "list") expect_equal(length(pm), 2) expect_equal(pm$general$model, "glmnet") expect_equal(pm$general$version, 1) expect_snapshot( rlang::expr_text(tf) ) }) test_that("Model can be saved and re-loaded", { model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = 1) pm <- parse_model(model) mp <- tempfile(fileext = ".yml") yaml::write_yaml(pm, mp) l <- yaml::read_yaml(mp) pm <- as_parsed_model(l) expect_identical( round_print(tidypredict_fit(model)), round_print(tidypredict_fit(pm)) ) }) test_that("formulas produces correct predictions", { # gaussian expect_snapshot( tidypredict_test( glmnet::glmnet(mtcars[, -1], mtcars$mpg, family = "gaussian", lambda = 1), mtcars[, -1] ) ) # binomial expect_snapshot( tidypredict_test( glmnet::glmnet(mtcars[, -8], mtcars$vs, family = "binomial", lambda = 1), mtcars[, -1] ) ) # poisson expect_snapshot( tidypredict_test( glmnet::glmnet(mtcars[, -8], mtcars$vs, family = "poisson", lambda = 1), mtcars[, -1] ) ) }) test_that("family function syntax works (#197)", { x <- as.matrix(mtcars[, -1]) # gaussian() model <- glmnet::glmnet(x, mtcars$mpg, family = gaussian(), lambda = 0.5) expect_no_error(tidypredict_fit(model)) # binomial() model <- glmnet::glmnet(x, mtcars$am, family = binomial(), lambda = 0.5) expect_no_error(tidypredict_fit(model)) # poisson() model <- glmnet::glmnet(x, mtcars$carb, family = poisson(), lambda = 0.5) expect_no_error(tidypredict_fit(model)) }) test_that("family string syntax works (#197)", { x <- as.matrix(mtcars[, -1]) # "gaussian" model <- glmnet::glmnet(x, mtcars$mpg, family = "gaussian", lambda = 0.5) expect_no_error(tidypredict_fit(model)) # "binomial" model <- glmnet::glmnet(x, mtcars$am, family = "binomial", lambda = 0.5) expect_no_error(tidypredict_fit(model)) # "poisson" model <- glmnet::glmnet(x, mtcars$carb, family = "poisson", lambda = 0.5) expect_no_error(tidypredict_fit(model)) }) test_that("errors if more than 1 penalty is selected", { model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg) expect_snapshot( error = TRUE, tidypredict_fit(model) ) model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = c(1, 5)) expect_snapshot( error = TRUE, tidypredict_fit(model) ) }) test_that("glmnet are handeld neatly with parsnip", { spec <- parsnip::linear_reg(engine = "glmnet", penalty = 1) model <- parsnip::fit(spec, mpg ~ ., mtcars) tf <- tidypredict_fit(model) pm <- parse_model(model) expect_type(tf, "language") expect_s3_class(pm, "list") expect_equal(length(pm), 2) expect_equal(pm$general$model, "glmnet") expect_equal(pm$general$version, 1) expect_snapshot( rlang::expr_text(tf) ) }) test_that("Gamma family works (#200)", { x <- as.matrix(mtcars[, -1]) model <- glmnet::glmnet(x, mtcars$mpg, family = Gamma(), lambda = 0.5) fit <- tidypredict_fit(model) native <- unname(predict(model, x, type = "response")[, 1]) tidy <- rlang::eval_tidy(fit, mtcars) expect_equal(tidy, native) }) test_that("Cox family works (#201)", { skip_if_not_installed("survival") x <- as.matrix(mtcars[, -c(1, 8)]) y <- survival::Surv(mtcars$mpg, mtcars$vs) model <- glmnet::glmnet(x, y, family = "cox", lambda = 0.1) fit <- tidypredict_fit(model) native <- unname(predict(model, x, type = "link")[, 1]) tidy <- rlang::eval_tidy(fit, mtcars) expect_equal(tidy, native) }) test_that("multinomial family errors with helpful message (#198)", { model <- glmnet::glmnet( as.matrix(iris[, 1:4]), iris$Species, family = "multinomial", lambda = 0.5 ) expect_snapshot(error = TRUE, tidypredict_fit(model)) }) test_that("mgaussian family errors with helpful message (#199)", { x <- as.matrix(mtcars[, -c(1, 4)]) y <- cbind(mtcars$mpg, mtcars$hp) model <- glmnet::glmnet(x, y, family = "mgaussian", lambda = 0.5) expect_snapshot(error = TRUE, tidypredict_fit(model)) }) # Tests for .extract_glmnet_multiclass() test_that(".extract_glmnet_multiclass returns correct structure", { model <- glmnet::glmnet( as.matrix(iris[, 1:4]), iris$Species, family = "multinomial", lambda = 0.5 ) result <- .extract_glmnet_multiclass(model) expect_type(result, "list") expect_length(result, 3) expect_named(result, levels(iris$Species)) expect_type(result[[1]], "character") }) test_that(".extract_glmnet_multiclass errors on non-multnet model", { model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = 1) expect_snapshot(error = TRUE, .extract_glmnet_multiclass(model)) }) test_that(".extract_glmnet_multiclass errors with multiple penalties", { model <- glmnet::glmnet( as.matrix(iris[, 1:4]), iris$Species, family = "multinomial" ) expect_snapshot(error = TRUE, .extract_glmnet_multiclass(model)) }) test_that(".extract_glmnet_multiclass works with explicit penalty", { model <- glmnet::glmnet( as.matrix(iris[, 1:4]), iris$Species, family = "multinomial" ) result <- .extract_glmnet_multiclass(model, penalty = 0.01) expect_type(result, "list") expect_length(result, 3) }) test_that(".extract_glmnet_multiclass handles sparse coefficients", { # High penalty should zero out many coefficients model <- glmnet::glmnet( as.matrix(iris[, 1:4]), iris$Species, family = "multinomial", lambda = 10 ) result <- .extract_glmnet_multiclass(model) expect_type(result, "list") expect_length(result, 3) }) test_that(".extract_glmnet_multiclass produces correct predictions", { model <- glmnet::glmnet( as.matrix(iris[, 1:4]), iris$Species, family = "multinomial", lambda = 0.01 ) eqs <- .extract_glmnet_multiclass(model) n_rows <- nrow(iris) # Evaluate each linear predictor, recycling scalars to full length logits <- sapply(eqs, function(eq) { val <- rlang::eval_tidy(rlang::parse_expr(eq), iris) if (length(val) == 1) rep(val, n_rows) else val }) # Apply softmax exp_logits <- exp(logits) probs <- exp_logits / rowSums(exp_logits) # Compare to native predictions native <- predict(model, as.matrix(iris[, 1:4]), type = "response")[,, 1] expect_equal(unname(probs), unname(native), tolerance = 1e-10) }) # Tests for .build_linear_pred() test_that(".build_linear_pred handles intercept only", { result <- .build_linear_pred("(Intercept)", 5.5) expect_equal(result, "5.5") }) test_that(".build_linear_pred handles single predictor", { result <- .build_linear_pred(c("(Intercept)", "x"), c(1.5, 2.0)) expect_equal(result, "1.5 + (`x` * 2)") }) test_that(".build_linear_pred handles multiple predictors", { result <- .build_linear_pred( c("(Intercept)", "x", "y"), c(1.0, 2.0, 3.0) ) expect_equal(result, "1 + (`x` * 2) + (`y` * 3)") }) test_that(".build_linear_pred skips zero coefficients", { result <- .build_linear_pred( c("(Intercept)", "x", "y", "z"), c(1.0, 0.0, 2.0, 0.0) ) expect_identical(result, "1 + (`y` * 2)") }) test_that(".build_linear_pred returns '0' when all coefficients are zero", { result <- .build_linear_pred( c("(Intercept)", "x", "y"), c(0, 0, 0) ) expect_equal(result, "0") }) test_that(".build_linear_pred handles negative coefficients", { result <- .build_linear_pred( c("(Intercept)", "x"), c(-1.5, -2.0) ) expect_equal(result, "-1.5 + (`x` * -2)") }) test_that(".build_linear_pred handles special characters in variable names", { result <- .build_linear_pred( c("(Intercept)", "var with space", "var.with.dots"), c(1.0, 2.0, 3.0) ) expect_identical(result, "1 + (`var with space` * 2) + (`var.with.dots` * 3)") }) test_that(".build_linear_pred handles no intercept", { result <- .build_linear_pred(c("x", "y"), c(2.0, 3.0)) expect_equal(result, "(`x` * 2) + (`y` * 3)") }) test_that(".build_linear_pred handles zero intercept", { result <- .build_linear_pred( c("(Intercept)", "x"), c(0, 2.0) ) expect_equal(result, "(`x` * 2)") })