context("autograd") test_that("can autograd", { x <- torch_tensor(c(1), requires_grad = TRUE) y <- 2 * x expect_invisible(y$backward()) expect_equal_to_r(x$grad, 2) }) test_that("can autograd with contexts", { with_no_grad({ with_enable_grad({ x <- torch_tensor(c(1), requires_grad = TRUE) y <- 2 * x }) }) expect_invisible(y$backward()) expect_equal_to_r(x$grad, 2) }) test_that("requires_grad works", { x <- torch_tensor(c(1), requires_grad = TRUE) expect_true(x$requires_grad) x <- torch_tensor(c(1), requires_grad = FALSE) expect_true(!x$requires_grad) x <- torch_tensor(c(1), requires_grad = FALSE) x$requires_grad_(TRUE) expect_true(x$requires_grad) x$requires_grad_(FALSE) expect_true(!x$requires_grad) }) test_that("register_hook", { x <- torch_tensor(c(2), requires_grad = TRUE) x$register_hook(function(grad) { print("hello") }) y <- 2 * x expect_output(y$backward(), "hello") expect_equal_to_r(x$grad, 2) # correctly sees the gradient x <- torch_tensor(c(2), requires_grad = TRUE) x$register_hook(function(grad) { print(grad) }) y <- 2 * x expect_output(y$backward(), "torch_tensor") x <- torch_tensor(c(2), requires_grad = TRUE) x$register_hook(function(grad) { print("ABABA") }) y <- 2 * x y$register_hook(function(grad) { print("EBEBE") }) expect_output(y$backward(), "EBEBE.*ABABA") }) test_that("register hook: can throw exceptions in the lantern thread", { x <- torch_tensor(c(2), requires_grad = TRUE) x$register_hook(function(grad) { 2 * grad }) y <- 2 * x y$backward() expect_equal_to_r(x$grad, 4) expect_error(y$backward()) }) test_that("register hook: can throw exceptions in the hook", { x <- torch_tensor(c(2), requires_grad = TRUE) x$register_hook(function(grad) { stop() }) y <- 2 * x expect_error(y$backward()) }) test_that("register_hook: grad non leaf", { # see https://github.com/pytorch/pytorch/blob/e0ee8000ac68ae58580ca62a59d5f40a9dd8710c/test/test_autograd.py#L400 # This checks an edge case for register_hook. # We want to capture grad of a nonleaf tensor, # but avoid segfault during backward of other nonleaf tensors x <- torch_randn(5, requires_grad = TRUE) x_list <- x$unbind() x0 <- x_list[[1]] hook_results <- NULL hook <- function(grad) { hook_results <<- grad } x0$register_hook(hook) x_list[[1]]$backward() expect_equal_to_r(hook_results, 1) expect_equal_to_r(x$grad, c(1, 0, 0, 0, 0)) }) test_that("register_hook: can call a the hook inside a hook", { v <- NULL x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x x$register_hook(function(grad) v <<- c(v, "x")) a <- torch_tensor(1, requires_grad = TRUE) b <- 2 * a a$register_hook(function(grad) { v <<- c(v, "a") y$backward() }) a$backward() expect_equal(v, c("a", "x")) # add one more level of nesting v <- NULL x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x x$register_hook(function(grad) v <<- c(v, "x")) a <- torch_tensor(1, requires_grad = TRUE) b <- 2 * a a$register_hook(function(grad) { v <<- c(v, "a") y$backward() }) k <- torch_tensor(1, requires_grad = TRUE) l <- 2 * k k$register_hook(function(grad) { v <<- c(v, "k") a$backward() }) l$backward() expect_equal(v, c("k", "a", "x")) }) test_that("register_hook: can have 2 hooks that call backwards on different graphs", { v <- NULL x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x x$register_hook(function(grad) v <<- c(v, "x")) a <- torch_tensor(1, requires_grad = TRUE) b <- 2 * a a$register_hook(function(grad) { v <<- c(v, "a") }) k <- torch_tensor(1, requires_grad = TRUE) l <- 2 * k k$register_hook(function(grad) { v <<- c(v, "k") a$backward() }) l$register_hook(function(grad) { v <<- c(v, "l") y$backward() }) l$backward() expect_equal(v, c("l", "x", "k", "a")) }) test_that("remove_hook", { v <- NULL x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x hook <- x$register_hook(function(grad) v <<- c(v, "x")) y$backward() expect_equal(v, "x") # hook was added hook$remove() y <- 2 * x y$backward() expect_equal(v, "x") # hook has been removed }) test_that("Simple autograd extension", { custom_pow <- autograd_function( forward = function(ctx, var1) { ctx$save_for_backward(var1) var1^2 }, backward = function(ctx, grad_output) { v <- ctx$saved_variables[[1]] expect_tensor(v) list(var1 = 2 * v) } ) x <- torch_tensor(c(3), requires_grad = TRUE) out <- custom_pow(x) out$backward() expect_equal_to_r(out, 9) expect_equal_to_r(x$grad, 6) x <- torch_tensor(c(3), requires_grad = TRUE) out <- custom_pow(x) * custom_pow(x) out$backward() x$grad expect_equal_to_r(out, 81) expect_equal_to_r(x$grad, 12) }) test_that("Autograd extension envolving 2 variables", { custom <- autograd_function( forward = function(ctx, a, b) { ctx$save_for_backward(a, b) a^2 + b^3 }, backward = function(ctx, grad_output) { a <- ctx$saved_variables[[1]] b <- ctx$saved_variables[[2]] expect_tensor(a) expect_tensor(b) # we must respect the named list and not the order. list(b = 3 * (b^2), a = 2 * a) } ) x <- torch_tensor(c(3), requires_grad = TRUE) y <- torch_tensor(c(4), requires_grad = TRUE) out <- custom(x, y) out$backward() expect_equal_to_r(out, (3^2) + (4^3)) expect_equal_to_r(x$grad, 2 * 3) expect_equal_to_r(y$grad, 3 * (4^2)) x <- torch_tensor(c(3), requires_grad = TRUE) y <- torch_tensor(c(4), requires_grad = TRUE) # order the order of the arguments must be respected even # even if they are passed in a different order in the # function call. out <- custom(b = y, a = x) out$backward() expect_equal_to_r(out, (3^2) + (4^3)) expect_equal_to_r(x$grad, 2 * 3) expect_equal_to_r(y$grad, 3 * (4^2)) }) test_that("Named values in saved variables", { custom_pow <- autograd_function( forward = function(ctx, var1) { ctx$save_for_backward(var1 = var1, var2 = 2 * var1) var1^2 }, backward = function(ctx, grad_output) { v <- ctx$saved_variables expect_tensor(v$var1) expect_tensor(v$var2) list(var1 = v$var2) } ) x <- torch_tensor(c(3), requires_grad = TRUE) out <- custom_pow(x) out$backward() expect_equal_to_r(out, 9) expect_equal_to_r(x$grad, 6) }) test_that("Can have optional arguments in forward", { linear <- autograd_function( forward = function(ctx, input, weight, bias = NULL) { ctx$save_for_backward(input = input, weight = weight, bias = bias) output <- input$mm(weight$t()) if (!is.null(bias)) { output <- output + bias$unsqueeze(1)$expand_as(output) } output }, backward = function(ctx, grad_output) { s <- ctx$saved_variables grads <- list( input = NULL, weight = NULL, bias = NULL ) if (ctx$needs_input_grad$input) { grads$input <- grad_output$mm(s$weight) } if (ctx$needs_input_grad$weight) { grads$weight <- grad_output$t()$mm(s$input) } if (!is.null(s$bias) && ctx$needs_input_grad$bias) { grads$bias <- grad_output$sum(dim = 1) } grads } ) x <- torch_tensor(matrix(c(1, 1, 1, 1), ncol = 2)) w <- torch_tensor(matrix(c(2, 3), ncol = 2), requires_grad = TRUE) y <- torch_tensor(matrix(c(2, 2), ncol = 1)) o <- linear(x, w) l <- torch_mean((y - o)^2) l$backward() expect_equal_to_r(w$grad, matrix(c(6, 6), ncol = 2)) x <- torch_tensor(matrix(c(1, 1, 1, 1), ncol = 2)) w <- torch_tensor(matrix(c(2, 3), ncol = 2), requires_grad = TRUE) b <- torch_tensor(0, requires_grad = TRUE) y <- torch_tensor(matrix(c(2, 2), ncol = 1)) o <- linear(x, w, b) l <- torch_mean((y - o)^2) l$backward() expect_equal_to_r(w$grad, matrix(c(6, 6), ncol = 2)) expect_equal_to_r(b$grad, 6) }) test_that("Catch errors in forward and backward R functions", { custom_pow <- autograd_function( forward = function(ctx, var1) { stop("stop forward") ctx$save_for_backward(var1) var1^2 }, backward = function(ctx, grad_output) { v <- ctx$saved_variables[[1]] expect_tensor(v) list(var1 = 2 * v) } ) x <- torch_tensor(1, requires_grad = TRUE) expect_error(custom_pow(x), "stop forward") custom_pow <- autograd_function( forward = function(ctx, var1) { ctx$save_for_backward(var1) var1^2 }, backward = function(ctx, grad_output) { v <- ctx$saved_variables[[1]] expect_tensor(v) stop("stop backward") list(var1 = 2 * v) } ) r <- custom_pow(x) expect_error(r$backward(), "stop backward") }) test_that("Can pass constants to save_for_backward", { custom_pow <- autograd_function( forward = function(ctx, var1, i) { ctx$save_for_backward(var1 = var1, i = i) var1^i }, backward = function(ctx, grad_output) { v <- ctx$saved_variables expect_tensor(v$var1) expect_is(v$i, "numeric") expect_named(v, c("var1", "i")) list(var1 = v$i * (v$var1^(v$i - 1))) } ) x <- torch_tensor(2, requires_grad = TRUE) r <- custom_pow(x, 2) r$backward() expect_equal_to_r(r, 4) expect_equal_to_r(x$grad, 2 * 2) x <- torch_tensor(2, requires_grad = TRUE) r <- custom_pow(x, 3) r$backward() expect_equal_to_r(r, 8) expect_equal_to_r(x$grad, 3 * 2^(3 - 1)) }) test_that("Forward can return a list", { custom_pow <- autograd_function( forward = function(ctx, var1, i) { ctx$save_for_backward(var1 = var1, i = i) list(var1^i, var1^(i + 1)) }, backward = function(ctx, grad_output) { v <- ctx$saved_variables expect_tensor(grad_output[[1]]) expect_tensor(grad_output[[2]]) list(var1 = v$i * (v$var1^(v$i - 1))) } ) x <- torch_tensor(2, requires_grad = TRUE) r <- custom_pow(x, 2) r[[1]]$backward() expect_equal_to_r(x$grad, 4) custom_pow <- autograd_function( forward = function(ctx, var1, i) { ctx$save_for_backward(var1 = var1, i = i) list(var1^i, var1^(i + 1)) }, backward = function(ctx, out1, out2) { v <- ctx$saved_variables expect_tensor(out1) expect_tensor(out2) list(var1 = v$i * (v$var1^(v$i - 1))) } ) x <- torch_tensor(2, requires_grad = TRUE) r <- custom_pow(x, 2) r[[1]]$backward() expect_equal_to_r(x$grad, 4) }) test_that("can use mark_dirty", { # https://github.com/pytorch/pytorch/blob/master/test/test_autograd.py#L1936 inplace <- autograd_function( forward = function(ctx, a, b) { ctx$mark_dirty(a) list(a$add_(b), b + 2) }, backward = function(ctx, ga, gb) { list(a = ga, b = ga + gb) } ) x <- torch_tensor(2) y <- torch_tensor(3, requires_grad = TRUE) r <- inplace(x, y) expect_equal_to_tensor(r[[1]], x) expect_true(r[[1]]$requires_grad) r[[1]]$backward() expect_equal_to_r(y$grad, 1) # https://github.com/pytorch/pytorch/blob/master/test/test_autograd.py#L1388 double_in_place <- autograd_function( forward = function(ctx, x) { x$mul_(2) ctx$mark_dirty(x) list(x, x) }, backward = function(ctx, g1, g2) { list(x = g1 * 2 + g2 * 2) } ) x <- torch_tensor(5, requires_grad = TRUE) expect_error(double_in_place(x), "leaf Variable that requires grad") }) test_that("mark_non_differentiable", { # https://github.com/pytorch/pytorch/blob/master/test/test_autograd.py#L1309 myfun <- autograd_function( forward = function(ctx, input) { output <- input > 0 ctx$mark_non_differentiable(output) output }, backward = function(ctx, g) { list(input = g * 0) } ) x <- torch_tensor(c(-1, 2), requires_grad = TRUE) mask <- myfun(x) expect_false(mask$requires_grad) y <- x$masked_fill(mask, 0) expect_no_error(y$sum()$backward()) myfun <- autograd_function( forward = function(ctx, input) { a <- input + 1 b <- input + 2 ctx$mark_non_differentiable(a) list(a, b) }, backward = function(ctx, ga, gb) { expect_equal_to_r(ga, 0) expect_equal_to_r(gb, 1) list(input = gb) } ) x <- torch_tensor(1, requires_grad = TRUE) r <- myfun(x) expect_false(r[[1]]$requires_grad) expect_true(r[[2]]$requires_grad) r[[2]]$backward() expect_equal_to_r(x$grad, 1) }) test_that("retain_grad is invisible", { x <- torch_tensor(1, requires_grad = TRUE) expect_invisible(x$retain_grad()) }) test_that("grad_fn works", { x <- torch_ones(c(2, 2), requires_grad = TRUE) y <- x$mean() expect_output(print(y$grad_fn), "MeanBackward0") k <- y$grad_fn$next_functions expect_length(k, 1) l <- k$next_functions expect_length(l, 0) x <- torch_tensor(1) expect_true(is.null(x$grad_fn)) }) test_that("autograd_backward", { x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x a <- torch_tensor(1, requires_grad = TRUE) b <- 3 * a on <- torch_ones(c(1)) autograd_backward(list(y, b), list(on, on)) expect_equal_to_r(x$grad, 2) expect_equal_to_r(a$grad, 3) x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x a <- torch_tensor(1, requires_grad = TRUE) b <- 3 * a on <- torch_ones(c(1)) autograd_backward(list(y, b)) expect_equal_to_r(x$grad, 2) expect_equal_to_r(a$grad, 3) x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x a <- torch_tensor(1, requires_grad = TRUE) b <- 3 * a on <- torch_ones(c(1)) autograd_backward(list(y, b), list(NULL, on)) expect_equal_to_r(x$grad, 2) expect_equal_to_r(a$grad, 3) }) test_that("autograd_backward works for single tensors", { x <- torch_tensor(1, requires_grad = TRUE) y <- 2 * x autograd_backward(y) expect_equal_to_r(x$grad, 2) x <- torch_tensor(1, requires_grad = TRUE) on <- torch_tensor(1) y <- 2 * x autograd_backward(y, on) expect_equal_to_r(x$grad, 2) }) test_that("autograd_grad works", { x <- torch_randn(c(2, 2), requires_grad = TRUE) y <- torch_randn(c(2, 2), requires_grad = TRUE) z <- x^2 + y * x + y^2 z$backward(torch_ones(c(2, 2)), create_graph = TRUE) x_grad <- 2 * x + y y_grad <- x + 2 * y expect_equal_to_tensor(x$grad, x_grad) expect_equal_to_tensor(y$grad, y_grad) grad_sum <- 2 * x$grad + y$grad x_hv <- autograd_grad( outputs = grad_sum, grad_outputs = torch_ones(c(2, 2)), inputs = x, create_graph = TRUE ) expected_x_hv <- torch_ones(c(2, 2)) * 5 expected_y_hv <- torch_ones(c(2, 2)) * 4 expect_equal_to_tensor(x_hv[[1]], expected_x_hv) expect_equal_to_tensor(x$grad, x_grad) expect_equal_to_tensor(y$grad, y_grad) }) test_that("autograd_grad with non-leafs", { x_init <- torch_randn(c(2, 2), requires_grad = TRUE) x <- x_init y <- torch_randn(c(2, 2), requires_grad = TRUE) grad_output <- torch_ones(c(2, 2)) fn <- function(x) { x^2 + y * x + y^2 } for (i in 1:5) { o <- autograd_grad( fn(x), x, grad_outputs = grad_output, create_graph = TRUE ) } grad_x_expected <- 2 * x + y expect_undefined_tensor(x$grad) expect_undefined_tensor(y$grad) expect_equal_to_tensor(o[[1]], grad_x_expected) x <- x + 0.05 * o[[1]] val_init <- fn(x_init)$sum() val_final <- fn(x)$sum() expect_true(as_array(val_final > val_init)) x$backward(grad_output) expect_equal_to_tensor(y$grad, (5 / 100) * torch_ones(c(2, 2))) expect_equal_to_tensor(x_init$grad, 1.1 * torch_ones(c(2, 2))) }) test_that("autograd_grad non leaf many outputs", { # This checks an edge case for function callbacks # We want to capture two grads of a function, but can only # register a single callback. x <- torch_randn(c(4, 2), requires_grad = TRUE) o <- x$chunk(2) a <- o[[1]] b <- o[[2]] hook <- function(...) { hook_called <<- TRUE } hook_called <- FALSE x$register_hook(hook) go <- torch_randn(c(2, 2)) o <- autograd_grad( a + 2 * b, list(a, b), grad_outputs = go, create_graph = TRUE ) expect_equal_to_tensor(o[[1]], go) expect_equal_to_tensor(o[[2]], go * 2) expect_false(hook_called) expect_undefined_tensor(x$grad) }) test_that("autograd_grad retain grad", { w <- torch_tensor(0.5, requires_grad = TRUE) b <- torch_tensor(0.9, requires_grad = TRUE) x <- torch_tensor(runif(100)) y <- 2 * x + 1 loss <- (y - (w * x + b))^2 loss <- loss$mean() out <- autograd_grad(loss, list(w, b)) expect_length(out, 2) expect_error(autograd_grad(loss, list(w, b)), regexp = "graph a second time") w <- torch_tensor(0.5, requires_grad = TRUE) b <- torch_tensor(0.9, requires_grad = TRUE) x <- torch_tensor(runif(100)) y <- 2 * x + 1 loss <- (y - (w * x + b))^2 loss <- loss$mean() out <- autograd_grad(loss, list(w, b), retain_graph = TRUE) expect_length(out, 2) out <- autograd_grad(loss, list(w, b)) expect_length(out, 2) }) test_that("autograd_grad allow unused", { w <- torch_tensor(0.5, requires_grad = TRUE) b <- torch_tensor(0.9, requires_grad = TRUE) x <- torch_tensor(runif(100)) y <- 2 * x + 1 loss <- (y - (w * x))^2 loss <- loss$mean() expect_error( autograd_grad(loss, list(w, b)), regexp = "not have been used in the graph" ) w <- torch_tensor(0.5, requires_grad = TRUE) b <- torch_tensor(0.9, requires_grad = TRUE) x <- torch_tensor(runif(100)) y <- 2 * x + 1 loss <- (y - (w * x))^2 loss <- loss$mean() out <- autograd_grad(loss, list(w, b), allow_unused = TRUE) expect_undefined_tensor(out[[2]]) expect_length(out, 2) }) test_that("can set requires_grad", { x <- torch_tensor(1, requires_grad = FALSE) x$requires_grad_(TRUE) expect_true(x$requires_grad) x$requires_grad <- FALSE expect_true(!x$requires_grad) x$requires_grad_(TRUE) expect_true(x$requires_grad) x[["requires_grad"]] <- FALSE expect_true(!x$requires_grad) x$requires_grad_(TRUE) expect_true(x$requires_grad) }) test_that("with detect anomaly", { x <- torch_randn(2, requires_grad = TRUE) y <- torch_randn(1) b <- (x^y)$sum() y$add_(1) expect_error( { b$backward() }, regexp = "detect_anomaly" ) expect_warning(regexp = "should be enabled only for debugging", { expect_error(regexp = "one of the variables needed", { with_detect_anomaly({ b$backward() }) }) }) }) test_that("try and with_detect_anomaly", { x <- torch_randn(2, requires_grad = TRUE) y <- torch_randn(1) b <- (x^y)$sum() y$add_(1) k <- try(silent = TRUE, { b$backward() with_detect_anomaly({ b$backward() }) }) expect_true(inherits(k, "try-error")) }) test_that("autograd_grad works with custom autograd fucntions", { torch_manual_seed(1) w1 <- torch_randn(4, 3, 5, 5)$requires_grad_() w1$retain_grad() m1 <- exp(w1)$mean() grad1 <- autograd_grad(m1, w1, torch_ones_like(w1)) exp2 <- autograd_function( forward = function(ctx, i) { result <- i$exp() ctx$save_for_backward(result = result) result }, backward = function(ctx, grad_output) { list(i = grad_output * ctx$saved_variables$result) } ) torch_manual_seed(1) w2 <- torch_randn(4, 3, 5, 5)$requires_grad_() w2$retain_grad() m2 <- exp2(w2)$mean() grad2 <- autograd_grad(m2, w2, torch_ones_like(w2)) expect_equal_to_tensor(grad1[[1]], grad2[[1]]) }) test_that("we autograd_grad doesn't segfault when backward fails", { exp2 <- autograd_function( forward = function(ctx, i) { result <- i$exp() ctx$save_for_backward(result = result) result }, backward = function(ctx, grad_output) { stop("error in GRADFOO") list(i = grad_output * ctx$saved_variables$result) } ) w2 <- torch_randn(4, 3, 5, 5)$requires_grad_() w2$retain_grad() m2 <- exp2(w2)$mean() expect_error( grad2 <- autograd_grad(m2, w2, torch_ones_like(w2)), regexp = "GRADFOO" ) }) test_that("can modify the gradient of a tensor", { x <- torch_tensor(2, requires_grad = TRUE) y <- x^3 y$backward() expect_equal_to_r(x$grad, 12) x$grad <- x$grad/2 expect_equal_to_r(x$grad, 6) }) test_that("local grad functions", { f <- function() { x <- torch_tensor(2, requires_grad = TRUE) y <- x^3 y$backward() } fun <- function(f, grad_mode) { if (grad_mode) { local_enable_grad() } else { local_no_grad() } f() } with_no_grad({ expect_error(fun(f, TRUE), regexp = NA) }) expect_error(f(), regexp = NA) with_enable_grad({ expect_error(fun(f, FALSE)) expect_error(f(), regexp = NA) }) })