library(testthat) library(ranger) library(arf) set.seed(42) trn <- sample(1:nrow(iris), 100) tst <- setdiff(1:nrow(iris), trn) arf <- adversarial_rf(iris[trn, ], num_trees = 20, parallel = FALSE) emap <- encode(arf, iris[trn, ], k = 2) emb_tst <- predict(emap, arf, iris[tst, ]) test_that("decode_knn returns correct structure", { # This tests the eForest scheme (train_decoder) out <- decode_knn(arf, emap, emb_tst, k = 5, parallel = FALSE) expect_type(out, "list") expect_named(out, c("x_hat", "x_tilde")) expect_s3_class(out$x_hat, "data.frame") expect_equal(nrow(out$x_hat), nrow(iris[tst, ])) # no class label expect_equal(ncol(out$x_hat), ncol(iris[trn, ])) }) test_that("decode_knn handles k=1 (nearest neighbor only)", { out <- decode_knn(arf, emap, emb_tst, k = 1, parallel = FALSE) expect_equal(nrow(out$x_hat), nrow(iris[tst, ])) # With k=1, the output should contain valid factor levels from the training set expect_true(all(out$x_hat$Species %in% levels(iris$Species))) })