library(bayesplot) context("PPC: loo") options(useFancyQuotes = FALSE) if (requireNamespace("rstanarm", quietly = TRUE) && requireNamespace("loo", quietly = TRUE)) { suppressPackageStartupMessages(library(rstanarm)) suppressPackageStartupMessages(library(loo)) ITER <- 1000 CHAINS <- 3 fit <- stan_glm(mpg ~ wt + am, data = mtcars, iter = ITER, chains = CHAINS, refresh = 0) y <- fit$y yrep <- posterior_predict(fit) suppressWarnings( psis1 <- psis(-log_lik(fit), cores = 2) ) lw <- weights(psis1) suppressWarnings( pits <- rstantools::loo_pit(yrep, y, lw) ) } test_that("ppc_loo_pit gives deprecation warning but still works", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_warning(p1 <- ppc_loo_pit(y, yrep, lw), "deprecated") expect_gg(p1) }) test_that("ppc_loo_pit_overlay returns ggplot object", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_gg(ppc_loo_pit_overlay(y, yrep, lw, samples = 25)) expect_gg(ppc_loo_pit_overlay(y, yrep, psis_object = psis1, samples = 25)) }) test_that("ppc_loo_pit_overlay warns about binary data", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_warning( ppc_loo_pit_overlay(rep(1, length(y)), yrep, lw), "not recommended for binary data" ) }) test_that("ppc_loo_pit_overlay works with boundary_correction=TRUE", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_message(p1 <- ppc_loo_pit_overlay(y, yrep, lw, boundary_correction = TRUE), "continuous observations") expect_gg(p1) }) test_that("ppc_loo_pit_overlay works with boundary_correction=FALSE", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") p1 <- ppc_loo_pit_overlay(y, yrep, lw, boundary_correction = FALSE) expect_gg(p1) }) test_that("ppc_loo_pit_qq returns ggplot object", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_gg(p1 <- ppc_loo_pit_qq(y, yrep, lw)) expect_gg(p2 <- ppc_loo_pit_qq(y, yrep, psis_object = psis1)) expect_equal(p1$labels$x, "Uniform") expect_equal(p1$data, p2$data) expect_gg(p3 <- ppc_loo_pit_qq(y, yrep, lw, compare = "normal")) expect_equal(p3$labels$x, "Normal") }) test_that("ppc_loo_pit functions work when pit specified instead of y,yrep,lw", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_gg(ppc_loo_pit_qq(pit = pits)) expect_message( p1 <- ppc_loo_pit_qq(y = y, yrep = yrep, lw = lw, pit = pits), "'pit' specified so ignoring 'y','yrep','lw' if specified" ) expect_message( p2 <- ppc_loo_pit_qq(pit = pits) ) expect_equal(p1$data, p2$data) expect_gg(p1 <- ppc_loo_pit_overlay(pit = pits)) expect_message( ppc_loo_pit_overlay(y = y, yrep = yrep, lw = lw, pit = pits), "'pit' specified so ignoring 'y','yrep','lw' if specified" ) }) test_that("ppc_loo_intervals returns ggplot object", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_gg(ppc_loo_intervals(y, yrep, psis_object = psis1)) expect_gg(g <- ppc_loo_intervals(y, yrep, psis_object = psis1, order = "median")) expect_s3_class(g$data$x, "factor") expect_equal(nlevels(g$data$x), length(g$data$x)) # subset argument expect_gg(g <- ppc_loo_intervals(y, yrep, psis_object = psis1, subset = 1:25)) expect_equal(nrow(g$data), 25) }) test_that("ppc_loo_ribbon returns ggplot object", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_gg(ppc_loo_ribbon(y, yrep, psis_object = psis1, prob = 0.7, alpha = 0.1)) expect_gg(g <- ppc_loo_ribbon(y, yrep, psis_object = psis1, subset = 1:25)) expect_equal(nrow(g$data), 25) }) test_that("ppc_loo_intervals/ribbon work when 'intervals' specified", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") intervals <- t(apply(yrep, 2, quantile, probs = c(0.1, 0.25, 0.5, 0.75, 0.9))) expect_gg(ppc_loo_intervals(y, intervals = intervals)) expect_gg(ppc_loo_ribbon(y, intervals = intervals)) expect_message(ppc_loo_ribbon(y, intervals = intervals), "'intervals' specified so ignoring 'yrep', 'psis_object', 'subset', if specified") expect_message(ppc_loo_intervals(y, yrep, psis_object = psis1, intervals = intervals), "'intervals' specified so ignoring 'yrep', 'psis_object', 'subset', if specified") }) test_that("ppc_loo_intervals/ribbon work when 'intervals' has 3 columns", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") intervals <- t(apply(yrep, 2, quantile, probs = c(0.1, 0.5, 0.9))) expect_gg(ppc_loo_intervals(y, intervals = intervals)) expect_gg(ppc_loo_ribbon(y, intervals = intervals)) }) test_that("errors if dimensions of yrep and lw don't match", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_error( ppc_loo_pit_overlay(y, yrep, lw[, 1:5]), "identical(dim(yrep), dim(lw)) is not TRUE", fixed = TRUE ) }) test_that("error if subset is bigger than num obs", { skip_if_not_installed("rstanarm") skip_if_not_installed("loo") expect_error(.psis_subset(psis1, 1:1000), "too many elements") expect_error( ppc_loo_intervals(y, yrep, psis_object = psis1, subset = 1:1000), "length(y) >= length(subset) is not TRUE", fixed = TRUE ) }) # Visual tests ------------------------------------------------------------ source(test_path("data-for-ppc-tests.R")) set.seed(123) test_that("ppc_loo_pit_overlay renders correctly", { skip_on_cran() skip_if_not_installed("vdiffr") skip_if_not_installed("loo") p_base <- suppressMessages(ppc_loo_pit_overlay(vdiff_loo_y, vdiff_loo_yrep, vdiff_loo_lw)) vdiffr::expect_doppelganger("ppc_loo_pit_overlay (default)", p_base) p_custom <- suppressMessages(ppc_loo_pit_overlay( vdiff_loo_y, vdiff_loo_yrep, vdiff_loo_lw, boundary_correction = FALSE )) vdiffr::expect_doppelganger("ppc_loo_pit_overlay (boundary)", p_custom) }) test_that("ppc_loo_pit_qq renders correctly", { skip_on_cran() skip_if_not_installed("vdiffr") skip_if_not_installed("loo") p_base <- ppc_loo_pit_qq(vdiff_loo_y, vdiff_loo_yrep, vdiff_loo_lw) vdiffr::expect_doppelganger("ppc_loo_pit_qq (default)", p_base) }) test_that("ppc_loo_intervals renders correctly", { skip_on_cran() skip_if_not_installed("vdiffr") skip_if_not_installed("loo") psis_object <- suppressWarnings(loo::psis(-vdiff_loo_lw)) p_base <- ppc_loo_intervals( vdiff_loo_y, vdiff_loo_yrep, psis_object = psis_object ) vdiffr::expect_doppelganger("ppc_loo_intervals (default)", p_base) p_custom <- ppc_loo_intervals( vdiff_loo_y, vdiff_loo_yrep, psis_object = psis_object, prob = 0.6, prob_outer = 0.7 ) vdiffr::expect_doppelganger("ppc_loo_intervals (prob)", p_custom) p_custom <- ppc_loo_intervals( vdiff_loo_y, vdiff_loo_yrep, psis_object = psis_object, order = "median" ) vdiffr::expect_doppelganger("ppc_loo_intervals (order)", p_custom) }) test_that("ppc_loo_ribbon renders correctly", { skip_on_cran() skip_if_not_installed("vdiffr") skip_if_not_installed("loo") psis_object <- suppressWarnings(loo::psis(-vdiff_loo_lw)) p_base <- ppc_loo_ribbon( vdiff_loo_y, vdiff_loo_yrep, psis_object = psis_object ) vdiffr::expect_doppelganger("ppc_loo_ribbon (default)", p_base) p_custom <- ppc_loo_ribbon( vdiff_loo_y, vdiff_loo_yrep, psis_object = psis_object, prob = 0.6, prob_outer = 0.7 ) vdiffr::expect_doppelganger("ppc_loo_ribbon (prob)", p_custom) p_custom <- ppc_loo_ribbon( vdiff_loo_y, vdiff_loo_yrep, psis_object = psis_object, subset = 1:10 ) vdiffr::expect_doppelganger("ppc_loo_ribbon (subset)", p_custom) })