test_that("can make simple batch request", { chat <- chat_databricks( system_prompt = "Be as terse as possible; no punctuation" ) resp <- chat$chat("What is 1 + 1?", echo = FALSE) expect_match(resp, "2") expect_equal(chat$last_turn()@tokens > 0, c(TRUE, TRUE)) }) test_that("can make simple streaming request", { chat <- chat_databricks( system_prompt = "Be as terse as possible; no punctuation" ) resp <- coro::collect(chat$stream("What is 1 + 1?")) expect_match(paste0(unlist(resp), collapse = ""), "2") }) # Common provider interface ----------------------------------------------- test_that("defaults are reported", { # Setting a dummy host ensures we don't skip this test, even if there are no # Databricks credentials available. withr::local_envvar(DATABRICKS_HOST = "https://example.cloud.databricks.com") expect_snapshot(. <- chat_databricks()) }) test_that("all tool variations work", { test_tools_simple(chat_databricks) test_tools_async(chat_databricks) test_tools_parallel(chat_databricks, total_calls = 6) test_tools_sequential(chat_databricks, total_calls = 6) }) test_that("can extract data", { test_data_extraction(chat_databricks) }) test_that("can use images", { # Databricks models don't support images. # # test_images_inline(chat_databricks) # test_images_remote(chat_databricks) }) # Auth -------------------------------------------------------------------- test_that("Databricks PATs are detected correctly", { withr::local_envvar( DATABRICKS_HOST = "https://example.cloud.databricks.com", DATABRICKS_TOKEN = "token" ) credentials <- default_databricks_credentials() expect_equal(credentials(), list(Authorization = "Bearer token")) }) test_that("Databricks CLI tokens are detected correctly", { # Emulate a ~/.databrickscfg file written by other tooling. cfg_file <- tempfile("databricks", fileext = ".cfg") writeLines( c( '[DEFAULT]', 'host = https://example.cloud.databricks.com', 'auth_type = databricks-cli', ';This profile is autogenerated by the Databricks Extension for VS Code', '[ellmer]', 'host=https://example2.cloud.databricks.com/', 'auth_type=databricks-cli' ), cfg_file ) withr::local_envvar( DATABRICKS_HOST = NA, DATABRICKS_CONFIG_FILE = cfg_file, DATABRICKS_CLIENT_ID = NA, DATABRICKS_CLIENT_SECRET = NA ) local_mocked_bindings( databricks_cli_token = function(path, host) { stopifnot(startsWith(host, "https://")) "cli_token" } ) # Ensure we can get the Workspace out of the config file. workspace <- databricks_workspace() expect_equal(workspace, "https://example.cloud.databricks.com") # And for a non-default profile, too. withr::local_envvar(DATABRICKS_CONFIG_PROFILE = "ellmer") expect_equal(databricks_workspace(), "https://example2.cloud.databricks.com/") # Ensure we call the CLI for credentials. credentials <- default_databricks_credentials(workspace) expect_equal(credentials(), list(Authorization = "Bearer cli_token")) }) test_that("Workbench-managed Databricks credentials are detected correctly", { # Emulate a databricks.cfg file written by Workbench. db_home <- tempfile("posit-workbench") dir.create(db_home) writeLines( c( '[workbench]', 'host = https://example.cloud.databricks.com', 'token = token' ), file.path(db_home, "databricks.cfg") ) withr::local_envvar( DATABRICKS_CONFIG_FILE = file.path(db_home, "databricks.cfg"), DATABRICKS_CONFIG_PROFILE = "workbench", DATABRICKS_HOST = "https://example.cloud.databricks.com", DATABRICKS_CLIENT_ID = NA, DATABRICKS_CLIENT_SECRET = NA ) credentials <- default_databricks_credentials() expect_equal(credentials(), list(Authorization = "Bearer token")) }) test_that("M2M authentication requests look correct", { withr::local_envvar( DATABRICKS_HOST = "https://example.cloud.databricks.com", DATABRICKS_CLIENT_ID = "id", DATABRICKS_CLIENT_SECRET = "secret" ) local_mocked_responses(function(req) { # Snapshot relevant fields of the outgoing request. expect_snapshot( list(url = req$url, headers = req$headers, body = req$body$data) ) response_json(body = list(access_token = "token")) }) credentials <- default_databricks_credentials() expect_equal(credentials(), list(Authorization = "Bearer token")) }) test_that("workspace detection handles URLs with and without an https prefix", { withr::with_envvar( c(DATABRICKS_HOST = "example.cloud.databricks.com"), expect_equal( databricks_workspace(), "https://example.cloud.databricks.com" ) ) withr::with_envvar( c(DATABRICKS_HOST = "https://example.cloud.databricks.com"), expect_equal( databricks_workspace(), "https://example.cloud.databricks.com" ) ) }) test_that("the user agent respects SPARK_CONNECT_USER_AGENT when set", { withr::with_envvar( c(SPARK_CONNECT_USER_AGENT = NA), expect_match(databricks_user_agent(), "^r-ellmer") ) withr::with_envvar( c(SPARK_CONNECT_USER_AGENT = "testing"), expect_match(databricks_user_agent(), "^testing r-ellmer") ) }) test_that("tokens can be requested from a Connect server", { skip_if_not_installed("connectcreds") withr::local_envvar( DATABRICKS_HOST = "https://example.cloud.databricks.com", DATABRICKS_TOKEN = "token" ) connectcreds::local_mocked_connect_responses(token = "token") credentials <- default_databricks_credentials() expect_equal(credentials(), list(Authorization = "Bearer token")) }) test_that("chat_databricks() serializes tools correctly", { withr::local_envvar( DATABRICKS_HOST = "https://example.cloud.databricks.com", DATABRICKS_TOKEN = "token" ) chat <- chat_databricks(model = "databricks-claude-3-7-sonnet") provider <- chat$get_provider() expect_equal( as_json( provider, tool( function() format(Sys.Date()), .name = "current_date", .description = "Returns the current date in ISO 8601 format." ) ), list( type = "function", "function" = list( name = "current_date", description = "Returns the current date in ISO 8601 format." ) ) ) expect_equal( as_json( provider, tool( function(person) { if (person == "Joe") "sage green" else "red" }, .name = "favourite_colour", .description = "Returns a person's favourite colour.", person = type_string("Name of a person") ) ), list( type = "function", "function" = list( name = "favourite_colour", description = "Returns a person's favourite colour.", parameters = list( type = "object", description = "", properties = list( person = list( type = "string", description = "Name of a person" ) ), required = list("person"), additionalProperties = FALSE ) ) ) ) })