test_that("metric_set() - class", { exp <- dplyr::bind_rows( accuracy(two_class_example, truth, predicted), sensitivity(two_class_example, truth, predicted), specificity(two_class_example, truth, predicted) ) set <- metric_set(accuracy, sensitivity, specificity) expect_equal( set(two_class_example, truth = truth, estimate = predicted), exp ) }) test_that("metric_set() - prob", { exp <- dplyr::bind_rows( mn_log_loss(two_class_example, truth, Class1), roc_auc(two_class_example, truth, Class1), brier_class(two_class_example, truth, Class1) ) set <- metric_set(mn_log_loss, roc_auc, brier_class) expect_equal( set(two_class_example, truth = truth, Class1), exp ) }) test_that("metric_set() - ordered prob", { two_class_example$truth <- as.ordered(two_class_example$truth) exp <- dplyr::bind_rows( ranked_prob_score(two_class_example, truth, Class1:Class2), ranked_prob_score(two_class_example, truth, Class1:Class2) ) set <- metric_set(ranked_prob_score, ranked_prob_score) expect_equal( set(two_class_example, truth = truth, Class1:Class2), exp ) }) test_that("metric_set() - numeric", { exp <- dplyr::bind_rows( rmse(solubility_test, solubility, prediction), rsq(solubility_test, solubility, prediction), mae(solubility_test, solubility, prediction) ) set <- metric_set(rmse, rsq, mae) expect_equal( set(solubility_test, solubility, prediction), exp ) }) test_that("metric_set() - static survival", { exp <- dplyr::bind_rows( concordance_survival(lung_surv, truth = surv_obj, estimate = .pred_time), concordance_survival(lung_surv, truth = surv_obj, estimate = .pred_time) ) set <- metric_set(concordance_survival, concordance_survival) expect_equal( set(lung_surv, truth = surv_obj, estimate = .pred_time), exp ) }) test_that("metric_set() - dynamic survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival(lung_surv, truth = surv_obj, .pred), roc_auc_survival(lung_surv, truth = surv_obj, .pred) ) set <- metric_set(brier_survival, roc_auc_survival) expect_equal( set(lung_surv, truth = surv_obj, .pred), exp ) }) test_that("metric_set() - integrated survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival_integrated(lung_surv, truth = surv_obj, .pred), brier_survival_integrated(lung_surv, truth = surv_obj, .pred) ) set <- metric_set(brier_survival_integrated, brier_survival_integrated) expect_equal( set(lung_surv, truth = surv_obj, .pred), exp ) }) test_that("metric_set() - linear predictor survival", { exp <- dplyr::bind_rows( royston_survival(lung_surv, truth = surv_obj, .pred_linear_pred), royston_survival(lung_surv, truth = surv_obj, .pred_linear_pred) ) set <- metric_set(royston_survival, royston_survival) expect_equal( set(lung_surv, truth = surv_obj, estimate = .pred_linear_pred), exp ) }) test_that("metric_set() - quantile", { quantile_levels <- c(0.2, 0.4, 0.6, 0.8) pred1 <- 1:4 pred2 <- 8:11 example <- dplyr::tibble( preds = hardhat::quantile_pred(rbind(pred1, pred2), quantile_levels), truth = c(3.3, 7.1) ) exp <- dplyr::bind_rows( weighted_interval_score(example, truth = truth, preds), weighted_interval_score(example, truth = truth, preds) ) set <- metric_set(weighted_interval_score, weighted_interval_score) expect_equal( set(example, truth = truth, estimate = preds), exp ) }) test_that("metric_set() - can mix class and prob", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( accuracy(two_class_example, truth, predicted), roc_auc(two_class_example, truth, Class1) ) set <- metric_set(accuracy, roc_auc) expect_equal( set(two_class_example, truth = truth, Class1, estimate = predicted), exp ) }) test_that("metric_set() - can mix class and orderedprob", { skip_if_not_installed("tidyr") two_class_example$truth <- as.ordered(two_class_example$truth) exp <- dplyr::bind_rows( accuracy(two_class_example, truth, predicted), ranked_prob_score(two_class_example, truth, Class1:Class2) ) set <- metric_set(accuracy, ranked_prob_score) expect_equal( set(two_class_example, truth = truth, Class1:Class2, estimate = predicted), exp ) }) test_that("metric_set() - can mix prob and orderedprob", { skip_if_not_installed("tidyr") hpc_cv$obs <- as.ordered(hpc_cv$obs) exp <- dplyr::bind_rows( roc_auc(hpc_cv, obs, VF:L), ranked_prob_score(hpc_cv, obs, VF:L) ) set <- metric_set(roc_auc, ranked_prob_score) expect_equal( set(hpc_cv, obs, VF:L), exp ) }) test_that("metric_set() - can mix class, prob, and orderedprob", { skip_if_not_installed("tidyr") hpc_cv$obs <- as.ordered(hpc_cv$obs) exp <- dplyr::bind_rows( accuracy(hpc_cv, obs, pred), roc_auc(hpc_cv, obs, VF:L), ranked_prob_score(hpc_cv, obs, VF:L) ) set <- metric_set(accuracy, roc_auc, ranked_prob_score) expect_equal( set(hpc_cv, obs, VF:L, estimate = pred), exp ) }) test_that("metric_set() - can mix static and dynamic survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival(lung_surv, truth = surv_obj, .pred), concordance_survival(lung_surv, truth = surv_obj, estimate = .pred_time) ) set <- metric_set(concordance_survival, brier_survival) expect_equal( set(lung_surv, truth = surv_obj, .pred, estimate = .pred_time), exp ) }) test_that("metric_set() - can mix static and integrated survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival_integrated(lung_surv, truth = surv_obj, .pred), concordance_survival(lung_surv, truth = surv_obj, estimate = .pred_time) ) set <- metric_set(brier_survival_integrated, concordance_survival) expect_equal( set(lung_surv, truth = surv_obj, .pred, estimate = .pred_time), exp ) }) test_that("metric_set() - can mix dynamic and integrated survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival(lung_surv, truth = surv_obj, .pred), brier_survival_integrated(lung_surv, truth = surv_obj, .pred) ) set <- metric_set(brier_survival, brier_survival_integrated) expect_equal( set(lung_surv, truth = surv_obj, .pred, estimate = .pred_time), exp ) }) test_that("metric_set() - can mix static and linear predictor survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( concordance_survival(lung_surv, truth = surv_obj, estimate = .pred_time), royston_survival(lung_surv, truth = surv_obj, .pred_linear_pred) ) set <- metric_set(concordance_survival, royston_survival) expect_equal( set( lung_surv, truth = surv_obj, estimate = c(static = .pred_time, linear_pred = .pred_linear_pred) ), exp ) }) test_that("metric_set() - can mix dynamic and linear predictor survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival(lung_surv, truth = surv_obj, .pred), royston_survival(lung_surv, truth = surv_obj, .pred_linear_pred) ) set <- metric_set(brier_survival, royston_survival) expect_equal( set(lung_surv, truth = surv_obj, .pred, estimate = .pred_linear_pred), exp ) }) test_that("metric_set() - can mix integrated and linear predictor survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival_integrated(lung_surv, truth = surv_obj, .pred), royston_survival(lung_surv, truth = surv_obj, .pred_linear_pred) ) set <- metric_set(brier_survival_integrated, royston_survival) expect_equal( set(lung_surv, truth = surv_obj, .pred, estimate = .pred_linear_pred), exp ) }) test_that("metric_set() - can mix static, dynamic, integrated, and linear predictor survival", { skip_if_not_installed("tidyr") exp <- dplyr::bind_rows( brier_survival(lung_surv, truth = surv_obj, .pred), brier_survival_integrated(lung_surv, truth = surv_obj, .pred), concordance_survival(lung_surv, truth = surv_obj, estimate = .pred_time), royston_survival(lung_surv, truth = surv_obj, .pred_linear_pred) ) set <- metric_set( concordance_survival, brier_survival, brier_survival_integrated, royston_survival ) expect_equal( set( lung_surv, truth = surv_obj, .pred, estimate = c(static = .pred_time, linear_pred = .pred_linear_pred) ), exp ) }) test_that("metric set functions are classed", { expect_s3_class( metric_set(accuracy), c("class_prob_metric_set", "metric_set") ) expect_s3_class( metric_set(roc_auc), c("class_prob_metric_set", "metric_set") ) expect_s3_class( metric_set(ranked_prob_score), c("class_prob_metric_set", "metric_set") ) expect_s3_class( metric_set(rmse), c("numeric_metric_set", "metric_set") ) expect_s3_class( metric_set(concordance_survival), c("survival_metric_set", "metric_set") ) expect_s3_class( metric_set(brier_survival), c("survival_metric_set", "metric_set") ) expect_s3_class( metric_set(brier_survival_integrated), c("survival_metric_set", "metric_set") ) expect_s3_class( metric_set(royston_survival), c("survival_metric_set", "metric_set") ) expect_s3_class( metric_set(weighted_interval_score), c("quantile_metric_set", "metric_set") ) expect_s3_class( metric_set(accuracy, roc_auc, ranked_prob_score), c("class_prob_metric_set", "metric_set") ) expect_s3_class( metric_set( concordance_survival, brier_survival, brier_survival_integrated, royston_survival ), c("survival_metric_set", "metric_set") ) }) test_that("print metric_set works", { expect_snapshot( metric_set(accuracy) ) expect_snapshot( metric_set(roc_auc) ) expect_snapshot( metric_set(ranked_prob_score) ) expect_snapshot( metric_set(rmse) ) expect_snapshot( metric_set(concordance_survival) ) expect_snapshot( metric_set(brier_survival) ) expect_snapshot( metric_set(brier_survival_integrated) ) expect_snapshot( metric_set(royston_survival) ) expect_snapshot( metric_set(weighted_interval_score) ) expect_snapshot( metric_set(accuracy, roc_auc, ranked_prob_score) ) expect_snapshot( metric_set( concordance_survival, brier_survival, brier_survival_integrated, royston_survival ) ) }) test_that("metric_tweak and metric_set plays nicely together (#351)", { skip_if_not_installed("tidyr") # Classification multi_ex <- data_three_by_three() ref <- dplyr::bind_rows( j_index(multi_ex, estimator = "macro"), j_index(multi_ex, estimator = "micro") ) j_index_macro <- metric_tweak("j_index", j_index, estimator = "macro") j_index_micro <- metric_tweak("j_index", j_index, estimator = "micro") expect_identical( metric_set(j_index_macro, j_index_micro)(multi_ex), ref ) # Probability ref <- dplyr::bind_rows( roc_auc(two_class_example, truth, Class1, event_level = "first"), roc_auc(two_class_example, truth, Class1, event_level = "second") ) roc_auc_first <- metric_tweak("roc_auc", roc_auc, event_level = "first") roc_auc_second <- metric_tweak("roc_auc", roc_auc, event_level = "second") expect_identical( metric_set(roc_auc_first, roc_auc_second)(two_class_example, truth, Class1), ref ) # Ordered Probability two_class_example_ordered <- two_class_example two_class_example_ordered$truth <- as.ordered(two_class_example_ordered$truth) ref <- dplyr::bind_rows( ranked_prob_score( two_class_example_ordered, truth, Class1:Class2, na_rm = TRUE ), ranked_prob_score( two_class_example_ordered, truth, Class1:Class2, na_rm = FALSE ) ) ranked_prob_score_true <- metric_tweak( "ranked_prob_score", ranked_prob_score, na_rm = TRUE ) ranked_prob_score_false <- metric_tweak( "ranked_prob_score", ranked_prob_score, na_rm = FALSE ) expect_identical( metric_set(ranked_prob_score_true, ranked_prob_score_false)( two_class_example_ordered, truth, Class1:Class2 ), ref ) # numeric ref <- dplyr::bind_rows( ccc(mtcars, truth = mpg, estimate = disp, bias = TRUE), ccc(mtcars, truth = mpg, estimate = disp, bias = FALSE) ) ccc_bias <- metric_tweak("ccc", ccc, bias = TRUE) ccc_no_bias <- metric_tweak("ccc", ccc, bias = FALSE) expect_identical( metric_set(ccc_bias, ccc_no_bias)(mtcars, truth = mpg, estimate = disp), ref ) # Static survival lung_surv_na <- lung_surv lung_surv_na$.pred_time[1] <- NA ref <- dplyr::bind_rows( concordance_survival(lung_surv_na, surv_obj, .pred_time, na_rm = TRUE), concordance_survival(lung_surv_na, surv_obj, .pred_time, na_rm = FALSE) ) concordance_survival_na_rm <- metric_tweak( "concordance_survival", concordance_survival, na_rm = TRUE ) concordance_survival_no_na_rm <- metric_tweak( "concordance_survival", concordance_survival, na_rm = FALSE ) expect_identical( metric_set(concordance_survival_na_rm, concordance_survival_no_na_rm)( lung_surv_na, truth = surv_obj, estimate = .pred_time ), ref ) # dynamic survival ref <- dplyr::bind_rows( brier_survival(lung_surv_na, surv_obj, .pred, na_rm = TRUE), brier_survival(lung_surv_na, surv_obj, .pred, na_rm = FALSE) ) brier_survival_na_rm <- metric_tweak( "brier_survival", brier_survival, na_rm = TRUE ) brier_survival_no_na_rm <- metric_tweak( "brier_survival", brier_survival, na_rm = FALSE ) expect_identical( metric_set(brier_survival_na_rm, brier_survival_no_na_rm)( lung_surv_na, truth = surv_obj, .pred ), ref ) # integrated survival ref <- dplyr::bind_rows( brier_survival_integrated(lung_surv_na, surv_obj, .pred, na_rm = TRUE), brier_survival_integrated(lung_surv_na, surv_obj, .pred, na_rm = FALSE) ) brier_survival_integrated_na_rm <- metric_tweak( "brier_survival_integrated", brier_survival_integrated, na_rm = TRUE ) brier_survival_integrated_no_na_rm <- metric_tweak( "brier_survival_integrated", brier_survival_integrated, na_rm = FALSE ) expect_identical( metric_set( brier_survival_integrated_na_rm, brier_survival_integrated_no_na_rm )( lung_surv_na, truth = surv_obj, .pred ), ref ) # linear predictor survival survival ref <- dplyr::bind_rows( royston_survival(lung_surv_na, surv_obj, .pred_linear_pred, na_rm = TRUE), royston_survival(lung_surv_na, surv_obj, .pred_linear_pred, na_rm = FALSE) ) royston_survival_na_rm <- metric_tweak( "royston_survival", royston_survival, na_rm = TRUE ) royston_survival_no_na_rm <- metric_tweak( "royston_survival", royston_survival, na_rm = FALSE ) expect_identical( metric_set( royston_survival_na_rm, royston_survival_no_na_rm )( lung_surv_na, truth = surv_obj, estimate = .pred_linear_pred ), ref ) }) test_that("metric_set() errors on bad input", { expect_snapshot( error = TRUE, metric_set("x") ) expect_snapshot( error = TRUE, metric_set(rmse, "x") ) }) test_that("metric_set() errors on empty input", { expect_snapshot( error = TRUE, metric_set() ) }) test_that("metric_set() errors on mixing incombatible metrics", { expect_snapshot( error = TRUE, metric_set(rmse, accuracy) ) expect_snapshot( error = TRUE, metric_set(rmse, accuracy, brier_survival) ) expect_snapshot( error = TRUE, metric_set(rmse, accuracy, brier_survival, weighted_interval_score) ) }) test_that("can supply `event_level` even with metrics that don't use it", { df <- two_class_example df_rev <- df df_rev$truth <- stats::relevel(df_rev$truth, "Class2") df_rev$predicted <- stats::relevel(df_rev$predicted, "Class2") # accuracy doesn't use it, and doesn't have it as an argument set <- metric_set(accuracy, recall, roc_auc) expect_equal( as.data.frame(set(df, truth, Class1, estimate = predicted)), as.data.frame( set(df_rev, truth, Class1, estimate = predicted, event_level = "second") ) ) }) test_that("`metric_set()` labeling remove namespaces", { x <- metric_set(yardstick::mase, rmse) expect_identical(names(attr(x, "metrics")), c("mase", "rmse")) }) test_that("metric_set can be coerced to a tibble", { x <- metric_set(roc_auc, pr_auc, accuracy) expect_s3_class(dplyr::as_tibble(x), "tbl_df") }) test_that("`metric_set()` errors contain env name for unknown functions (#128)", { foobar <- function() {} # Store env name in `name` attribute for `environmentName()` to find it env <- rlang::new_environment(parent = globalenv()) attr(env, "name") <- "test" rlang::fn_env(foobar) <- env expect_snapshot( error = TRUE, metric_set(accuracy, foobar, sens, rlang::abort) ) expect_snapshot( error = TRUE, metric_set(accuracy, foobar, sens, rlang::abort) ) }) test_that("`metric_set()` gives an informative error for a single non-metric function (#181)", { foobar <- function() {} # Store env name in `name` attribute for `environmentName()` to find it env <- rlang::new_environment(parent = globalenv()) attr(env, "name") <- "test" rlang::fn_env(foobar) <- env expect_snapshot( error = TRUE, metric_set(foobar) ) }) test_that("errors informatively for unevaluated metric factories", { # one bad metric expect_snapshot( error = TRUE, metric_set(demographic_parity) ) expect_snapshot( error = TRUE, metric_set(demographic_parity, roc_auc) ) # two bad metrics expect_snapshot( error = TRUE, metric_set(demographic_parity, equal_opportunity) ) expect_snapshot( error = TRUE, metric_set(demographic_parity, equal_opportunity, roc_auc) ) }) test_that("propagates 'caused by' error message when specifying the wrong column name", { set <- metric_set(accuracy, kap) # There is no `weight` column! expect_snapshot(error = TRUE, { set( two_class_example, truth, Class1, estimate = predicted, case_weights = weight ) }) }) test_that("errors informatively when `estimate` is not named for class metrics", { set <- metric_set(accuracy, kap) expect_snapshot(error = TRUE, { set(two_class_example, truth, predicted) }) }) test_that("errors informatively when `estimate` is not named for survival metrics", { set <- metric_set(concordance_survival) expect_snapshot(error = TRUE, { set(lung_surv, surv_obj, .pred_time) }) }) test_that("metric set functions retain class/prob metric functions", { fns <- attr(metric_set(accuracy, roc_auc), "metrics") expect_equal( names(fns), c("accuracy", "roc_auc") ) expect_equal( class(fns[[1]]), c("class_metric", "metric", "function") ) expect_equal( class(fns[[2]]), c("prob_metric", "metric", "function") ) expect_equal( vapply(fns, function(fn) attr(fn, "direction"), character(1)), c(accuracy = "maximize", roc_auc = "maximize") ) }) test_that("metric set functions retain numeric metric functions", { fns <- attr(metric_set(mae, rmse), "metrics") expect_equal( names(fns), c("mae", "rmse") ) expect_equal( class(fns[[1]]), c("numeric_metric", "metric", "function") ) expect_equal( class(fns[[2]]), c("numeric_metric", "metric", "function") ) expect_equal( vapply(fns, function(fn) attr(fn, "direction"), character(1)), c(mae = "minimize", rmse = "minimize") ) }) test_that("all class metrics - `metric_set()` works with `case_weights`", { # Mock a metric that doesn't support weights accuracy_no_weights <- function(data, truth, estimate, na_rm = TRUE, ...) { # Eat the `...` silently accuracy( data = data, truth = !!enquo(truth), estimate = !!enquo(estimate), na_rm = na_rm ) } accuracy_no_weights <- new_class_metric(accuracy_no_weights, "maximize") set <- metric_set(accuracy, accuracy_no_weights) df <- data.frame( truth = factor(c("x", "x", "y"), levels = c("x", "y")), estimate = factor(c("x", "y", "x"), levels = c("x", "y")), case_weights = c(1L, 1L, 2L) ) expect_identical( set(df, truth, estimate = estimate, case_weights = case_weights)[[ ".estimate" ]], c(1 / 4, 1 / 3) ) }) test_that("all numeric metrics - `metric_set()` works with `case_weights`", { # Mock a metric that doesn't support weights rmse_no_weights <- function(data, truth, estimate, na_rm = TRUE, ...) { # Eat the `...` silently rmse( data = data, truth = !!enquo(truth), estimate = !!enquo(estimate), na_rm = na_rm ) } rmse_no_weights <- new_numeric_metric(rmse_no_weights, "minimize") set <- metric_set(rmse, rmse_no_weights) solubility_test$weight <- read_weights_solubility_test() expect <- c( rmse(solubility_test, solubility, prediction, case_weights = weight)[[ ".estimate" ]], rmse(solubility_test, solubility, prediction)[[".estimate"]] ) expect_identical( set(solubility_test, solubility, prediction, case_weights = weight)[[ ".estimate" ]], expect ) }) test_that("class and prob metrics - `metric_set()` works with `case_weights`", { # Mock a metric that doesn't support weights accuracy_no_weights <- function(data, truth, estimate, na_rm = TRUE, ...) { # Eat the `...` silently accuracy( data = data, truth = !!enquo(truth), estimate = !!enquo(estimate), na_rm = na_rm ) } accuracy_no_weights <- new_class_metric(accuracy_no_weights, "maximize") set <- metric_set(accuracy, accuracy_no_weights, roc_auc) two_class_example$weight <- read_weights_two_class_example() expect <- c( accuracy(two_class_example, truth, predicted, case_weights = weight)[[ ".estimate" ]], accuracy(two_class_example, truth, predicted)[[".estimate"]], roc_auc(two_class_example, truth, Class1, case_weights = weight)[[ ".estimate" ]] ) expect_identical( set( two_class_example, truth, Class1, estimate = predicted, case_weights = weight )[[".estimate"]], expect ) })