test_that("max with indices", { x <- torch_tensor(c(5, 6, 7, 8)) m <- torch_max(x, dim = 1) expect_equal_to_r(m[[2]]$to(dtype = torch_int()), 4) expect_equal_to_r( torch_max(c(2, 1), other = c(1, 2)), c(2, 2) ) }) test_that("min with indices", { x <- torch_tensor(c(5, 6, 7, 8)) m <- torch_min(x, dim = 1) expect_equal_to_r(m[[2]]$to(dtype = torch_int()), 1) expect_equal_to_r( torch_min(c(2, 1), other = c(1, 2)), c(1, 1) ) }) test_that("argsort", { x <- torch_tensor(c(3, 2, 1)) expect_equal_to_r(torch_argsort(x), c(3, 2, 1)) expect_equal_to_r(x$argsort(), c(3, 2, 1)) x <- torch_tensor(c(1, 2, 3)) expect_equal_to_r(torch_argsort(x, descending = TRUE), c(3, 2, 1)) expect_equal_to_r(x$argsort(descending = TRUE), c(3, 2, 1)) x <- torch_tensor(1:10)$view(c(5, 2)) expect_equal_to_r(torch_argsort(x, dim = 1)[, 1], 1:5) expect_equal_to_r(x$argsort(dim = 1)[, 1], 1:5) expect_equal_to_r(torch_argsort(x, dim = 2)[, 1], rep(1, 5)) expect_equal_to_r(x$argsort(dim = 2)[, 1], rep(1, 5)) }) test_that("argmax", { x <- torch_tensor(c(1, 2, 3)) expect_equal_to_r(torch_argmax(x), 3) expect_equal_to_r(x$argmax(), 3) x <- torch_tensor(c(3, 2, 1)) expect_equal_to_r(torch_argmax(x), 1) expect_equal_to_r(x$argmax(), 1) x <- torch_tensor(1:9)$reshape(c(3, 3)) expect_equal_to_r(torch_argmax(x, dim = 2), c(3, 3, 3)) expect_equal(torch_argmax(x, dim = 2, keepdim = TRUE)$shape, c(3, 1)) }) test_that("argmin", { x <- torch_tensor(c(1, 2, 3)) expect_equal_to_r(torch_argmin(x), 1) expect_equal_to_r(x$argmin(), 1) x <- torch_tensor(c(3, 2, 1)) expect_equal_to_r(torch_argmin(x), 3) expect_equal_to_r(x$argmin(), 3) x <- torch_tensor(1:9)$reshape(c(3, 3)) expect_equal_to_r(torch_argmin(x, dim = 2), c(1, 1, 1)) expect_equal(torch_argmin(x, dim = 2, keepdim = TRUE)$shape, c(3, 1)) }) test_that("sort", { x <- torch_tensor(sample(1e2)) expect_equal_to_r(torch_sort(x)[[2]], order(as.integer(x))) expect_equal_to_r(torch_sort(x, descending = TRUE)[[2]], order(as.integer(x), decreasing = TRUE)) expect_equal_to_r(x$sort()[[2]], order(as.integer(x))) expect_equal_to_r(x$sort(descending = TRUE)[[2]], order(as.integer(x), decreasing = TRUE)) }) test_that("bincount is 1 indexed", { x <- torch_tensor(c(1,2,3,1), dtype = torch_int64()) out <- torch_bincount(x) expect_length(out, 3) out <- x$bincount() expect_length(out, 3) x <- torch_tensor(c(1,2,3,1,0), dtype = torch_int64()) expect_error({ out <- torch_bincount(x) }, regexp = "Indexing starts at 1 but found a 0.") })