testthat::test_that("OT object forms tensor", { causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) # giving masses testthat::expect_silent(ot1 <- causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "auto", diameter=NULL)) # no masses testthat::expect_silent(ot2 <- causalOT:::OT$new(x = x, y = y, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "auto", diameter=NULL)) testthat::expect_equal(ot1, ot2) }) testthat::test_that("OT object forms online", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) # giving masses causalOT:::rkeops_check() ot1 <- causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) # no masses testthat::expect_silent(ot2 <- causalOT:::OT$new(x = x, y = y, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL)) testthat::expect_equal(ot1, ot2) testthat::expect_equal(ot1$diameter, 21.12901, tolerance = 1e-4) }) testthat::test_that("sinkhorn_loop runs, tensor", { causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-10 ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "auto", diameter=NULL) output <- ot1$.__enclos_env__$private$sinkhorn_loop(niter, tol) at <- torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) bt <- torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device) loss <- sum(output$f_xy * at) + sum(output$g_yx * bt) testthat::expect_equal(loss$item(), 2.786224, tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * at )$item(), 2.786224, tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * bt)$item(), 0, tolerance = 1e-3) # compare # pot <- causalOT::sinkhorn(x = x*sqrt(0.5), y = y*sqrt(0.5), a = a, b = b, power = 2, blur = 10, scaling = 0.99, debias = FALSE ) # sum(pot$f * a) + sum(pot$g * b) #2.786224 total, f = 1.390419, g = 1.395806 }) testthat::test_that("sinkhorn_loop runs, online", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-8 # giving masses causalOT:::rkeops_check() ot <- causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) output <- ot$.__enclos_env__$private$sinkhorn_loop(niter, tol) ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "auto", diameter=NULL) output1 <- ot1$.__enclos_env__$private$sinkhorn_loop(niter, tol) loss <- sum(output$f_xy * torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device)) + sum(output$g_yx * torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device)) loss1 <- sum(output1$f_xy * torch::torch_tensor(a, dtype = output1$f_xy$dtype, device = output1$f_xy$device) ) + sum(output1$g_yx * torch::torch_tensor(b, dtype = output1$g_xy$dtype, device = output1$f_xy$device)) testthat::expect_equal(loss$item(), loss1$item(), tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) )$item(), sum(output1$f_xy * torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) )$item(), tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device))$item(), sum(output1$g_yx * torch::torch_tensor(b, dtype = output1$g_xy$dtype, device = output1$f_xy$device))$item(), tolerance = 1e-3) }) testthat::test_that("sinkhorn_self runs, tensor", { causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-10 ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "auto", diameter=NULL) output_x <- ot1$.__enclos_env__$private$sinkhorn_self_loop("x", niter, tol) output_y <- ot1$.__enclos_env__$private$sinkhorn_self_loop("y", niter, tol) loss_x <- sum(output_x * torch::torch_tensor(a, dtype = output_x$dtype, device = output_x$device) ) * 2 testthat::expect_equal(loss_x$item(), 2.17266, tolerance = 1e-4) loss_y <- sum(output_y * torch::torch_tensor(b, dtype = output_y$dtype, device = output_y$device) ) * 2 testthat::expect_equal(loss_y$item(), 2.99741, tolerance = 1e-4) }) testthat::test_that("sinkhorn_self runs, online", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-8 # giving masses causalOT:::rkeops_check() ot <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) output_x <- ot$.__enclos_env__$private$sinkhorn_self_loop("x", niter, tol) output_y <- ot$.__enclos_env__$private$sinkhorn_self_loop("y", niter, tol) ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "auto", diameter=NULL) output_x1 <- ot1$.__enclos_env__$private$sinkhorn_self_loop("x", niter, tol) output_y1 <- ot1$.__enclos_env__$private$sinkhorn_self_loop("y", niter, tol) loss_x <- sum(output_x * torch::torch_tensor(a, dtype = output_x$dtype, device = output_x$device) ) * 2 loss_x1 <- sum(output_x1 * torch::torch_tensor(a, dtype = output_x1$dtype, device = output_x1$device) ) * 2 testthat::expect_equal(loss_x$item(), loss_x1$item(), tolerance = 1e-4) loss_y <- sum(output_y * torch::torch_tensor(b, dtype = output_y$dtype, device = output_y$device) ) * 2 loss_y1 <- sum(output_y1 * torch::torch_tensor(b, dtype = output_y1$dtype, device = output_y1$device) ) * 2 testthat::expect_equal(loss_y$item(), loss_y1$item(), tolerance = 1e-4) }) testthat::test_that("sinkhorn_cot runs, tensor", { testthat::skip_on_cran() testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-10 ot <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "auto", diameter=NULL) ot$sinkhorn_cot(niter, tol) output <- ot$potentials at <- torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) bt <- torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device) loss <- sum(output$f_xy * at) + sum(output$g_yx * bt) testthat::expect_equal(loss$item(), 2.786224, tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * at )$item(),2.786224, tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * bt)$item(), 0, tolerance = 1e-3) testthat::expect_true(all(output$f_xx==0)) testthat::expect_true(all(as.logical(output$g_yy==0))) ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "auto", diameter=NULL) output1 <- ot1$sinkhorn_cot(niter, tol)$potentials loss1 <- sum(output$f_xy * at ) + sum(output$g_yx * bt) testthat::expect_equal(loss1$item(), loss$item(), tolerance = 1e-4) testthat::expect_equal(sum(output1$f_xx * at )$item()*2, 2.17266, tolerance = 1e-4) testthat::expect_true(all(as.logical((output1$g_yy==0)$to(device = "cpu")))) }) testthat::test_that("sinkhorn_cot runs, online", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-10 causalOT:::rkeops_check() ot <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "online", diameter=NULL) output <- ot$sinkhorn_cot(niter, tol)$potentials at <- torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) bt <- torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device) loss <- sum(output$f_xy * at) + sum(output$g_yx * bt) testthat::expect_equal(loss$item(), 2.786224, tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * at )$item(),2.786224, tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * bt)$item(), 0, tolerance = 1e-3) testthat::expect_true(all(as.logical(output$f_xx==0))) testthat::expect_true(all(as.logical(output$g_yy==0))) ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) output1 <- ot1$sinkhorn_cot( niter, tol)$potentials loss1 <- sum(output$f_xy * at ) + sum(output$g_yx * bt) testthat::expect_equal(loss1, loss, tolerance = 1e-4) testthat::expect_equal(sum(output1$f_xx * at )$item()*2, 2.17266, tolerance = 1e-4) testthat::expect_true(all(as.logical((output1$g_yy==0)$to(device = "cpu")))) }) testthat::test_that("sinkhorn_dist runs, tensor", { causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-10 ot1 <- causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "auto", diameter=NULL) testthat::expect_error( causalOT:::sinkhorn_dist(ot1)) ot1$sinkhorn_opt(niter, tol) loss <- causalOT:::sinkhorn_dist(ot1) output <- ot1$potentials at <- torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) bt <- torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device) testthat::expect_equal(loss$item(), 2.786224, tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * at )$item(),2.786224, tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * bt)$item(), 0, tolerance = 1e-3) testthat::expect_equal(loss$item(), sum(output$f_xy * at )$item() + sum(output$g_yx * bt)$item(), tol = 1e-4) # check primal primal_loss <- (exp((output$f_xy$view(c(n,1)) - ot1$C_xy$data + output$g_yx$view(c(1,ot1$m)))/ot1$penalty + ot1$.__enclos_env__$private$a_log$view(c(n,1)) + ot1$.__enclos_env__$private$b_log$view(c(1,m))) * ot1$C_xy$data)$sum() testthat::expect_true((primal_loss < loss)$item()) # compare # pot <- causalOT::sinkhorn(x = x*sqrt(0.5), y = y*sqrt(0.5), a = a, b = b, power = 2, blur = 10, scaling = 0.99, debias = FALSE ) # sum(pot$f * a) + sum(pot$g * b) #2.786224 total, f = 1.390419, g = 1.395806 }) testthat::test_that("sinkhorn_dist runs, online", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-8 # giving masses causalOT:::rkeops_check() ot <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) output <- ot$sinkhorn_opt(niter, tol)$potentials loss <- causalOT:::sinkhorn_dist(ot) ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "auto", diameter=NULL) output1 <- ot1$sinkhorn_opt(niter, tol)$potentials loss1 <- causalOT:::sinkhorn_dist(ot1) at <- torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) bt <- torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device) testthat::expect_equal(loss$item(), loss1$item(), tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * at )$item(), sum(output1$f_xy * at )$item(), tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * bt)$item(), sum(output1$g_yx * bt)$item(), tolerance = 1e-3) }) testthat::test_that("sinkhorn_loop runs, tensor", { causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-10 ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "auto", diameter=NULL) output <- ot1$.__enclos_env__$private$sinkhorn_loop(niter, tol) at <- torch::torch_tensor(a, dtype = output$f_xy$dtype, device = output$f_xy$device) bt <- torch::torch_tensor(b, dtype = output$g_yx$dtype, device = output$g_yx$device) loss <- sum(output$f_xy * at) + sum(output$g_yx * bt) testthat::expect_equal(loss$item(), 2.786224, tolerance = 1e-4) testthat::expect_equal(sum(output$f_xy * at )$item(), 2.786224, tolerance = 1e-3) testthat::expect_equal(sum(output$g_yx * bt)$item(), 0, tolerance = 1e-3) # compare # pot <- causalOT::sinkhorn(x = x*sqrt(0.5), y = y*sqrt(0.5), a = a, b = b, power = 2, blur = 10, scaling = 0.99, debias = FALSE ) # sum(pot$f * a) + sum(pot$g * b) #2.786224 total, f = 1.390419, g = 1.395806 }) testthat::test_that("sinkhorn_loop gradient", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- torch::torch_tensor(matrix(stats::rnorm(n*d), n, d), dtype = torch::torch_double(), requires_grad = TRUE) y <- torch::torch_tensor(matrix(stats::rnorm(m*d), m, d), dtype = torch::torch_double(), requires_grad = TRUE) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-8 # giving masses causalOT:::rkeops_check() ot <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) output <- ot$.__enclos_env__$private$sinkhorn_loop(niter, tol) testthat::expect_true(output$f_xy$requires_grad) testthat::expect_true(output$g_yx$requires_grad) ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "tensorized", diameter=NULL) output1 <- ot1$.__enclos_env__$private$sinkhorn_loop(niter, tol) testthat::expect_true(output1$f_xy$requires_grad) testthat::expect_true(output1$g_yx$requires_grad) }) testthat::test_that("sinkhorn_loop gradient", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- 10 x <- torch::torch_tensor(matrix(stats::rnorm(n*d), n, d), dtype = torch::torch_double(), requires_grad = TRUE) y <- torch::torch_tensor(matrix(stats::rnorm(m*d), m, d), dtype = torch::torch_double(), requires_grad = TRUE) a <- rep(1/n, n) b <- rep(1/m, m) niter <- 1000 tol <- 1e-8 ot1 <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "tensorized", diameter=NULL) output1x <- ot1$.__enclos_env__$private$sinkhorn_self_loop(which.margin = "x", niter, tol) output1y <- ot1$.__enclos_env__$private$sinkhorn_self_loop(which.margin = "y", niter, tol) testthat::expect_true(output1x$requires_grad) testthat::expect_true(output1y$requires_grad) # giving masses causalOT:::rkeops_check() ot <-causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) outputx <- ot$.__enclos_env__$private$sinkhorn_self_loop(which.margin = "x", niter, tol) outputy <- ot$.__enclos_env__$private$sinkhorn_self_loop(which.margin = "y", niter, tol) testthat::expect_true(outputx$requires_grad) testthat::expect_true(outputy$requires_grad) }) testthat::test_that("OT infinite penalty distances online", { testthat::skip_on_cran() testthat::skip_if_not_installed(pkg="rkeops") testthat::skip_on_ci() causalOT:::torch_check() set.seed(1231) n <- 15 m <- 13 d <- 3 penalty <- Inf x <- matrix(stats::rnorm(n*d), n, d) y <- matrix(stats::rnorm(m*d), m, d) a <- rep(1/n, n) b <- rep(1/m, m) # giving masses causalOT:::rkeops_check() ot1 <- causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = TRUE, tensorized = "online", diameter=NULL) testthat::expect_silent(loss1 <- causalOT:::energy_dist(ot1)) testthat::expect_silent(loss2 <- causalOT:::inf_sinkhorn_dist(ot1)) testthat::expect_equal(loss1,loss2) causalOT:::rkeops_check() ot2 <- causalOT:::OT$new(x = x, y = y, a = a, b = b, penalty = penalty, cost = NULL, p = 2, debias = FALSE, tensorized = "online", diameter=NULL) testthat::expect_silent(loss3 <- causalOT:::inf_sinkhorn_dist(ot2)) testthat::expect_true(ot2$a$dtype == loss3$dtype) # debugonce(causalOT:::inf_sinkhorn_dist) ot2$a <- torch::torch_tensor(ot2$a, requires_grad = TRUE) loss <- causalOT:::inf_sinkhorn_dist(ot2) testthat::expect_silent(loss$backward()) })