# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. test_that("list_compute_functions() works", { expect_type(list_compute_functions(), "character") expect_true(all(!grepl("^hash_", list_compute_functions()))) }) test_that("arrow_scalar_function() works", { # check in/out type as schema/data type fun <- arrow_scalar_function( function(context, x) x$cast(int64()), schema(x = int32()), int64() ) expect_equal(fun$in_type[[1]], schema(x = int32())) expect_equal(fun$out_type[[1]](), int64()) # check in/out type as data type/data type fun <- arrow_scalar_function( function(context, x) x$cast(int64()), int32(), int64() ) expect_equal(fun$in_type[[1]][[1]], field("", int32())) expect_equal(fun$out_type[[1]](), int64()) # check in/out type as field/data type fun <- arrow_scalar_function( function(context, a_name) x$cast(int64()), field("a_name", int32()), int64() ) expect_equal(fun$in_type[[1]], schema(a_name = int32())) expect_equal(fun$out_type[[1]](), int64()) # check in/out type as lists fun <- arrow_scalar_function( function(context, x) x, list(int32(), int64()), list(int64(), int32()), auto_convert = TRUE ) expect_equal(fun$in_type[[1]][[1]], field("", int32())) expect_equal(fun$in_type[[2]][[1]], field("", int64())) expect_equal(fun$out_type[[1]](), int64()) expect_equal(fun$out_type[[2]](), int32()) expect_snapshot_error(arrow_scalar_function(NULL, int32(), int32())) }) test_that("arrow_scalar_function() works with auto_convert = TRUE", { times_32_wrapper <- arrow_scalar_function( function(context, x) x * 32, float64(), float64(), auto_convert = TRUE ) dummy_kernel_context <- list() expect_equal( times_32_wrapper$wrapper_fun(dummy_kernel_context, list(Scalar$create(2))), Array$create(2 * 32) ) }) test_that("register_scalar_function() adds a compute function to the registry", { skip_if_not(CanRunWithCapturedR()) # TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has # occasional valgrind errors skip_on_linux_devel() register_scalar_function( "times_32", function(context, x) x * 32.0, int32(), float64(), auto_convert = TRUE ) on.exit(unregister_binding("times_32", update_cache = TRUE)) expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) expect_true("times_32" %in% list_compute_functions()) expect_equal( call_function("times_32", Array$create(1L, int32())), Array$create(32L, float64()) ) expect_equal( call_function("times_32", Scalar$create(1L, int32())), Scalar$create(32L, float64()) ) skip_if_not_available("acero") expect_identical( record_batch(a = 1L) %>% dplyr::mutate(b = times_32(a)) %>% dplyr::collect(), tibble::tibble(a = 1L, b = 32.0) ) }) test_that("arrow_scalar_function() with bad return type errors", { skip_if_not(CanRunWithCapturedR()) register_scalar_function( "times_32_bad_return_type_array", function(context, x) Array$create(x, int32()), int32(), float64() ) on.exit(unregister_binding("times_32_bad_return_type_array", update_cache = TRUE)) expect_error( call_function("times_32_bad_return_type_array", Array$create(1L)), "Expected return Array or Scalar with type 'double'" ) register_scalar_function( "times_32_bad_return_type_scalar", function(context, x) Scalar$create(x, int32()), int32(), float64() ) on.exit(unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE)) expect_error( call_function("times_32_bad_return_type_scalar", Array$create(1L)), "Expected return Array or Scalar with type 'double'" ) }) test_that("register_scalar_function() can register multiple kernels", { skip_if_not(CanRunWithCapturedR()) register_scalar_function( "times_32", function(context, x) x * 32L, in_type = list(int32(), int64(), float64()), out_type = function(in_types) in_types[[1]], auto_convert = TRUE ) on.exit(unregister_binding("times_32", update_cache = TRUE)) expect_equal( call_function("times_32", Scalar$create(1L, int32())), Scalar$create(32L, int32()) ) expect_equal( call_function("times_32", Scalar$create(1L, int64())), Scalar$create(32L, int64()) ) expect_equal( call_function("times_32", Scalar$create(1L, float64())), Scalar$create(32L, float64()) ) }) test_that("register_scalar_function() errors for unsupported specifications", { expect_error( register_scalar_function( "no_kernels", function(...) NULL, list(), list() ), "Can't register user-defined scalar function with 0 kernels" ) expect_error( register_scalar_function( "wrong_n_args", function(x) NULL, int32(), int32() ), "Expected `fun` to accept 2 argument\\(s\\)" ) expect_error( register_scalar_function( "var_kernels", function(...) NULL, list(float64(), schema(x = float64(), y = float64())), float64() ), "Kernels for user-defined function must accept the same number of arguments" ) }) test_that("user-defined functions work during multi-threaded execution", { skip_if_not(CanRunWithCapturedR()) skip_if_not_available("dataset") # Skip on linux devel because: # TODO(ARROW-17283): Snappy has a UBSan issue that is fixed in the dev version # TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has # occasional valgrind errors skip_on_linux_devel() n_rows <- 10000 n_partitions <- 10 example_df <- expand.grid( part = letters[seq_len(n_partitions)], value = seq_len(n_rows), stringsAsFactors = FALSE ) # make sure values are different for each partition and example_df$row_num <- seq_len(nrow(example_df)) example_df$value <- example_df$value + match(example_df$part, letters) tf_dataset <- tempfile() tf_dest <- tempfile() on.exit(unlink(c(tf_dataset, tf_dest))) write_dataset(example_df, tf_dataset, partitioning = "part") register_scalar_function( "times_32", function(context, x) x * 32.0, int32(), float64(), auto_convert = TRUE ) on.exit(unregister_binding("times_32", update_cache = TRUE)) # check a regular collect() result <- open_dataset(tf_dataset) %>% dplyr::mutate(fun_result = times_32(value)) %>% dplyr::collect() %>% dplyr::arrange(row_num) expect_identical(result$fun_result, example_df$value * 32) # check a write_dataset() open_dataset(tf_dataset) %>% dplyr::mutate(fun_result = times_32(value)) %>% write_dataset(tf_dest) result2 <- dplyr::collect(open_dataset(tf_dest)) %>% dplyr::arrange(row_num) %>% dplyr::collect() expect_identical(result2$fun_result, example_df$value * 32) }) test_that("nested exec plans can contain user-defined functions", { skip_if_not_available("dataset") skip_if_not(CanRunWithCapturedR()) register_scalar_function( "times_32", function(context, x) x * 32.0, int32(), float64(), auto_convert = TRUE ) on.exit(unregister_binding("times_32", update_cache = TRUE)) stream_plan_with_udf <- function() { record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% as_record_batch_reader() %>% as_arrow_table() } collect_plan_with_head <- function() { record_batch(a = 1:1000) %>% dplyr::mutate(fun_result = times_32(a)) %>% head(11) %>% dplyr::collect() } expect_equal( stream_plan_with_udf(), record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% dplyr::collect(as_data_frame = FALSE) ) result <- collect_plan_with_head() expect_equal(nrow(result), 11) }) test_that("head() on exec plan containing user-defined functions", { skip("ARROW-18101") skip_if_not_available("dataset") skip_if_not(CanRunWithCapturedR()) register_scalar_function( "times_32", function(context, x) x * 32.0, int32(), float64(), auto_convert = TRUE ) on.exit(unregister_binding("times_32", update_cache = TRUE)) result <- record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% as_record_batch_reader() %>% head(11) %>% dplyr::collect() expect_equal(nrow(result), 11) })