context("utils-data-sampler") test_that("sampler's length", { x <- torch_randn(1000, 10) y <- torch_randn(1000) data <- tensor_dataset(x, y) sampler <- SequentialSampler(data) expect_length(sampler, 1000) sampler <- RandomSampler(data, num_samples = 10) expect_length(sampler, 10) sampler <- RandomSampler(data) expect_length(sampler, 1000) batch <- BatchSampler(sampler = sampler, batch_size = 32, drop_last = TRUE) expect_length(batch, 1000 %/% 32) batch <- BatchSampler(sampler = sampler, batch_size = 32, drop_last = FALSE) expect_length(batch, 1000 %/% 32 + 1) batch <- BatchSampler(sampler = sampler, batch_size = 100, drop_last = FALSE) expect_length(batch, 10) batch <- BatchSampler(sampler = sampler, batch_size = 1000, drop_last = FALSE) expect_length(batch, 1) batch <- BatchSampler(sampler = sampler, batch_size = 1001, drop_last = FALSE) expect_length(batch, 1) batch <- BatchSampler(sampler = sampler, batch_size = 1001, drop_last = TRUE) expect_length(batch, 0) }) test_that("Random sampler, replacement = TRUE", { x <- torch_randn(2, 10) y <- torch_randn(2) data <- tensor_dataset(x, y) x <- RandomSampler(data, replacement = TRUE) it <- x$.iter() for (i in 1:length(x)) { k <- it() expect_true(k <= 2 && k >= 1) } expect_equal(it(), coro::exhausted()) }) test_that("Batch sampler", { x <- torch_randn(100, 10) y <- torch_randn(2) data <- tensor_dataset(x, y) r <- RandomSampler(data, replacement = FALSE) x <- BatchSampler(r, 32, TRUE) it <- x$.iter() expect_length(it(), 32) expect_length(it(), 32) expect_length(it(), 32) expect_equal(it(), coro::exhausted()) })