make_optimizer_maker = function(optimizer_fn) { function(..., steps = 2) { n <- nn_linear(1, 1) o <- optimizer_fn(n$parameters, ...) x <- torch_randn(10, 1) y <- torch_randn(10, 1) s <- function() { o$zero_grad() loss <- mean((n(x) - y)^2) loss$backward() o$step() } replicate(steps, s()) o } } make_ignite_adamw <- make_optimizer_maker(optim_ignite_adamw) make_ignite_sgd <- make_optimizer_maker(optim_ignite_sgd) make_ignite_adam <- make_optimizer_maker(optim_ignite_adam) make_ignite_rmsprop <- make_optimizer_maker(optim_ignite_rmsprop) make_ignite_adagrad <- make_optimizer_maker(optim_ignite_adagrad) sample_rmsprop_params <- function() { lr <- runif(1, 0.01, 0.02) alpha <- runif(1, 0.98, 0.99) eps <- runif(1, 0.0000001, 0.000002) weight_decay <- if (runif(1) < 0.5) runif(1, 0, 0.001) else 0 momentum <- if (runif(1) < 0.5) runif(1, 0.8, 0.9) else 0 centered <- sample(c(TRUE, FALSE), 1) list(lr = lr, alpha = alpha, eps = eps, weight_decay = weight_decay, momentum = momentum, centered = centered) } sample_adagrad_params <- function() { lr <- runif(1, 0.1, 0.2) weight_decay <- if (runif(1) < 0.5) runif(1, 0, 0.001) else 0 lr_decay <- if (runif(1) < 0.5) runif(1, 0, 0.001) else 0 initial_accumulator_value <- if (runif(1) < 0.5) runif(1, 0, 0.000001) else 0 eps <- runif(1, 0.0000001, 0.000002) list(lr = lr, weight_decay = weight_decay, initial_accumulator_value = initial_accumulator_value, eps = eps) } sample_sgd_params <- function() { lr <- runif(1, 0.01, 0.02) nesterov <- sample(c(TRUE, FALSE), 1) dampening <- if (nesterov) 0 else runif(1, 0, 0.1) weight_decay <- if (runif(1) < 0.5) runif(1, 0, 0.001) else 0 momentum <- if (nesterov || runif(1) < 0.5) runif(1, 0.8, 0.9) else 0 list(lr = lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov) } sample_adam_params <- function() { lr <- runif(1, 0.01, 0.02) weight_decay <- if (runif(1) < 0.5) runif(1, 0, 0.001) else 0 betas <- runif(2, 0.9, 0.99) eps <- runif(1, 0.001, 0.002) amsgrad <- sample(c(TRUE, FALSE), 1) list(lr = lr, weight_decay = weight_decay, betas = betas, eps = eps, amsgrad = amsgrad) } sample_adamw_params <- sample_adam_params expect_state_dict_works <- function(optimizer_fn, ...) { f <- function(load = FALSE) { n <- nn_linear(1, 1) o <- optimizer_fn(n$parameters, ...) x <- torch_randn(10, 1) y <- torch_randn(10, 1) n$parameters$bias$requires_grad_(FALSE) s <- function() { o$zero_grad() loss <- mean((n(x) - y)^2) loss$backward() o$step() } replicate(2, s()) if (load) { o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) } replicate(2, s()) return(n$parameters) } w1 <- f(load = TRUE) w2 <- f(load = FALSE) expect_equal(w1, w2) } expect_ignite_can_change_param_groups <- function(optimizer_fn, ...) { n <- nn_linear(1, 1) o <- optimizer_fn(n$parameters, ...) for (nm in names(o$param_groups[-1L])) { if (is.numeric(o$param_groups[[nm]])) { o$param_groups[[nm]] = o$param_groups[[nm]] * 0.1 expect_equal(o$param_groups[[nm]], o$param_groups[[nm]] * 0.1) } else if (is.logical(o$param_groups[[nm]])) { o$param_groups[[nm]] = !o$param_groups[[nm]] expect_equal(o$param_groups[[nm]], !o$param_groups[[nm]]) } else { stop("Unknown type") } } } expect_ignite_can_add_param_group <- function(optimizer_fn, ...) { n <- nn_linear(1, 1) o <- optimizer_fn(n$parameters, lr = 0.1) n1 = nn_linear(1, 1) o$add_param_group(list(params = n1$parameters, lr = 19)) expect_equal(o$param_groups[[1]]$lr, 0.1) expect_equal(o$param_groups[[2]]$lr, 19) }