context("utils-data-dataloader") test_that("dataloader works", { x <- torch_randn(1000, 100) y <- torch_randn(1000, 1) dataset <- tensor_dataset(x, y) dl <- dataloader(dataset = dataset, batch_size = 32) expect_length(dl, 1000 %/% 32 + 1) expect_true(is_dataloader(dl)) iter <- dl$.iter() b <- iter$.next() expect_tensor_shape(b[[1]], c(32, 100)) expect_tensor_shape(b[[2]], c(32, 1)) iter <- dl$.iter() for (i in 1:32) { k <- iter$.next() } expect_equal(iter$.next(), coro::exhausted()) }) test_that("dataloader iteration", { x <- torch_randn(100, 100) y <- torch_randn(100, 1) dataset <- tensor_dataset(x, y) dl <- dataloader(dataset = dataset, batch_size = 32) # iterating with a while loop iter <- dataloader_make_iter(dl) while (!is.null(batch <- dataloader_next(iter))) { expect_tensor(batch[[1]]) expect_tensor(batch[[2]]) } expect_warning(class = "deprecated", { # iterating with an enum for (batch in enumerate(dl)) { expect_tensor(batch[[1]]) expect_tensor(batch[[2]]) } }) }) test_that("can have datasets that don't return tensors", { ds <- dataset( initialize = function() {}, .getitem = function(index) { list( matrix(runif(10), ncol = 10), index, 1:10 ) }, .length = function() { 100 } ) d <- ds() dl <- dataloader(d, batch_size = 32, drop_last = TRUE) # iterating with an enum expect_warning(class = "deprecated", { for (batch in enumerate(dl)) { expect_tensor_shape(batch[[1]], c(32, 1, 10)) expect_true(batch[[1]]$dtype == torch_float()) expect_tensor_shape(batch[[2]], c(32)) expect_tensor_shape(batch[[3]], c(32, 10)) expect_true(batch[[3]]$dtype == torch_long()) } }) expect_true(batch[[1]]$dtype == torch_float32()) expect_true(batch[[2]]$dtype == torch_int64()) expect_true(batch[[3]]$dtype == torch_int64()) }) test_that("dataloader that shuffles", { x <- torch_randn(100, 100) y <- torch_randn(100, 1) d <- tensor_dataset(x, y) dl <- dataloader(dataset = d, batch_size = 50, shuffle = TRUE) expect_warning(class = "deprecated", { for (i in enumerate(dl)) { expect_tensor_shape(i[[1]], c(50, 100)) } }) dl <- dataloader(dataset = d, batch_size = 30, shuffle = TRUE) j <- 0 expect_warning(class = "deprecated", { for (i in enumerate(dl)) { j <- j + 1 if (j == 4) { expect_tensor_shape(i[[1]], c(10, 100)) } else { expect_tensor_shape(i[[1]], c(30, 100)) } } }) }) test_that("named outputs", { ds <- dataset( initialize = function() { }, .getitem = function(i) { list(x = i, y = 2 * i) }, .length = function() { 1000 } )() expect_named(ds[1], c("x", "y")) dl <- dataloader(ds, batch_size = 4) iter <- dataloader_make_iter(dl) expect_named(dataloader_next(iter), c("x", "y")) }) test_that("can use a dataloader with coro", { ds <- dataset( initialize = function() { }, .getitem = function(i) { list(x = i, y = 2 * i) }, .length = function() { 10 } )() expect_named(ds[1], c("x", "y")) dl <- dataloader(ds, batch_size = 5) j <- 1 loop(for (batch in dl) { expect_named(batch, c("x", "y")) expect_tensor_shape(batch$x, 5) expect_tensor_shape(batch$y, 5) }) }) test_that("dataloader works with num_workers", { if (cuda_is_available()) { skip_on_os("windows") } ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { list(x = .worker_info$id) } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2) it <- dataloader_make_iter(dl) i <- 1 expect_warning(class = "deprecated", { for (batch in enumerate(dl)) { expect_equal_to_tensor(batch$x, i * torch_ones(10)) i <- i + 1 } }) }) test_that("dataloader catches errors on workers", { if (cuda_is_available()) { skip_on_os("windows") } ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { stop("the error id is 5567") list(x = .worker_info$id) } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2) iter <- dataloader_make_iter(dl) expect_error( dataloader_next(iter), class = "runtime_error", regexp = "5567" ) }) test_that("woprker init function is respected", { if (cuda_is_available()) { skip_on_os("windows") } ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { list(x = theid) } ) worker_init_fn <- function(id) { theid <<- id * 2 } dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_init_fn = worker_init_fn ) i <- 1 expect_warning(class = "deprecated", { for (batch in enumerate(dl)) { expect_equal_to_tensor(batch$x, i * 2 * torch_ones(10)) i <- i + 1 } }) }) test_that("dataloader timeout is respected", { if (cuda_is_available()) { skip_on_os("windows") } ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { Sys.sleep(10) list(x = 1) } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2, timeout = 5 ) # (timeout is in miliseconds) iter <- dataloader_make_iter(dl) expect_error( dataloader_next(iter), class = "runtime_error", regexp = "timed out" ) }) test_that("can return tensors in multiworker dataloaders", { if (cuda_is_available()) { skip_on_os("windows") } ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { list(x = torch_scalar_tensor(1)) } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2) expect_warning(class = "deprecated", { for (batch in enumerate(dl)) { expect_equal_to_tensor(batch$x, torch_ones(10)) } }) }) test_that("can make reproducible runs", { if (cuda_is_available()) { skip_on_os("windows") } ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { list(x = runif(1), y = torch_randn(1)) } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2) set.seed(1) iter <- dataloader_make_iter(dl) b1 <- dataloader_next(iter) set.seed(1) iter <- dataloader_make_iter(dl) b2 <- dataloader_next(iter) expect_equal(b1$x, b2$x) expect_equal_to_tensor(b1$y, b2$y) }) test_that("load packages in dataloader", { ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { torch_tensor("coro" %in% (.packages())) } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2) iter <- dataloader_make_iter(dl) b1 <- dataloader_next(iter) expect_equal(torch_any(b1)$item(), FALSE) dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_packages = "coro") iter <- dataloader_make_iter(dl) b1 <- dataloader_next(iter) expect_equal(torch_all(b1)$item(), TRUE) }) test_that("globals can be found", { ds <- dataset( .length = function() { 20 }, initialize = function() {}, .getitem = function(id) { hello_fn() } ) dl <- dataloader(ds(), batch_size = 10, num_workers = 2) iter <- dataloader_make_iter(dl) expect_error( b1 <- dataloader_next(iter) ) expect_error( dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_globals = c("hello", "world") ), class = "runtime_error" ) hello_fn <- function() { torch_randn(5, 5) } dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_globals = list( hello_fn = hello_fn )) iter <- dataloader_make_iter(dl) expect_tensor_shape(dataloader_next(iter), c(10, 5, 5)) dl <- dataloader(ds(), batch_size = 10, num_workers = 2, worker_globals = "hello_fn" ) iter <- dataloader_make_iter(dl) expect_tensor_shape(dataloader_next(iter), c(10, 5, 5)) }) test_that("datasets can use an optional .getbatch method for speedups", { d <- dataset( initialize = function() {}, .getbatch = function(indexes) { list( torch_randn(length(indexes), 10), torch_randn(length(indexes), 1) ) }, .length = function() { 100 } ) dl <- dataloader(d(), batch_size = 10) coro::loop(for (x in dl) { expect_length(x, 2) expect_tensor_shape(x[[1]], c(10, 10)) expect_tensor_shape(x[[2]], c(10, 1)) }) }) test_that("dataloaders handle .getbatch that don't necessarily return a torch_tensor", { d <- dataset( initialize = function() {}, .getbatch = function(indexes) { list( array(0, dim = c(length(indexes), 10)), array(0, dim = c(length(indexes), 1)) ) }, .length = function() { 100 } ) dl <- dataloader(d(), batch_size = 10) coro::loop(for (x in dl) { expect_length(x, 2) expect_tensor_shape(x[[1]], c(10, 10)) expect_tensor_shape(x[[2]], c(10, 1)) }) }) test_that("a value error is returned when its not possible to convert", { d <- dataset( initialize = function() {}, .getbatch = function(indexes) { "a" }, .length = function() { 100 } ) expect_error( dataloader_next(dataloader_make_iter(dataloader(d(), batch_size = 10))), regexp = "Can't convert data of class.*", class = "value_error" ) }) test_that("warning tensor", { dt <- dataset( initialize = function() { self$x <- torch_randn(100, 100) private$k <- torch_randn(10, 10) self$z <- list( k = torch_tensor(1), torch_tensor(2) ) }, .getitem = function(i) { torch_randn(1, 1) }, .length = function() { 100 }, active = list( y = function() { torch_randn(1) } ), private = list( k = 1 ) ) dt <- dt() expect_warning( x <- dataloader(dt, batch_size = 2, num_workers = 10), regexp = "parallel dataloader" ) }) test_that("collate works with bool data", { data <- replicate(10, torch_randn(5)) example_ds <- dataset( "example_dataset", initialize = function(numbers) { self$numbers <- numbers }, .getitem = function(i) { x <- self$numbers[[i]] list(x = x, n = as.array(torch_mean(x) > 0)) }, .length = function() length(self$numbers) ) example_ds_inst <- example_ds(numbers = data) expect_true(is.logical(example_ds_inst[1]$n)) example_dl <- dataloader( example_ds_inst, batch_size = 2 ) out <- coro::collect(example_dl, 1)[[1]]$n expect_true(out$dtype == torch_bool()) }) test_that("can use dataloaders on iterable datasets", { ids <- iterable_dataset( "ids", initialize = function(n = 320) { self$n <- n }, .iter = function() { i <- 0 function() { i <<- i + 1 if (i <= self$n) { i } else { coro::exhausted() } } } ) dl <- dataloader(ids(), batch_size = 32) data <- coro::collect(dl) expect_equal(length(data), 10) expect_equal(data[[10]]$shape, 32) dl <- dataloader(ids(33), batch_size = 32) data <- coro::collect(dl) expect_equal(length(data), 2) expect_equal(data[[2]]$shape, 1) dl <- dataloader(ids(33), batch_size = 32, drop_last = TRUE) data <- coro::collect(dl) expect_equal(length(data), 1) # length can be NA for iterable datasets expect_true(is.na(length(dl))) }) test_that("correctly reports length for iterable datasets that provide length", { ids <- iterable_dataset( "ids", initialize = function(n = 320) { self$n <- n }, .iter = function() { i <- 0 function() { i <<- i + 1 if (i <= self$n) { i } else { coro::exhausted() } } }, .length = function() { self$n } ) expect_equal(length(ids()), 320) dl <- dataloader(ids(), batch_size = 32) expect_equal(length(dl), 10) dl <- dataloader(ids(33), batch_size = 32) expect_equal(length(dl), 2) dl <- dataloader(ids(33), batch_size = 32, drop_last = TRUE) expect_equal(length(dl), 1) }) test_that("a case that errors in luz", { get_iterable_ds <- iterable_dataset( "iterable_ds", initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) { self$len <- len self$x <- torch::torch_randn(size = c(len, x_size)) self$y <- torch::torch_randn(size = c(len, y_size)) }, .iter = function() { i <- 0 function() { i <<- i + 1 if (i > self$len) { return(coro::exhausted()) } list( x = self$x[i,..], y = self$y[i,..] ) } } ) ds <- get_iterable_ds() dl <- dataloader(ds, batch_size = 32) expect_equal(length(coro::collect(dl)), 4) })