test_that("basic binomial mlp LBFGS", { skip_on_cran() skip_if_not(torch::torch_is_installed()) skip_if_not_installed("modeldata") skip_if_not_installed("yardstick") suppressPackageStartupMessages(library(dplyr)) suppressPackageStartupMessages(library(recipes)) # ------------------------------------------------------------------------------ set.seed(585) bin_tr <- modeldata::sim_classification(5000) bin_te <- modeldata::sim_classification(1000) rec <- recipe(class ~ ., data = bin_tr) |> step_normalize(all_predictors()) num_class <- length(levels(bin_tr$class)) # ------------------------------------------------------------------------------ expect_no_error( { set.seed(392) torch::torch_manual_seed(392) bin_fit_f_lbfgs <- brulee_mlp( class ~ ., bin_tr, epochs = 200, hidden_units = 5, rate_schedule = "cyclic", learn_rate = 0.1, device = "cpu" ) } ) expect_no_error( { set.seed(392) torch::torch_manual_seed(392) bin_fit_lbfgs <- brulee_mlp( rec, bin_tr, epochs = 200, hidden_units = 5, rate_schedule = "cyclic", learn_rate = 0.1, device = "cpu" ) } ) expect_no_error( bin_pred_lbfgs <- predict(bin_fit_lbfgs, bin_te) |> bind_cols(predict(bin_fit_lbfgs, bin_te, type = "prob")) |> bind_cols(bin_te) |> select(starts_with(".pred"), class) ) fact_str <- structure( integer(0), levels = c("class_1", "class_2"), class = "factor" ) exp_str <- structure( list( .pred_class = fact_str, .pred_class_1 = numeric(0), .pred_class_2 = numeric(0), class = fact_str ), row.names = integer(0), class = c("tbl_df", "tbl", "data.frame") ) expect_equal(bin_pred_lbfgs[0, ], exp_str) expect_equal(nrow(bin_pred_lbfgs), nrow(bin_te)) # Did it learn anything? bin_brier_lbfgs <- bin_pred_lbfgs |> yardstick::brier_class(class, .pred_class_1) expect_true(bin_brier_lbfgs$.estimate < (1 - 1 / num_class)^2) }) test_that("basic binomial mlp SGD", { skip_on_cran() skip_if_not(torch::torch_is_installed()) skip_if_not_installed("modeldata") skip_if_not_installed("yardstick") suppressPackageStartupMessages(library(dplyr)) suppressPackageStartupMessages(library(recipes)) # ------------------------------------------------------------------------------ set.seed(585) bin_tr <- modeldata::sim_classification(5000) bin_te <- modeldata::sim_classification(1000) rec <- recipe(class ~ ., data = bin_tr) |> step_normalize(all_predictors()) num_class <- length(levels(bin_tr$class)) # ------------------------------------------------------------------------------ expect_no_error( { set.seed(392) torch::torch_manual_seed(392) bin_fit_f_sgd <- brulee_mlp( class ~ ., bin_tr, epochs = 200, penalty = 0, dropout = .1, hidden_units = 5, optimize = "SGD", batch_size = 64L, momentum = 0.5, learn_rate = 0.1, device = "cpu" ) } ) expect_no_error( { set.seed(392) torch::torch_manual_seed(392) bin_fit_sgd <- brulee_mlp( rec, bin_tr, epochs = 200, penalty = 0, dropout = .1, hidden_units = 5, optimize = "SGD", batch_size = 64L, momentum = 0.5, learn_rate = 0.1, device = "cpu" ) } ) expect_no_error( bin_pred_sgd <- predict(bin_fit_sgd, bin_te) |> bind_cols(predict(bin_fit_sgd, bin_te, type = "prob")) |> bind_cols(bin_te) |> select(starts_with(".pred"), class) ) fact_str <- structure( integer(0), levels = c("class_1", "class_2"), class = "factor" ) exp_str <- structure( list( .pred_class = fact_str, .pred_class_1 = numeric(0), .pred_class_2 = numeric(0), class = fact_str ), row.names = integer(0), class = c("tbl_df", "tbl", "data.frame") ) expect_equal(bin_pred_sgd[0, ], exp_str) expect_equal(nrow(bin_pred_sgd), nrow(bin_te)) # Did it learn anything? bin_brier_sgd <- bin_pred_sgd |> yardstick::brier_class(class, .pred_class_1) expect_true(bin_brier_sgd$.estimate < (1 - 1 / num_class)^2) }) test_that("binomial mlp case weights", { skip_on_cran() skip_if_not(torch::torch_is_installed()) skip_if_not_installed("modeldata") skip_if_not_installed("yardstick") suppressPackageStartupMessages(library(dplyr)) suppressPackageStartupMessages(library(recipes)) # ------------------------------------------------------------------------------ set.seed(585) bin_tr <- modeldata::sim_classification(5000, intercept = 1) bin_te <- modeldata::sim_classification(1000, intercept = 1) rec <- recipe(class ~ ., data = bin_tr) |> step_normalize(all_predictors()) num_class <- length(levels(bin_tr$class)) # ------------------------------------------------------------------------------ expect_no_error( { set.seed(391) torch::torch_manual_seed(391) weighted <- brulee_mlp( rec, bin_tr, epochs = 5, hidden_units = 5, rate_schedule = "cyclic", class_weights = 10, learn_rate = 0.01, device = "cpu" ) } ) expect_no_error( weighted_pred <- predict(weighted, bin_te) |> bind_cols(predict(weighted, bin_te, type = "prob")) |> bind_cols(bin_te) |> select(starts_with(".pred"), class) ) expect_no_error( { set.seed(391) torch::torch_manual_seed(391) unweighted <- brulee_mlp( rec, bin_tr, epochs = 5, hidden_units = 5, rate_schedule = "cyclic", learn_rate = 0.01, device = "cpu" ) } ) expect_no_error( unweighted_pred <- predict(unweighted, bin_te) |> bind_cols(predict(unweighted, bin_te, type = "prob")) |> bind_cols(bin_te) |> select(starts_with(".pred"), class) ) expect_true( sum(weighted_pred$.pred_class == "class_2") > sum(unweighted_pred$.pred_class == "class_2") ) }) test_that('linear activations', { skip_on_cran() # See https://github.com/tidymodels/brulee/issues/68 skip_if(!torch::torch_is_installed()) skip_if_not_installed("modeldata") data(bivariate, package = "modeldata") set.seed(20) torch::torch_manual_seed(20) nn_log_biv <- try( brulee_mlp( Class ~ log(A) + log(B), data = bivariate_train, epochs = 150, hidden_units = 3, activation = "linear", device = "cpu" ), silent = TRUE ) expect_s3_class(nn_log_biv, "brulee_mlp") })