R Under development (unstable) (2024-12-01 r87412 ucrt) -- "Unsuffered Consequences" Copyright (C) 2024 The R Foundation for Statistical Computing Platform: x86_64-w64-mingw32/x64 R is free software and comes with ABSOLUTELY NO WARRANTY. You are welcome to redistribute it under certain conditions. Type 'license()' or 'licence()' for distribution details. R is a collaborative project with many contributors. Type 'contributors()' for more information and 'citation()' on how to cite R or R packages in publications. Type 'demo()' for some demos, 'help()' for on-line help, or 'help.start()' for an HTML browser interface to help. Type 'q()' to quit R. > suppressWarnings(RNGversion("3.5.2")) > > ## packages > library("partykit") Loading required package: grid Loading required package: libcoin Loading required package: mvtnorm > library("rpart") > > ## data-generating process > dgp <- function(n) + data.frame(y = gl(4, n), x1 = rnorm(4 * n), x2 = rnorm(4 * n)) > > ## rpart check > learn <- dgp(100) > fit <- as.party(rpart(y ~ ., data = learn)) > test <- dgp(100000) > system.time(id <- fitted_node(node_party(fit), test)) user system elapsed 0.08 0.11 0.19 > system.time(yhat <- predict_party(fit, id = id, newdata = test)) user system elapsed 0.03 0.00 0.03 > > ### predictions in info slots > tmp <- data.frame(x = rnorm(100)) > pfit <- party(node = partynode(1L, split = partysplit(1L, breaks = 0), + kids = list(partynode(2L, info = -0.5), partynode(3L, info = 0.5))), data = tmp) > pfit [1] root | [2] x <= 0: -0.5 | [3] x > 0: 0.5 > p <- predict(pfit, newdata = tmp) > p 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 3 2 2 3 3 3 2 3 3 3 2 3 2 3 3 3 2 2 3 2 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 2 2 2 3 3 3 3 2 2 3 3 3 3 3 3 3 2 3 3 3 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 2 3 3 2 2 2 2 3 2 3 2 2 3 3 2 2 2 2 3 3 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 2 2 2 2 3 2 2 2 2 2 2 2 2 2 3 3 2 3 2 2 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 3 3 2 3 3 2 3 3 2 2 3 3 2 2 3 3 2 2 3 2 > table(p, sign(tmp$x)) p -1 1 2 51 0 3 0 49 > > proc.time() user system elapsed 1.87 0.46 2.32