# Parity test for the TabICL row-interaction stage (R/tabicl-interaction.R) # against the Python reference (RowInteraction). The fixture is generated by # dev/tabicl/dump_primitives.py with 3 blocks, so it exercises both the # self-attention blocks and the final CLS-token cross-attention block, plus the # output LayerNorm and CLS flattening. # Both LayerNorm configurations: the classifier checkpoint keeps LayerNorm # biases, the regressor uses bias_free_ln = TRUE. for (fixture in c("row_interaction", "row_interaction_biasfree")) { local({ fixture_name <- fixture test_that( paste0( "tabicl_row_interaction matches the Python reference (", fixture_name, ")" ), { skip_if_no_tabicl_fixtures(fixture_name) f <- tabicl_load_fixture(fixture_name) meta <- tabicl_fixture_meta(fixture_name) ri <- brulee:::tabicl_row_interaction( embed_dim = meta$embed_dim, num_blocks = meta$num_blocks, nhead = meta$nhead, dim_feedforward = meta$dim_feedforward, num_cls = meta$num_cls, rope_base = meta$rope_base, bias_free_ln = isTRUE(meta$bias_free_ln) ) ri$eval() tabicl_copy_row_interaction(ri, f) out <- ri(f$input) expect_equal(dim(out), dim(f$out)) expect_lt(tabicl_max_abs_diff(out, f$out), 1e-5) } ) }) } test_that("tabicl_row_interaction output dimension is num_cls * embed_dim", { skip_on_cran() skip_if_not_installed("torch") if (!torch::torch_is_installed()) { skip("libtorch not installed") } embed_dim <- 16 num_cls <- 3 ri <- brulee:::tabicl_row_interaction( embed_dim = embed_dim, num_blocks = 2, nhead = 4, dim_feedforward = 32, num_cls = num_cls ) ri$eval() # Input: 1 table, 4 rows, 5 feature slots + num_cls CLS slots. input <- torch::torch_randn(1, 4, 5 + num_cls, embed_dim) out <- ri(input) expect_equal(dim(out), c(1, 4, num_cls * embed_dim)) })