context("test - softmax") load("softmax-correctRes.RData") set.seed(5) beta <- c(-0.5, 0.7, -0.9, 1.1) miData <- DGP(70, 3, beta) miData_test <- DGP(30, 3, beta) softmax_result_alpha0 <- softmax(miData$Z, miData$X, miData$ID, alpha = 0) softmax_result_alpha3 <- softmax(miData$Z, miData$X, miData$ID, alpha = 3) test_that("softmax", { all.equal(softmax_result_alpha0$alpha, softmax_result_alpha0_correct$alpha) all.equal(softmax_result_alpha0$loglik, softmax_result_alpha0_correct$loglik) all.equal(softmax_result_alpha0$coeffiecents, softmax_result_alpha0_correct$coeffiecents) all.equal(softmax_result_alpha3$alpha, softmax_result_alpha3_correct$alpha) all.equal(softmax_result_alpha3$loglik, softmax_result_alpha3_correct$loglik) all.equal(softmax_result_alpha3$coeffiecents, softmax_result_alpha3_correct$coeffiecents) expect_error(predict(softmax_result_alpha0, newdata = miData_test$X)) expect_error(predict(softmax_result_alpha0, bag_newdata = miData_test$ID)) expect_is(softmax_result_alpha0, "softmax") expect_equal(softmax_result_alpha0$alpha, 0) expect_equal(length(coef(softmax_result_alpha0)), 4L) expect_equal(length(fitted(softmax_result_alpha0, type = "bag")), 70L) expect_equal(length(fitted(softmax_result_alpha0, type = "instance")), 210L) expect_equal(length(predict(softmax_result_alpha0, type = "bag")), 70L) expect_equal(length(predict(softmax_result_alpha0, type = "instance")), 210L) expect_equal(length(predict(softmax_result_alpha0, miData_test$X, miData_test$ID, type = "bag")), 30L) expect_equal(length(predict(softmax_result_alpha0, miData_test$X, miData_test$ID, type = "instance")), 90L) expect_is(softmax_result_alpha3, "softmax") expect_equal(softmax_result_alpha3$alpha, 3) expect_equal(length(coef(softmax_result_alpha3)), 4L) expect_equal(length(fitted(softmax_result_alpha3, type = "bag")), 70L) expect_equal(length(fitted(softmax_result_alpha3, type = "instance")), 210L) expect_equal(length(predict(softmax_result_alpha3, type = "bag")), 70L) expect_equal(length(predict(softmax_result_alpha3, type = "instance")), 210L) expect_equal(length(predict(softmax_result_alpha3, miData_test$X, miData_test$ID, type = "bag")), 30L) expect_equal(length(predict(softmax_result_alpha3, miData_test$X, miData_test$ID, type = "instance")), 90L) })