context("nn-multihead_attention")

test_that("nn_multihead_attention", {
  t1 <- torch_randn(5, 8, 32)
  t2 <- torch_randn(5, 8, 32)
  t3 <- torch_randn(5, 8, 32)
  attn <- nn_multihead_attention(32, 8)

  # q,k,v all the same:
  out <- attn(t1, t1, t1)

  expect_identical(out[[1]]$size(), c(5L, 8L, 32L))
  expect_identical(out[[2]]$size(), c(8L, 5L, 5L))

  # unaveraged attention weights
  out <- attn(t1, t1, t1, avg_weights = FALSE)

  expect_identical(out[[1]]$size(), c(5L, 8L, 32L))
  expect_identical(out[[2]]$size(), c(8L, 8L, 5L, 5L))

  # q different from k,v:
  out <- attn(t1, t2, t2)

  expect_identical(out[[1]]$size(), c(5L, 8L, 32L))
  expect_identical(out[[2]]$size(), c(8L, 5L, 5L))

  # q,k,v all different
  out <- attn(t1, t2, t3)

  expect_identical(out[[1]]$size(), c(5L, 8L, 32L))
  expect_identical(out[[2]]$size(), c(8L, 5L, 5L))

  t2 <- torch_ones(c(5, 5)) - torch_tril(torch_ones(c(5, 5)))
  t2 <- t2$to(torch_bool())
  t3 <- torch_bernoulli(torch_ones(c(8, 5)) * 0.5)$to(dtype = "bool")
  out2 <- attn(t1, t1, t1, attn_mask = t2, key_padding_mask = t3)

  expect_identical(out2[[1]]$size(), c(5L, 8L, 32L))
  expect_identical(out2[[2]]$size(), c(8L, 5L, 5L))

})