R Under development (unstable) (2024-08-16 r87026 ucrt) -- "Unsuffered Consequences" Copyright (C) 2024 The R Foundation for Statistical Computing Platform: x86_64-w64-mingw32/x64 R is free software and comes with ABSOLUTELY NO WARRANTY. You are welcome to redistribute it under certain conditions. Type 'license()' or 'licence()' for distribution details. R is a collaborative project with many contributors. Type 'contributors()' for more information and 'citation()' on how to cite R or R packages in publications. Type 'demo()' for some demos, 'help()' for on-line help, or 'help.start()' for an HTML browser interface to help. Type 'q()' to quit R. > suppressWarnings(RNGversion("3.5.2")) > > library("partykit") Loading required package: grid Loading required package: libcoin Loading required package: mvtnorm > stopifnot(require("party")) Loading required package: party Loading required package: modeltools Loading required package: stats4 Loading required package: strucchange Loading required package: zoo Attaching package: 'zoo' The following objects are masked from 'package:base': as.Date, as.Date.numeric Loading required package: sandwich Attaching package: 'party' The following objects are masked from 'package:partykit': cforest, ctree, ctree_control, edge_simple, mob, mob_control, node_barplot, node_bivplot, node_boxplot, node_inner, node_surv, node_terminal, varimp > set.seed(29) > > ### regression > airq <- airquality[complete.cases(airquality),] > > mtry <- ncol(airq) - 1L > ntree <- 25 > > cf_partykit <- partykit::cforest(Ozone ~ ., data = airq, + ntree = ntree, mtry = mtry) > > w <- do.call("cbind", cf_partykit$weights) > > cf_party <- party::cforest(Ozone ~ ., data = airq, + control = party::cforest_unbiased(ntree = ntree, mtry = mtry), + weights = w) > > p_partykit <- predict(cf_partykit) > p_party <- predict(cf_party) > > stopifnot(max(abs(p_partykit - p_party)) < sqrt(.Machine$double.eps)) > > prettytree(cf_party@ensemble[[1]], inames = names(airq)[-1]) 1) Wind <= 5.7; criterion = 1, statistic = 30.75 2)* weights = 0 1) Wind > 5.7 3) Temp <= 84; criterion = 1, statistic = 30.238 4) Temp <= 77; criterion = 0.999, statistic = 10.471 5) Wind <= 9.2; criterion = 0.895, statistic = 2.632 6)* weights = 0 5) Wind > 9.2 7) Solar.R <= 112; criterion = 0.907, statistic = 2.823 8)* weights = 0 7) Solar.R > 112 9)* weights = 0 4) Temp > 77 10) Day <= 13; criterion = 0.981, statistic = 5.479 11)* weights = 0 10) Day > 13 12)* weights = 0 3) Temp > 84 13)* weights = 0 > party(cf_partykit$nodes[[1]], data = model.frame(cf_partykit)) [1] root | [2] Wind <= 5.7: * | [3] Wind > 5.7 | | [4] Temp <= 84 | | | [5] Temp <= 77 | | | | [6] Wind <= 9.2: * | | | | [7] Wind > 9.2 | | | | | [8] Solar.R <= 112: * | | | | | [9] Solar.R > 112: * | | | [10] Temp > 77 | | | | [11] Day <= 13: * | | | | [12] Day > 13: * | | [13] Temp > 84: * > > v_party <- do.call("rbind", lapply(1:5, function(i) party::varimp(cf_party))) > > v_partykit <- do.call("rbind", lapply(1:5, function(i) partykit::varimp(cf_partykit))) > > summary(v_party) Solar.R Wind Temp Month Min. :22.87 Min. :146.3 Min. :760.9 Min. :0.5159 1st Qu.:25.06 1st Qu.:152.8 1st Qu.:784.3 1st Qu.:0.5236 Median :26.11 Median :176.0 Median :806.2 Median :0.6119 Mean :26.90 Mean :171.9 Mean :813.8 Mean :0.7391 3rd Qu.:26.26 3rd Qu.:189.3 3rd Qu.:841.5 3rd Qu.:0.9886 Max. :34.18 Max. :195.1 Max. :875.9 Max. :1.0556 Day Min. :2.051 1st Qu.:2.512 Median :2.689 Mean :3.409 3rd Qu.:3.487 Max. :6.304 > summary(v_partykit) Solar.R Wind Temp Month Min. :23.35 Min. :161.7 Min. :760.8 Min. :-2.446 1st Qu.:24.81 1st Qu.:190.1 1st Qu.:763.4 1st Qu.: 2.983 Median :26.93 Median :199.4 Median :768.7 Median : 3.440 Mean :29.65 Mean :195.5 Mean :777.1 Mean : 2.662 3rd Qu.:31.46 3rd Qu.:205.0 3rd Qu.:769.2 3rd Qu.: 4.575 Max. :41.69 Max. :221.5 Max. :823.4 Max. : 4.757 Day Min. :-1.1396 1st Qu.:-0.4362 Median :24.3535 Mean :17.7578 3rd Qu.:31.8914 Max. :34.1200 > > party::varimp(cf_party, conditional = TRUE) Solar.R Wind Temp Month Day 16.7190604 100.7812597 534.9587763 -0.2538655 4.4848324 > partykit::varimp(cf_partykit, conditional = TRUE) Solar.R Wind Temp Month Day 27.520179 144.897612 476.407961 0.308407 -0.655686 > > > ### classification > set.seed(29) > mtry <- ncol(iris) - 1L > ntree <- 25 > > cf_partykit <- partykit::cforest(Species ~ ., data = iris, + ntree = ntree, mtry = mtry) > > w <- do.call("cbind", cf_partykit$weights) > > cf_party <- party::cforest(Species ~ ., data = iris, + control = party::cforest_unbiased(ntree = ntree, mtry = mtry), + weights = w) > > p_partykit <- predict(cf_partykit, type = "prob") > p_party <- do.call("rbind", treeresponse(cf_party)) > > stopifnot(max(abs(unclass(p_partykit) - unclass(p_party))) < sqrt(.Machine$double.eps)) > > prettytree(cf_party@ensemble[[1]], inames = names(iris)[-5]) 1) Petal.Length <= 1.9; criterion = 1, statistic = 86.933 2)* weights = 0 1) Petal.Length > 1.9 3) Petal.Width <= 1.6; criterion = 1, statistic = 42.075 4) Sepal.Width <= 2.5; criterion = 0.931, statistic = 3.316 5)* weights = 0 4) Sepal.Width > 2.5 6)* weights = 0 3) Petal.Width > 1.6 7) Petal.Length <= 5.1; criterion = 0.774, statistic = 1.466 8)* weights = 0 7) Petal.Length > 5.1 9)* weights = 0 > party(cf_partykit$nodes[[1]], data = model.frame(cf_partykit)) [1] root | [2] Petal.Length <= 1.9: * | [3] Petal.Length > 1.9 | | [4] Petal.Width <= 1.6 | | | [5] Sepal.Width <= 2.5: * | | | [6] Sepal.Width > 2.5: * | | [7] Petal.Width > 1.6 | | | [8] Petal.Length <= 5.1: * | | | [9] Petal.Length > 5.1: * > > v_party <- do.call("rbind", lapply(1:5, function(i) party::varimp(cf_party))) > > v_partykit <- do.call("rbind", lapply(1:5, function(i) + partykit::varimp(cf_partykit, risk = "mis"))) > > summary(v_party) Sepal.Length Sepal.Width Petal.Length Petal.Width Min. :0 Min. :0 Min. :0.3786 Min. :0.3014 1st Qu.:0 1st Qu.:0 1st Qu.:0.3807 1st Qu.:0.3029 Median :0 Median :0 Median :0.4000 Median :0.3050 Mean :0 Mean :0 Mean :0.3941 Mean :0.3111 3rd Qu.:0 3rd Qu.:0 3rd Qu.:0.4036 3rd Qu.:0.3121 Max. :0 Max. :0 Max. :0.4079 Max. :0.3343 > summary(v_partykit) Sepal.Width Petal.Length Petal.Width Min. :0 Min. :0.3869 Min. :0.2971 1st Qu.:0 1st Qu.:0.3921 1st Qu.:0.3036 Median :0 Median :0.3966 Median :0.3057 Mean :0 Mean :0.3952 Mean :0.3117 3rd Qu.:0 3rd Qu.:0.4003 3rd Qu.:0.3179 Max. :0 Max. :0.4003 Max. :0.3343 > > party::varimp(cf_party, conditional = TRUE) Sepal.Length Sepal.Width Petal.Length Petal.Width 0.0000000 0.0000000 0.2778571 0.1014286 > partykit::varimp(cf_partykit, risk = "misclass", conditional = TRUE) Sepal.Width Petal.Length Petal.Width 0.0000000 0.2782738 0.1171429 > > ### mean aggregation > set.seed(29) > > ### fit forest > cf <- partykit::cforest(dist ~ speed, data = cars, ntree = 100) > > ### prediction; scale = TRUE introduced in 1.2-1 > pr <- predict(cf, newdata = cars[1,,drop = FALSE], type = "response", scale = TRUE) > ### this is equivalent to > w <- predict(cf, newdata = cars[1,,drop = FALSE], type = "weights") > stopifnot(isTRUE(all.equal(pr, sum(w * cars$dist) / sum(w), + check.attributes = FALSE))) > > ### check if this is the same as mean aggregation > > ### compute terminal node IDs for first obs > nd1 <- predict(cf, newdata = cars[1,,drop = FALSE], type = "node") > ### compute terminal nide IDs for all obs > nd <- predict(cf, newdata = cars, type = "node") > ### random forests weighs > lw <- cf$weights > > ### compute mean predictions for each tree > ### and extract mean for terminal node containing > ### first observation > np <- vector(mode = "list", length = length(lw)) > m <- numeric(length(lw)) > > for (i in 1:length(lw)) { + np[[i]] <- tapply(lw[[i]] * cars$dist, nd[[i]], sum) / + tapply(lw[[i]], nd[[i]], sum) + m[i] <- np[[i]][as.character(nd1[i])] + } > > stopifnot(isTRUE(all.equal(mean(m), sum(w * cars$dist) / sum(w)))) > > ### check parallel variable importance (make this reproducible) > if(.Platform$OS.type == "unix") { + RNGkind("L'Ecuyer-CMRG") + v1 <- partykit::varimp(cf_partykit, risk = "misclass", conditional = TRUE, cores = 2) + v2 <- partykit::varimp(cf_partykit, risk = "misclass", conditional = TRUE, cores = 2) + stopifnot(all.equal(v1, v2)) + } > > ### check weights argument > cf_partykit <- partykit::cforest(Species ~ ., data = iris, + ntree = ntree, mtry = 4) > w <- do.call("cbind", cf_partykit$weights) > cf_2 <- partykit::cforest(Species ~ ., data = iris, + ntree = ntree, mtry = 4, weights = w) > stopifnot(max(abs(predict(cf_2, type = "prob") - + predict(cf_partykit, type = "prob"))) < sqrt(.Machine$double.eps)) > > > proc.time() user system elapsed 4.92 0.31 5.18