# Parity tests for the TabICL attention primitives (R/tabicl-attention.R): # the qassmax-mlp-elementwise scalable softmax and multi-head attention with # optional RoPE / SSMax. Fixtures are generated by # dev/tabicl/dump_primitives.py and cover the four attention configurations the # released checkpoints use: # - mha_rope: self-attention + RoPE, no SSMax (row interactor) # - mha_ssmax_self: self-attention + SSMax, no RoPE (ICL transformer) # - mha_ssmax_cross: cross-attention + SSMax (column-embedding ISAB attn1) # - mha_plain_cross: cross-attention, plain (column-embedding ISAB attn2) # ------------------------------------------------------------------------------ # SSMax test_that("tabicl_qassmax matches the Python reference", { skip_if_no_tabicl_fixtures("ssmax") f <- tabicl_load_fixture("ssmax") meta <- tabicl_fixture_meta("ssmax") ssmax <- brulee:::tabicl_qassmax( num_heads = meta$num_heads, head_dim = meta$head_dim, n_hidden = meta$n_hidden ) tabicl_copy_ssmax(ssmax, f) out <- ssmax(f$q, meta$n) expect_equal(dim(out), dim(f$out)) expect_lt(tabicl_max_abs_diff(out, f$out), 1e-5) }) # ------------------------------------------------------------------------------ # Multi-head attention run_mha_fixture <- function(name) { f <- tabicl_load_fixture(name) meta <- tabicl_fixture_meta(name) mha <- brulee:::tabicl_mha( embed_dim = meta$embed_dim, num_heads = meta$num_heads, ssmax = meta$ssmax ) tabicl_copy_mha(mha, f) rope <- NULL if (isTRUE(meta$use_rope)) { rope <- brulee:::tabicl_rope( dim = meta$embed_dim %/% meta$num_heads, theta = 100000 ) torch::with_no_grad(rope$freqs$copy_(f[["rope.freqs"]])) } out <- mha(f$query, f$key, f$value, rope = rope) list(out = out, golden = f$out) } test_that("tabicl_mha self-attention with RoPE matches Python", { skip_if_no_tabicl_fixtures("mha_rope") res <- run_mha_fixture("mha_rope") expect_equal(dim(res$out), dim(res$golden)) expect_lt(tabicl_max_abs_diff(res$out, res$golden), 1e-5) }) test_that("tabicl_mha self-attention with SSMax matches Python", { skip_if_no_tabicl_fixtures("mha_ssmax_self") res <- run_mha_fixture("mha_ssmax_self") expect_lt(tabicl_max_abs_diff(res$out, res$golden), 1e-5) }) test_that("tabicl_mha cross-attention with SSMax matches Python", { skip_if_no_tabicl_fixtures("mha_ssmax_cross") res <- run_mha_fixture("mha_ssmax_cross") expect_lt(tabicl_max_abs_diff(res$out, res$golden), 1e-5) }) test_that("tabicl_mha plain cross-attention matches Python", { skip_if_no_tabicl_fixtures("mha_plain_cross") res <- run_mha_fixture("mha_plain_cross") expect_lt(tabicl_max_abs_diff(res$out, res$golden), 1e-5) })