# test-ALEPlot.R # Tests to ensure that ale package gives exactly the same results # as the gold standard reference ALEPlot package. # test_file('tests/testthat/test-ALEPlot.R') # To minimize test time, the reference output should be serialized with expect_snapshot_value. # Do not run these on CRAN so that the required packages are not included as dependencies. # https://community.rstudio.com/t/skip-an-entire-test-file-on-cran-only/162842 if (!identical(Sys.getenv("NOT_CRAN"), "true")) return() # nnet ----------------- set.seed(0) n = 1000 # smaller dataset for more rapid execution x1 <- runif(n, min = 0, max = 1) x2 <- runif(n, min = 0, max = 1) x3 <- runif(n, min = 0, max = 1) x4 <- runif(n, min = 0, max = 1) y = 4*x1 + 3.87*x2^2 + 2.97*exp(-5+10*x3)/(1+exp(-5+10*x3))+ 13.86*(x1-0.5)*(x2-0.5)+ rnorm(n, 0, 1) DAT <<- data.frame(y, x1, x2, x3, x4) set.seed(0) nnet.DAT <<- nnet::nnet(y ~ ., data = DAT, linout = T, skip = F, size = 6, decay = 0.1, maxit = 1000, trace = F) # Define the predict functions nnet_pred_fun_ALEPlot <<- function(X.model, newdata) { as.numeric(predict(X.model, newdata,type = "raw")) } nnet_pred_fun_ale <<- function(object, newdata, type = pred_type) { as.numeric(predict(object, newdata, type = type)) } # gbm ---------------- adult_data <<- census |> as.data.frame() |> # ALEPlot is not compatible with the tibble format select(age:native_country, higher_income) |> # Rearrange columns to match ALEPlot order stats::na.omit(data) # Dump plots automatically generated by gbm into a temp PDF file so they don't print # Don't print any plots pdf(file = NULL) set.seed(0) gbm.data <<- gbm::gbm( higher_income ~ ., data = adult_data[,-c(3,4)] |> # gbm::gbm() requires binary response outcomes to be numeric 0 or 1 mutate(higher_income = as.integer(higher_income)), distribution = "bernoulli", n.trees = 100, # smaller model than ALEPlot example for rapid execution shrinkage = 0.02, interaction.depth = 3 ) # Return to regular printing of plots dev.off() |> invisible() gbm_pred_fun_ALEPlot <<- function(X.model, newdata) { as.numeric(gbm::predict.gbm(X.model, newdata, n.trees = 100, type="link")) } gbm_pred_fun_ale <<- function(object, newdata, type = pred_type) { as.numeric(gbm::predict.gbm(object, newdata, n.trees = 100, type = type)) } # Tests -------------------- test_that('ale function matches output of ALEPlot with nnet', { # Dump plots into a temp PDF file so they don't print # Don't print any plots pdf(file = NULL) # Create list of ALEPlot data that can be readily compared for accuracy nnet_ALEPlot <- map(1:4, \(it.col_idx) { ALEPlot::ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = nnet_pred_fun_ALEPlot, J = it.col_idx, K = 10) |> as_tibble() |> select(-K) }) |> set_names(names(DAT[,2:5])) # Return to regular printing of plots dev.off() |> invisible() # Create ale results with data only nnet_ale <- ALE( # basic arguments model = nnet.DAT, data = DAT, # make ale equivalent to ALEPlot parallel = 0, output_stats = FALSE, boot_it = 0, # specific options requested by ALEPlot example pred_type = "raw", pred_fun = nnet_pred_fun_ale, max_num_bins = 10 + 1, silent = TRUE ) # Convert ale results to version that can be readily compared with ALEPlot nnet_ale_to_ALEPlot <- get(nnet_ale, ale_centre = 'zero') |> map(\(it.x) { tibble( x.values = it.x[[1]], f.values = it.x$.y, ) }) # Compare results of ALEPlot with ale expect_true( all.equal(nnet_ALEPlot, nnet_ale_to_ALEPlot, tolerance = 0.01) ) }) test_that('ale function matches output of ALEPlot with gbm', { # Dump plots into a temp PDF file so they don't print # Don't print any plots pdf(file = NULL) # Create list of ALEPlot data that can be readily compared for accuracy # For this test, get only four variables: c('age', 'workclass', 'education_num', 'sex') # These are column indexes c(1, 2, 3, 8) gbm_ALEPlot <- map(c(1, 2, 3, 8), \(it.col_idx) { ALEPlot::ALEPlot( adult_data[,-c(3,4,15)], gbm.data, pred.fun = gbm_pred_fun_ALEPlot, J = it.col_idx, K = 10, NA.plot = TRUE ) |> as_tibble() |> select(-K) }) |> set_names(names(adult_data[,-c(3,4,15)])[c(1, 2, 3, 8)]) # Return to regular printing of plots dev.off() |> invisible() # Create ale results with data only gbm_ale <- ALE( model = gbm.data, x_cols = c('age', 'workclass', 'education_num', 'sex'), data = adult_data[,-c(3,4)], # unlike ALEPlot, include the y column (15) # make ale equivalent to ALEPlot parallel = 0, output_stats = FALSE, boot_it = 0, # specific options requested by ALEPlot example pred_fun = gbm_pred_fun_ale, pred_type = 'link', max_num_bins = 10 + 1, silent = TRUE ) |> suppressMessages() # Convert ale results to version that can be readily compared with ALEPlot gbm_ale_to_ALEPlot <- get(gbm_ale, ale_centre = 'zero') |> map(\(it.x) { tibble( x.values = it.x[[1]], f.values = unname(it.x$.y), ) |> mutate(across(where(is.factor), as.character)) }) # Compare results of ALEPlot with ale expect_true( all.equal(gbm_ALEPlot, gbm_ale_to_ALEPlot) ) }) test_that('2D ALE matches output of ALEPlot interactions with nnet', { # Dump plots into a temp PDF file so they don't print # Don't print any plots pdf(file = NULL) # Create list of ALEPlot data that can be readily compared for accuracy nnet_ALEPlot_ixn <- list() for (it.x1 in 1:4) { for (it.x2 in 1:4) { if (it.x1 < it.x2) { ap_data <- ALEPlot::ALEPlot( DAT[,2:5], nnet.DAT, pred.fun = nnet_pred_fun_ALEPlot, J = c(it.x1, it.x2), K = 10 ) .x1 <- ap_data$x.values[[1]] .x2 <- ap_data$x.values[[2]] .y <- ap_data$f.values ixn_tbl <- expand.grid( row = 1:length(.x1), col = 1:length(.x2) ) |> as_tibble() |> mutate( .x1 = .x1[row], .x2 = as.numeric(.x2[col]), .y = as.numeric(.y[cbind(row, col)]) ) |> select(-row, -col) |> arrange(.x1, .x2, .y) # Remove extraneous attributes, otherwise comparison will not match attributes(ixn_tbl)$out.attrs <- NULL nnet_ALEPlot_ixn[[str_glue('x{it.x1}:x{it.x2}')]] <- ixn_tbl } } } # Return to regular printing of plots dev.off() |> invisible() nnet_2D <- ALE( # basic arguments model = nnet.DAT, data = DAT, x_cols = list(d2 = TRUE), parallel = 0, output_stats = FALSE, pred_fun = nnet_pred_fun_ale, pred_type = "raw", max_num_bins = 10 + 1, # specific options requested silent = TRUE ) # Convert ale results to version that can be readily compared with ALEPlot nnet_2D_to_ALEPlot <- get(nnet_2D, ale_centre = 'zero') |> map(\(it.ale) { it.ale <- it.ale |> select(1, 2, .y) |> set_names(c('.x1', '.x2', '.y')) |> arrange(.x1, .x2, .y) # Strip incomparable attributes attr(it.ale, 'x') <- NULL it.ale }) # Compare results of ALEPlot with ale expect_true( all.equal(nnet_ALEPlot_ixn, nnet_2D_to_ALEPlot, tolerance = 0.01) ) }) test_that('2D ALE matches output of ALEPlot interactions with gbm', { # Dump plots into a temp PDF file so they don't print # Don't print any plots pdf(file = NULL) # Create list of ALEPlot data that can be readily compared for accuracy gbm_ALEPlot_ixn <- list() adult_data_subset <- adult_data[,-c(3,4,15)] for (it.x1 in c(1, 2, 3, 8)) { for (it.x2 in c(1, 3, 11)) { if (it.x1 < it.x2) { ap_data <- ALEPlot::ALEPlot( adult_data_subset, gbm.data, pred.fun = gbm_pred_fun_ALEPlot, J = c(it.x1, it.x2), K = 10, NA.plot = TRUE ) .x1 <- ap_data$x.values[[1]] .x2 <- ap_data$x.values[[2]] .y <- ap_data$f.values ixn_tbl <- expand.grid( row = 1:length(.x1), col = 1:length(.x2) ) |> as_tibble() |> mutate( .x1 = .x1[row], .x2 = as.numeric(.x2[col]), .y = as.numeric(.y[cbind(row, col)]) ) |> select(-row, -col) |> arrange(.x1, .x2, .y) # Remove extraneous attributes, otherwise comparison will not match attributes(ixn_tbl)$out.attrs <- NULL gbm_ALEPlot_ixn[[str_glue( '{names(adult_data_subset)[it.x1]}:{names(adult_data_subset)[it.x2]}' )]] <- ixn_tbl } } } gbm_2D <- ALE( model = gbm.data, data = adult_data, x_cols = c( 'age:education_num', 'age:hours_per_week', 'workclass:education_num', 'workclass:hours_per_week', 'education_num:hours_per_week', 'sex:hours_per_week' ), parallel = 0, output_stats = FALSE, pred_fun = gbm_pred_fun_ale, pred_type = 'link', max_num_bins = 10 + 1, # specific options requested silent = TRUE ) # Return to regular printing of plots. # For some reason, calling ALE() on gbm.data also prints some plots. dev.off() |> invisible() # Convert ale results to version that can be readily compared with ALEPlot gbm_2D_to_ALEPlot <- get(gbm_2D, ale_centre = 'zero') |> map(\(it.ale) { it.ale <- it.ale |> select(1, 2, .y) |> set_names(c('.x1', '.x2', '.y')) |> # Convert [ordered] factor columns to character for comparability with ALEPlot mutate(across( '.x1', \(it.col) if (is.factor(it.col)) as.character(it.col) else it.col )) |> arrange(.x1, .x2, .y) # Strip incomparable attributes attr(it.ale, 'x') <- NULL it.ale }) # Compare results of ALEPlot with ale expect_true( all.equal(gbm_ALEPlot_ixn, gbm_2D_to_ALEPlot) ) })