# Tests for [add_]predicted_rvars # # Author: mjskay ############################################################################### suppressWarnings(suppressMessages({ library(dplyr) library(tidyr) library(magrittr) library(posterior) })) # data mtcars_tbl = mtcars %>% set_rownames(seq_len(nrow(.))) %>% as_tibble() # for reliable testing, need to use only a single core (otherwise random # numbers do not seem to always be reproducible on brms) options(mc.cores = 1) test_that("[add_]predicted_rvars and basic arguments works on a simple rstanarm model", { skip_if_not_installed("rstanarm") m_hp_wt = readRDS(test_path("../models/models.rstanarm.m_hp_wt.rds")) ref = mtcars_tbl %>% mutate(.prediction = rvar(rstantools::posterior_predict(m_hp_wt, mtcars_tbl, draws = 10, seed = 123))) expect_equal(predicted_rvars(m_hp_wt, mtcars_tbl, ndraws = 10, seed = 123), ref) expect_equal(add_predicted_rvars(mtcars_tbl, m_hp_wt, ndraws = 10, seed = 123), ref) #predicted_rvars.default should work fine here so long as we don't subset ref_all = mtcars_tbl %>% mutate(.prediction = rvar(rstantools::posterior_predict(m_hp_wt, mtcars_tbl, seed = 123))) expect_equal(predicted_rvars.default(m_hp_wt, mtcars_tbl, seed = 123), ref_all) }) test_that("[add_]predicted_rvars and basic arguments works on an rstanarm model with random effects", { skip_if_not_installed("rstanarm") m_cyl = readRDS(test_path("../models/models.rstanarm.m_cyl.rds")) ref = mtcars_tbl %>% mutate(.prediction = rvar(rstantools::posterior_predict(m_cyl, mtcars_tbl, draws = 10, seed = 123))) expect_equal(predicted_rvars(m_cyl, mtcars_tbl, ndraws = 10, seed = 123), ref) expect_equal(add_predicted_rvars(mtcars_tbl, m_cyl, ndraws = 10, seed = 123), ref) }) test_that("[add_]predicted_rvars works on a simple brms model", { skip_if_not_installed("brms") m_hp = readRDS(test_path("../models/models.brms.m_hp.rds")) set.seed(123) ref = mtcars_tbl %>% mutate(.prediction = rvar(rstantools::posterior_predict(m_hp, mtcars_tbl, ndraws = 10))) expect_equal(predicted_rvars(m_hp, mtcars_tbl, ndraws = 10, seed = 123), ref) expect_equal(add_predicted_rvars(mtcars_tbl, m_hp, ndraws = 10, seed = 123), ref) }) test_that("[add_]predicted_rvars works on brms models with categorical outcomes", { skip_if_not_installed("brms") m_cyl_mpg = readRDS(test_path("../models/models.brms.m_cyl_mpg.rds")) set.seed(1234) ref = mtcars_tbl %>% mutate(.prediction = rvar(rstantools::posterior_predict(m_cyl_mpg, mtcars_tbl, ndraws = 10))) expect_equal(predicted_rvars(m_cyl_mpg, mtcars_tbl, seed = 1234, ndraws = 10), ref) expect_equal(add_predicted_rvars(mtcars_tbl, m_cyl_mpg, seed = 1234, ndraws = 10), ref) }) test_that("[add_]predicted_rvars works on brms models with dirichlet responses", { skip_if_not_installed("brms") skip_if_not(getRversion() >= "4") m_dirich = readRDS(test_path("../models/models.brms.m_dirich.rds")) set.seed(1234) grid = tibble(x = c("A", "B")) draws = rstantools::posterior_predict(m_dirich, grid, ndraws = 10) ref = grid %>% mutate(.prediction = rvar(draws)) expect_equal(predicted_rvars(m_dirich, grid, seed = 1234, ndraws = 10), ref) }) test_that("[add_]predicted_rvars works on brms models with multinomial responses", { skip_if_not_installed("brms") skip_if_not(getRversion() >= "4") m_multinom = readRDS(test_path("../models/models.brms.m_multinom.rds")) set.seed(1234) # use a low number for total so there are some 0s grid = tibble(total = c(10, 20)) ref = grid %>% mutate(.prediction = rvar(rstantools::posterior_predict(m_multinom, grid, ndraws = 10))) expect_equal(predicted_rvars(m_multinom, grid, seed = 1234, ndraws = 10), ref) # column transformation column_ref = ref %>% mutate(.row = 1:n()) %>% group_by(across(-.prediction)) %>% reframe(col_pred = colnames(.prediction), .prediction = t(.prediction)) %>% arrange(col_pred, .row) dim(column_ref$.prediction) = length(column_ref$.prediction) attr(draws_of(column_ref$.prediction), "levels") = c("a","b","c") expect_equal(predicted_rvars(m_multinom, grid, seed = 1234, ndraws = 10, columns_to = "col_pred"), column_ref) }) test_that("[add_]predicted_rvars throws an error when draws is called instead of ndraws in rstanarm", { skip_if_not_installed("rstanarm") m_hp_wt = readRDS(test_path("../models/models.rstanarm.m_hp_wt.rds")) expect_error( m_hp_wt %>% predicted_rvars(mtcars_tbl, draws = 10), "`draws.*.`ndraws`.*.See the documentation for additional details." ) expect_error( mtcars_tbl %>% add_predicted_rvars(m_hp_wt, draws = 10), "`draws.*.`ndraws`.*.See the documentation for additional details." ) }) test_that("[add_]predicted_rvars throws an error when re.form is called instead of re_formula in rstanarm", { skip_if_not_installed("rstanarm") m_hp_wt = readRDS(test_path("../models/models.rstanarm.m_hp_wt.rds")) expect_error( m_hp_wt %>% predicted_rvars(mtcars_tbl, re.form = NULL), "`re.form.*.`re_formula`.*.See the documentation for additional details." ) expect_error( mtcars_tbl %>% add_predicted_rvars(m_hp_wt, re.form = NULL), "`re.form.*.`re_formula`.*.See the documentation for additional details." ) })