test_shrinkGPR <- function(args, eval_points = c(-2, 0, 2), log_pred = FALSE) { if (!torch::torch_is_installed()) { skip("Torch is not installed. Skipping test.") } set.seed(123) torch_manual_seed(123) # Create mock data full_dat <- data.frame( y = sin(2 * pi * runif(20)) + rnorm(20, sd = 0.1), x1 = runif(20), x2 = rnorm(20) ) train <- full_dat[1:15, ] test <- full_dat[16:20, ] args$data <- train args$formula <- y ~ x1 + x2 # Fit model res <- do.call(shrinkGPR, args) # Test model object expect_s3_class(res, "shrinkGPR") expect_true("shrinkGPR" %in% class(res)) # Test prediction methods preds <- predict(res, newdata = test) expect_type(preds, "double") expect_equal(dim(preds), c(100, nrow(test))) # Default nsamp for prediction functions # Test LPDS lpds <- LPDS(res, data_test = test[1, ]) expect_type(lpds, "double") expect_length(lpds, 1) # Test predictive density evaluation pred_dens <- eval_pred_dens(eval_points, res, data_test = test[1, ], log = log_pred) expect_type(pred_dens, "double") expect_length(pred_dens, length(eval_points)) # Test predictive moments moments <- calc_pred_moments(res, newdata = test, nsamp = 100) expect_type(moments, "list") expect_named(moments, c("means", "vars")) expect_equal(dim(moments$means), c(100, nrow(test))) expect_equal(dim(moments$vars), c(100, nrow(test), nrow(test))) # Test posterior samples posterior <- gen_posterior_samples(res, nsamp = 100) expect_type(posterior, "list") expect_named(posterior, c("thetas", "sigma2", "lambda")) expect_equal(nrow(posterior$thetas), 100) if (res$model_internals$x_mean) { expect_true(all(c("betas", "lambda_mean") %in% names(posterior))) } } # Define scenarios scenarios <- expand.grid( auto_stop = c(TRUE, FALSE), kernel = c(kernel_se, kernel_matern_32), flow = c(sylvester) ) names(scenarios) <- c("auto_stop", "kernel", "flow") # Parameter toggles params <- c("display_progress", "auto_stop") for (i in seq_len(nrow(scenarios))) { for (j in params) { args <- formals(shrinkGPR) args <- args[sapply(args, function(x) !is.null(x))] args[[j]] <- !args[[j]] args$auto_stop <- scenarios$auto_stop[i] args$kernel_func <- scenarios$kernel[[i]] args$flow_func <- scenarios$flow[[i]] args$n_epochs <- 10 # Limit epochs for testing test_that(paste0( "scenario: ", i, ", auto_stop: ", scenarios$auto_stop[i], ", toggled: ", j ), { test_shrinkGPR(args) }) } }