# --------------------------------------------------------------------------- # extract_model_id # --------------------------------------------------------------------------- test_that("extract_model_id accepts a character string", { expect_equal(extract_model_id("mod-123"), "mod-123") }) test_that("extract_model_id accepts a list with $model_id", { fit <- list(model_id = "mod-456", n_rows = 10) expect_equal(extract_model_id(fit), "mod-456") }) test_that("extract_model_id errors on invalid input", { expect_error(extract_model_id(42), "must be a character string") expect_error(extract_model_id(list(a = 1)), "must be a character string") expect_error(extract_model_id(NULL), "must be a character string") }) # --------------------------------------------------------------------------- # midas_fit # --------------------------------------------------------------------------- test_that("midas_fit sends correct request and parses response", { skip_on_cran() fake_server_running() on.exit(reset_server_state()) mock_resp <- mock_json_response(list( model_id = "mod-123", n_rows = 10, n_cols = 3, col_types = list("num", "num", "bin") )) data <- data.frame(X1 = rnorm(10), X2 = rnorm(10), X3 = c(rep(0, 5), rep(1, 5))) data$X1[c(3, 7)] <- NA httr2::with_mocked_responses(mock_resp, { res <- midas_fit(data, epochs = 5L) }) expect_type(res, "list") expect_equal(res$model_id, "mod-123") expect_equal(res$n_rows, 10) }) test_that("midas_transform returns list of data frames", { skip_on_cran() fake_server_running() on.exit(reset_server_state()) call_count <- 0L mock_resp <- function(req) { call_count <<- call_count + 1L if (call_count == 1L) { # POST /model/{id}/transform response body <- list(model_id = "mod-123", m = 2, n_rows = 3, n_cols = 2) } else { # GET /model/{id}/imputations/{idx} response body <- list( model_id = "mod-123", index = call_count - 2L, columns = list("X1", "X2"), data = list(list(1.0, 2.0), list(3.0, 4.0), list(5.0, 6.0)) ) } httr2::response( status_code = 200L, headers = list("Content-Type" = "application/json"), body = charToRaw(jsonlite::toJSON(body, auto_unbox = TRUE)) ) } httr2::with_mocked_responses(mock_resp, { imps <- midas_transform("mod-123", m = 2L) }) expect_type(imps, "list") expect_length(imps, 2) expect_s3_class(imps[[1]], "data.frame") expect_equal(nrow(imps[[1]]), 3) }) test_that("midas returns model_id and imputations", { skip_on_cran() fake_server_running() on.exit(reset_server_state()) mock_resp <- mock_json_response(list( model_id = "mod-456", m = 2, columns = list("X1", "X2"), imputations = list( list(list(1.0, 2.0), list(3.0, 4.0)), list(list(5.0, 6.0), list(7.0, 8.0)) ) )) data <- data.frame(X1 = c(1, NA), X2 = c(2, 4)) httr2::with_mocked_responses(mock_resp, { res <- midas(data, m = 2L, epochs = 1L) }) expect_equal(res$model_id, "mod-456") expect_length(res$imputations, 2) expect_s3_class(res$imputations[[1]], "data.frame") }) test_that("imp_mean returns data frame", { skip_on_cran() fake_server_running() on.exit(reset_server_state()) mock_resp <- mock_json_response(list( model_id = "mod-123", columns = list("X1", "X2"), data = list(list(1.5, 2.5), list(3.5, 4.5)) )) httr2::with_mocked_responses(mock_resp, { res <- imp_mean("mod-123") }) expect_s3_class(res, "data.frame") expect_equal(nrow(res), 2) expect_equal(ncol(res), 2) }) test_that("combine returns Rubin's rules data frame", { skip_on_cran() fake_server_running() on.exit(reset_server_state()) mock_resp <- mock_json_response(list( model_id = "mod-123", columns = list("term", "estimate", "std.error", "statistic", "df", "p.value"), data = list( list("const", 0.1, 0.05, 2.0, 50.0, 0.05), list("X1", 0.3, 0.1, 3.0, 48.0, 0.004) ) )) httr2::with_mocked_responses(mock_resp, { res <- combine("mod-123", y = "Y") }) expect_s3_class(res, "data.frame") expect_true("term" %in% names(res)) expect_true("p.value" %in% names(res)) expect_equal(nrow(res), 2) }) test_that("overimpute returns RMSE", { skip_on_cran() fake_server_running() on.exit(reset_server_state()) mock_resp <- mock_json_response(list( model_id = "mod-123", rmse = list(X1 = 0.15, X2 = 0.22), mean_rmse = 0.185 )) httr2::with_mocked_responses(mock_resp, { res <- overimpute("mod-123") }) expect_type(res$mean_rmse, "double") expect_true("X1" %in% names(res$rmse)) })