test_that("formula method", { set.seed(23598723) split <- rsample::initial_split(mtcars) f <- mpg ~ cyl + poly(disp, 2) + hp + drat + wt + qsec + vs + am + gear + carb lm_fit <- lm(f, data = rsample::training(split)) test_pred <- predict(lm_fit, rsample::testing(split)) rmse_test <- yardstick::rsq_vec( rsample::testing(split) |> pull(mpg), test_pred ) res <- parsnip::linear_reg() |> parsnip::set_engine("lm") |> last_fit(f, split) expect_equal(res, .Last.tune.result) expect_equal( coef(extract_fit_engine(res$.workflow[[1]])), coef(lm_fit), ignore_attr = TRUE ) expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) expect_true(res$.workflow[[1]]$trained) expect_equal( nrow(predict(res$.workflow[[1]], rsample::testing(split))), nrow(rsample::testing(split)) ) expect_null(.get_tune_eval_times(res)) expect_null(.get_tune_eval_time_target(res)) }) test_that("recipe method", { set.seed(23598723) split <- rsample::initial_split(mtcars) f <- mpg ~ cyl + poly(disp, 2) + hp + drat + wt + qsec + vs + am + gear + carb lm_fit <- lm(f, data = rsample::training(split)) test_pred <- predict(lm_fit, rsample::testing(split)) rmse_test <- yardstick::rsq_vec( rsample::testing(split) |> pull(mpg), test_pred ) rec <- recipes::recipe(mpg ~ ., data = mtcars) |> recipes::step_poly(disp) res <- parsnip::linear_reg() |> parsnip::set_engine("lm") |> last_fit(rec, split) expect_equal( sort(coef(extract_fit_engine(res$.workflow[[1]]))), sort(coef(lm_fit)), ignore_attr = TRUE ) expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) expect_true(res$.workflow[[1]]$trained) expect_equal( nrow(predict(res$.workflow[[1]], rsample::testing(split))), nrow(rsample::testing(split)) ) }) test_that("model_fit method", { library(parsnip) lm_fit <- linear_reg() |> fit(mpg ~ ., data = mtcars) expect_snapshot(last_fit(lm_fit), error = TRUE) }) test_that("workflow method", { library(parsnip) lm_fit <- workflows::workflow(mpg ~ ., linear_reg()) |> fit(data = mtcars) expect_snapshot(last_fit(lm_fit), error = TRUE) }) test_that("collect metrics of last fit", { set.seed(23598723) split <- rsample::initial_split(mtcars) f <- mpg ~ cyl + poly(disp, 2) + hp + drat + wt + qsec + vs + am + gear + carb res <- parsnip::linear_reg() |> parsnip::set_engine("lm") |> last_fit(f, split) met <- collect_metrics(res) expect_true(inherits(met, "tbl_df")) expect_equal(names(met), c(".metric", ".estimator", ".estimate", ".config")) }) test_that("ellipses with last_fit", { options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf) set.seed(23598723) split <- rsample::initial_split(mtcars) f <- mpg ~ cyl + poly(disp, 2) + hp + drat + wt + qsec + vs + am + gear + carb expect_snapshot( linear_reg() |> set_engine("lm") |> last_fit(f, split, something = "wrong") ) }) test_that("argument order gives errors for recipe/formula", { options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf) set.seed(23598723) split <- rsample::initial_split(mtcars) f <- mpg ~ cyl + poly(disp, 2) + hp + drat + wt + qsec + vs + am + gear + carb rec <- recipes::recipe(mpg ~ ., data = mtcars) |> recipes::step_poly(disp) lin_mod <- parsnip::linear_reg() |> parsnip::set_engine("lm") expect_snapshot(error = TRUE, { last_fit(rec, lin_mod, split) }) expect_snapshot(error = TRUE, { last_fit(f, lin_mod, split) }) }) test_that("same results of last_fit() and fit() (#300)", { skip_if_not_installed("randomForest") rf <- parsnip::rand_forest(mtry = 2, trees = 5) |> parsnip::set_engine("randomForest") |> parsnip::set_mode("regression") wflow <- workflows::workflow() |> workflows::add_model(rf) |> workflows::add_formula(mpg ~ .) set.seed(23598723) split <- rsample::initial_split(mtcars) set.seed(1) lf_obj <- last_fit(wflow, split = split) set.seed(1) r_obj <- fit(wflow, data = rsample::analysis(split)) r_pred <- predict(r_obj, rsample::assessment(split)) expect_equal( lf_obj$.predictions[[1]]$.pred, r_pred$.pred ) }) test_that("`last_fit()` when objects need tuning", { skip_if_not_installed("splines2") options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf) rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp, deg_free = tune()) spec_1 <- linear_reg(penalty = tune()) |> set_engine("glmnet") spec_2 <- linear_reg() wflow_1 <- workflow(rec, spec_1) wflow_2 <- workflow(mpg ~ ., spec_1) wflow_3 <- workflow(rec, spec_2) split <- rsample::initial_split(mtcars) expect_snapshot_error(last_fit(wflow_1, split)) expect_snapshot_error(last_fit(wflow_2, split)) expect_snapshot_error(last_fit(wflow_3, split)) }) test_that("last_fit() excludes validation set for initial_validation_split objects", { skip_if_not_installed("modeldata") data(ames, package = "modeldata", envir = rlang::current_env()) set.seed(23598723) split <- rsample::initial_validation_split(ames) f <- Sale_Price ~ Gr_Liv_Area + Year_Built lm_fit <- lm(f, data = rsample::training(split)) test_pred <- predict(lm_fit, rsample::testing(split)) rmse_test <- yardstick::rsq_vec( rsample::testing(split) |> pull(Sale_Price), test_pred ) res <- parsnip::linear_reg() |> parsnip::set_engine("lm") |> last_fit(f, split) expect_equal(res, .Last.tune.result) expect_equal( coef(extract_fit_engine(res$.workflow[[1]])), coef(lm_fit), ignore_attr = TRUE ) expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) expect_true(res$.workflow[[1]]$trained) expect_equal( nrow(predict(res$.workflow[[1]], rsample::testing(split))), nrow(rsample::testing(split)) ) }) test_that("last_fit() can include validation set for initial_validation_split objects", { skip_if_not_installed("modeldata") data(ames, package = "modeldata", envir = rlang::current_env()) set.seed(23598723) split <- rsample::initial_validation_split(ames) f <- Sale_Price ~ Gr_Liv_Area + Year_Built train_val <- rbind(rsample::training(split), rsample::validation(split)) lm_fit <- lm(f, data = train_val) test_pred <- predict(lm_fit, rsample::testing(split)) rmse_test <- yardstick::rsq_vec( rsample::testing(split) |> pull(Sale_Price), test_pred ) res <- parsnip::linear_reg() |> parsnip::set_engine("lm") |> last_fit(f, split, add_validation_set = TRUE) expect_equal(res, .Last.tune.result) expect_equal( coef(extract_fit_engine(res$.workflow[[1]])), coef(lm_fit), ignore_attr = TRUE ) expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) expect_true(res$.workflow[[1]]$trained) expect_equal( nrow(predict(res$.workflow[[1]], rsample::testing(split))), nrow(rsample::testing(split)) ) }) test_that("can use `last_fit()` with a workflow - postprocessor (requires training)", { skip_if_not_installed("mgcv") skip_if_not_installed("tailor", minimum_version = "0.0.0.9002") skip_if_not_installed("probably") y <- seq(0, 7, .001) dat <- data.frame(y = y, x = y + (y - 3)^2) dat set.seed(1) split <- rsample::initial_split(dat) set.seed(1) internal_calibration_split <- rsample::internal_calibration_split( split, split_args = list() ) # ---------------------------------------------------------------------------- wflow <- workflows::workflow( y ~ x, parsnip::linear_reg() ) |> workflows::add_tailor( tailor::tailor() |> tailor::adjust_numeric_calibration("linear") ) # ---------------------------------------------------------------------------- set.seed(1) last_fit_res <- last_fit( wflow, split ) last_fit_preds <- collect_predictions(last_fit_res) last_fit_wflow <- extract_workflow(last_fit_res) expect_true(is_trained_workflow(last_fit_wflow)) last_fit_cal <- last_fit_wflow |> extract_postprocessor() |> purrr::pluck("adjustments") |> purrr::pluck(1) |> purrr::pluck("results") |> purrr::pluck("fit") |> purrr::pluck("estimates") |> purrr::pluck(1) |> purrr::pluck("estimate") # ---------------------------------------------------------------------------- # TODO inner split is getting different random numbers; check against no cal instead set.seed(1) wflow_res <- generics::fit( wflow, rsample::analysis(internal_calibration_split), data_calibration = rsample::calibration(internal_calibration_split) ) wflow_cal <- wflow_res |> extract_postprocessor() |> purrr::pluck("adjustments") |> purrr::pluck(1) |> purrr::pluck("results") |> purrr::pluck("fit") |> purrr::pluck("estimates") |> purrr::pluck(1) |> purrr::pluck("estimate") wflow_preds <- predict(wflow_res, rsample::assessment(split)) # ---------------------------------------------------------------------------- expect_equal(last_fit_cal, wflow_cal) expect_equal(last_fit_preds[".pred"], wflow_preds) }) test_that("can use `last_fit()` with a workflow - postprocessor (does not require training)", { skip_if_not_installed("tailor", minimum_version = "0.0.0.9002") skip_if_not_installed("probably") y <- seq(0, 7, .001) dat <- data.frame(y = y, x = y + (y - 3)^2) dat set.seed(1) split <- rsample::initial_split(dat) wflow <- workflows::workflow( y ~ x, parsnip::linear_reg() ) |> workflows::add_tailor( tailor::tailor() |> tailor::adjust_numeric_range(lower_limit = 1) ) set.seed(1) last_fit_res <- last_fit( wflow, split ) last_fit_preds <- collect_predictions(last_fit_res) set.seed(1) wflow_res <- generics::fit(wflow, rsample::analysis(split)) wflow_preds <- predict(wflow_res, rsample::assessment(split)) expect_equal(last_fit_preds[".pred"], wflow_preds) })