context("Momentum") test_that("momentum schedules", { constant_m <- make_constant(0.5) expect_equal(constant_m(1), 0.5) expect_equal(constant_m(500), 0.5) expect_equal(constant_m(1000), 0.5) constant_m <- make_constant(0.25) expect_equal(constant_m(1), 0.25) expect_equal(constant_m(500), 0.25) expect_equal(constant_m(1000), 0.25) constant_m <- make_constant(0.95) expect_equal(constant_m(1), 0.95) expect_equal(constant_m(500), 0.95) expect_equal(constant_m(1000), 0.95) step_m <- make_switch(init_value = 0.2, final_value = 0.6, switch_iter = 100) expect_equal(step_m(1), 0.2) expect_equal(step_m(99), 0.2) expect_equal(step_m(100), 0.6) expect_equal(step_m(101), 0.6) expect_equal(step_m(1000), 0.6) linear_m <- make_ramp(init_value = 0.1, final_value = 0.8) expect_equal(linear_m(1, max_iter = 1000), 0.1) expect_equal(linear_m(500, max_iter = 1000), 0.45, tol = 1e-3) expect_equal(linear_m(1000, max_iter = 1000), 0.8) nest_m <- make_nesterov_convex_approx(burn_in = 0, use_init_mu = FALSE) expect_equal(nest_m(0), 0) expect_equal(nest_m(1), 0.5) expect_equal(nest_m(5), 0.7) expect_equal(nest_m(10), 0.8) expect_equal(nest_m(20), 0.88) expect_equal(nest_m(50), 0.9455, tolerance = 0.0001) expect_equal(nest_m(500), 0.9941, tolerance = 0.0001) expect_equal(nest_m(1000), 0.9970, tolerance = 0.0001) nest_m <- make_nesterov_convex_approx(burn_in = 0, use_init_mu = TRUE) expect_equal(nest_m(0), 0.4) expect_equal(nest_m(1), 0.5) expect_equal(nest_m(5), 0.7) expect_equal(nest_m(10), 0.8) expect_equal(nest_m(20), 0.88) expect_equal(nest_m(50), 0.9455, tolerance = 0.0001) expect_equal(nest_m(500), 0.9941, tolerance = 0.0001) expect_equal(nest_m(1000), 0.9970, tolerance = 0.0001) nest_m <- make_nesterov_convex_approx(burn_in = 1, use_init_mu = FALSE) expect_equal(nest_m(0), 0) expect_equal(nest_m(1), 0) expect_equal(nest_m(2), 0.5) expect_equal(nest_m(6), 0.7) expect_equal(nest_m(11), 0.8) expect_equal(nest_m(21), 0.88) expect_equal(nest_m(51), 0.9455, tolerance = 0.0001) expect_equal(nest_m(501), 0.9941, tolerance = 0.0001) expect_equal(nest_m(1001), 0.9970, tolerance = 0.0001) nest_m <- make_nesterov_convex_approx(burn_in = 1, use_init_mu = TRUE) expect_equal(nest_m(0), 0) expect_equal(nest_m(1), 0.4) expect_equal(nest_m(2), 0.5) expect_equal(nest_m(6), 0.7) expect_equal(nest_m(11), 0.8) expect_equal(nest_m(21), 0.88) expect_equal(nest_m(51), 0.9455, tolerance = 0.0001) expect_equal(nest_m(501), 0.9941, tolerance = 0.0001) expect_equal(nest_m(1001), 0.9970, tolerance = 0.0001) })