test_that('regression models', { skip_if(!run_tests()) skip_on_cran() pred_ptype <- tibble::tibble(.pred = numeric(0)) #----------------------------------------------------------------------------- set.seed(166) mod_df <- try(tab_pfn(predictors, outcome), silent = TRUE) expect_s3_class(mod_df, exp_cls) expect_snapshot(mod_df) pred_df <- predict(mod_df, mtcars[1:3, -1]) expect_equal(pred_df[0, ], pred_ptype) expect_equal(nrow(pred_df), 3L) aug_df <- augment(mod_df, mtcars[1:3, -1]) expect_s3_class(aug_df, c("tbl_df", "tbl", "data.frame")) expect_equal(nrow(aug_df), 3L) expect_equal(ncol(aug_df), 11L) #----------------------------------------------------------------------------- set.seed(166) mod_f <- try(tab_pfn(mpg ~ ., data = mtcars), silent = TRUE) expect_s3_class(mod_f, exp_cls) expect_snapshot(mod_f) pred_f <- predict(mod_f, mtcars[1:3, -1]) expect_equal(pred_f[0, ], pred_ptype) expect_equal(nrow(pred_f), 3L) aug_f <- augment(mod_f, mtcars[1:3, -1]) expect_s3_class(aug_f, c("tbl_df", "tbl", "data.frame")) expect_equal(nrow(aug_f), 3L) expect_equal(ncol(aug_f), 11L) #----------------------------------------------------------------------------- set.seed(166) mod_mat <- try(tab_pfn(as.matrix(predictors), outcome), silent = TRUE) expect_s3_class(mod_mat, exp_cls) expect_snapshot(mod_mat) pred_mat <- predict(mod_mat, mtcars[1:3, -1]) expect_equal(pred_mat[0, ], pred_ptype) expect_equal(nrow(pred_mat), 3L) aug_mat <- augment(mod_mat, mtcars[1:3, -1]) expect_s3_class(aug_mat, c("tbl_df", "tbl", "data.frame")) expect_equal(nrow(aug_mat), 3L) expect_equal(ncol(aug_mat), 11L) #----------------------------------------------------------------------------- expect_snapshot_error( tab_pfn(1, 2) ) }) test_that('reproducible results', { skip_if(!run_tests()) skip_on_cran() set.seed(166) mod_1 <- try(tab_pfn(predictors, outcome), silent = TRUE) pred_1 <- predict(mod_1, mtcars[1:3, -1]) set.seed(166) mod_2 <- try(tab_pfn(predictors, outcome), silent = TRUE) pred_2 <- predict(mod_2, mtcars[1:3, -1]) expect_equal(pred_1, pred_2) set.seed(774) mod_3 <- try(tab_pfn(predictors, outcome), silent = TRUE) pred_3 <- predict(mod_3, mtcars[1:3, -1]) expect_false(all(pred_1$.pred == pred_3$.pred)) set.seed(166) mod_4 <- try(tab_pfn(predictors, outcome, num_estimators = 1), silent = TRUE) pred_4 <- predict(mod_4, mtcars[1:3, -1]) expect_false(all(pred_1$.pred == pred_4$.pred)) }) test_that('regression models - recipes', { skip_if(!run_tests()) skip_on_cran() skip_if_not_installed("modeldata") skip_if_not_installed("recipes") reticulate::import("torch") library(tabpfn) library(recipes) data(Chicago, package = "modeldata") pred_ptype <- tibble::tibble(.pred = numeric(0)) #----------------------------------------------------------------------------- rec <- recipe(ridership ~ Austin + Quincy_Wells + date, data = Chicago) |> step_date(date) |> step_rm(date) set.seed(166) mod_rec <- try(tab_pfn(rec, Chicago[1:20, ]), silent = TRUE) expect_s3_class(mod_rec, exp_cls) expect_snapshot(mod_rec) pred_rec <- predict(mod_rec, Chicago[50:52, ]) expect_equal(pred_rec[0, ], pred_ptype) expect_equal(nrow(pred_rec), 3L) aug_rec <- augment(mod_rec, Chicago[50:52, ]) expect_s3_class(aug_rec, c("tbl_df", "tbl", "data.frame")) expect_equal(nrow(aug_rec), 3L) expect_equal(ncol(aug_rec), 51L) })