library(testthat) library(data.table) test_that("duplicate names is an error", { expect_error({ aum::aum_diffs_binary(c(a=0,b=1,b=0)) }, "if label.vec has names they must be unique, problems: b") }) test_that("non-numeric labels is an error", { expect_error({ aum::aum_diffs_binary(factor(c(-1,1))) }, "label.vec must be numeric vector with length>0 and all elements either 0,1 or -1,1") }) test_that("non-finite label is an error", { expect_error({ aum::aum_diffs_binary(c(-1,1,NA)) }, "label.vec must be numeric vector with length>0 and all elements either 0,1 or -1,1") }) test_that("non-finite label is an error", { expect_error({ aum::aum_diffs_binary(numeric()) }, "label.vec must be numeric vector with length>0 and all elements either 0,1 or -1,1") }) test_that("non-binary label is an error", { expect_error({ aum::aum_diffs_binary(c(-1,1,2)) }, "label.vec must be numeric vector with length>0 and all elements either 0,1 or -1,1") }) exp.df <- data.frame( example=0:1, pred=0, fp_diff=c(1, 0), fn_diff=c(0, -1), row.names=NULL) test_that("binary diffs computed for two un-named labels", { (computed <- aum::aum_diffs_binary(c(0,1))) expect_equal(as.data.frame(computed), exp.df) (computed <- aum::aum_diffs_binary(c(-1,1))) expect_equal(as.data.frame(computed), exp.df) }) exp.df <- data.frame( example=0:1, pred=0, fp_diff=c(1, 0), fn_diff=c(0, -1), row.names=NULL, stringsAsFactors=FALSE) test_that("binary diffs computed for three named labels", { (computed <- aum::aum_diffs_binary(c(a=0,b=1,c=0), c("c","b"))) expect_equal(as.data.frame(computed[order(example)]), exp.df) }) test_that("error for numeric example", { expect_error({ aum::aum_diffs(1, 1, 1, 1) }, "example must be integer vector but has class: numeric") }) test_that("error for non-unique predicted example names", { expect_error({ aum::aum_diffs("ex1", 1, 1, 1, c("a","a","b","b","c")) }, "elements of pred.name.vec must be unique, problems: a, b") }) test_that("error for columns in penalty error", { expect_error({ aum::aum_diffs_penalty(data.frame(example=1L, min.lambda=0, fp=0)) }, "errors.df must have numeric column named fn") expect_error({ aum::aum_diffs_penalty(data.frame(example=1L, min.lambda=0, fn=0)) }, "errors.df must have numeric column named fp") expect_error({ aum::aum_diffs_penalty(data.frame(example=1L, fp=0, fn=0)) }, "errors.df must have numeric column named min.lambda") expect_error({ aum::aum_diffs_penalty(data.frame(example=1, min.lambda=0, fp=0, fn=0)) }, "errors.df must have integer or character column named example") }) test_that("error if min.lambda does not start at 0", { simple.df <- data.frame( example=1L, min.lambda=exp(1:4), fp=c(10,4,4,0), fn=c(0,2,2,10)) expect_error({ aum::aum_diffs_penalty(simple.df, denominator="count") }, "need min.lambda=0 for each example, problems: 1") }) test_that("error if min.lambda repeated", { simple.df <- data.frame( example=1L, min.lambda=c(0, 1, 1, 2), fp=c(10,4,4,0), fn=c(0,2,2,10)) expect_error({ aum::aum_diffs_penalty(simple.df, denominator="count") }, "need only one min.lambda per example, problems with more are (example:min.lambda) 1:1", fixed=TRUE) }) test_that("rate works for one ex", { simple.df <- data.frame( example=1L, min.lambda=c(0, exp(1:3)), fp=c(10,4,4,0), fn=c(0,2,2,10)) (simple.diffs <- aum::aum_diffs_penalty(simple.df, denominator="count")) expect_equal(simple.diffs$pred, c(-3, -1)) expect_equal(simple.diffs$fp_diff, c(4, 6)) expect_equal(simple.diffs$fn_diff, c(-8, -2)) (simple.rates <- aum::aum_diffs_penalty(simple.df, denominator="rate")) expect_equal(simple.rates$pred, c(-3, -1)) expect_equal(simple.rates$fp_diff, c(0.4, 0.6)) expect_equal(simple.rates$fn_diff, c(-0.8, -0.2)) }) test_that("rate works for three ex, one with no diffs", { four.dt <- rbind( data.table( example="one", min.lambda=c(0, exp(1:3)), fp=c(10,4,4,0), fn=c(0,2,2,8)), data.table( example="two", min.lambda=c(0, exp(4:6)), fp=c(1,0,0,0), fn=c(0,0,0,2)), data.table( example="three", min.lambda=c(0, exp(44:46)), fp=c(100,0,0,0), fn=c(0,0,0,100)), data.table( example="constantFN", min.lambda=c(0,exp(9)), fp=c(11,2), fn=c(1,1))) three.ids <- c("one","constantFN","two") (count.diffs <- aum::aum_diffs_penalty(four.dt, three.ids, denominator="count")) expected.counts <- data.frame( example=c(0,0,1,2,2), pred=c(-3,-1,-9,-6,-4), fp_diff=c(4,6,9,0,1), fn_diff=c(-6,-2,0,-2,0)) expect_equal(data.frame(count.diffs), expected.counts) (rate.diffs <- aum::aum_diffs_penalty(four.dt, three.ids, denominator="rate")) (expected.rates <- with(expected.counts, data.frame( example, pred, fp_diff=fp_diff/sum(fp_diff), fn_diff=-fn_diff/sum(fn_diff)))) expect_equal(data.frame(rate.diffs), expected.rates) }) test_that("aum_errors works even if input not sorted", { diff.df <- data.frame( example=as.integer(c(0, 0, 1)), pred=c(2, 1, 3), fp_diff=c(1, 0, 1), fn_diff=c(0, -1, 0)) computed.dt <- aum::aum_errors(diff.df) exp.dt <- data.table( example=as.integer(c(0,0,0,1,1)), min.pred=c(-Inf,1,2,-Inf,3), max.pred=c(1,2,Inf,3,Inf), fp=c(0,0,1,0,1), fn=c(1,0,0,0,0)) expect_equal(computed.dt, exp.dt) })