R version 4.4.0 beta (2024-04-10 r86396 ucrt) -- "Puppy Cup" Copyright (C) 2024 The R Foundation for Statistical Computing Platform: x86_64-w64-mingw32/x64 R is free software and comes with ABSOLUTELY NO WARRANTY. You are welcome to redistribute it under certain conditions. Type 'license()' or 'licence()' for distribution details. R is a collaborative project with many contributors. Type 'contributors()' for more information and 'citation()' on how to cite R or R packages in publications. Type 'demo()' for some demos, 'help()' for on-line help, or 'help.start()' for an HTML browser interface to help. Type 'q()' to quit R. > library(deepNN) > > N <- 1000 > d <- matrix(rnorm(5*N),ncol=5) > > fun <- function(x){ + lp <- 2*x[2] + pr <- exp(lp) / (1 + exp(lp)) + ret <- c(0,0) + ret[1+rbinom(1,1,pr)] <- 1 + return(ret) + } > > d <- lapply(1:N,function(i){return(d[i,])}) > > truth <- lapply(d,fun) > > net <- network( dims = c(5,10,2), + activ=list(ReLU(),softmax())) > > netwts <- train( dat=d, + truth=truth, + net=net, + eps=0.01, + tol=100, # run for 100 iterations + batchsize=10, # note this is not enough + loss=multinomial(), # for convergence + stopping="maxit") [1] 6.943888 0.000000 [1] 6.933191 0.000000 [1] 7.032016 0.000000 [1] 6.989746 0.000000 [1] 6.94473 0.00000 [1] 6.93126 0.00000 [1] 6.924207 0.000000 [1] 6.878339 0.000000 [1] 6.804929 0.000000 [1] 6.842238 0.000000 [1] 6.948685 0.000000 [1] 6.947105 0.000000 [1] 7.157745 0.000000 [1] 6.997445 0.000000 [1] 6.970845 0.000000 [1] 6.946958 0.000000 [1] 6.936468 0.000000 [1] 6.962108 0.000000 [1] 6.930813 0.000000 [1] 6.907072 0.000000 [1] 6.976217 0.504552 [1] 6.9525927 0.5043667 [1] 6.9310563 0.5047823 [1] 6.9111286 0.5054169 [1] 6.9326443 0.5055459 [1] 6.8952349 0.5052593 [1] 6.9357291 0.5058216 [1] 7.0389860 0.5052832 [1] 6.9310304 0.5057209 [1] 6.9404746 0.5048241 [1] 6.9085857 0.5034731 [1] 6.9844696 0.5014907 [1] 6.9034801 0.5031606 [1] 7.0253702 0.5024427 [1] 6.9272416 0.5016678 [1] 6.9077345 0.5014947 [1] 7.0190759 0.5008033 [1] 6.9300916 0.5008765 [1] 6.8915674 0.5005828 [1] 6.9919556 0.4996046 [1] 6.9328957 0.4997447 [1] 6.8297960 0.4996346 [1] 6.8499712 0.5009858 [1] 6.9436685 0.5022507 [1] 7.146680 0.500801 [1] 6.933249 0.499650 [1] 6.8257340 0.5007447 [1] 6.8489394 0.5011794 [1] 6.9426278 0.5030873 [1] 6.9401977 0.5017766 [1] 6.7487812 0.5007416 [1] 6.9483056 0.5024404 [1] 6.698150 0.504517 [1] 6.6430786 0.5060727 [1] 7.3546492 0.5077867 [1] 7.0959394 0.5095324 [1] 6.8311074 0.5106092 [1] 7.0746581 0.5126448 [1] 7.0436260 0.5114198 [1] 6.9363123 0.5110042 [1] 7.0054073 0.5108349 [1] 6.8794410 0.5098413 [1] 6.9316004 0.5102436 [1] 6.8640991 0.5084315 [1] 6.7815058 0.5062398 [1] 6.710277 0.506078 [1] 6.9482787 0.5085365 [1] 7.0935504 0.5061651 [1] 6.8276493 0.5037181 [1] 6.8115175 0.5040053 [1] 6.8046144 0.5069212 [1] 6.7900560 0.5062854 [1] 6.6263585 0.5067395 [1] 6.9644205 0.5028933 [1] 6.7810564 0.5032769 [1] 6.9696115 0.5034135 [1] 7.1565737 0.4995546 [1] 7.271969 0.496418 [1] 7.0505095 0.4962658 [1] 7.0926441 0.4985442 [1] 6.9217549 0.4989787 [1] 6.9212817 0.5003502 [1] 6.9192657 0.5004377 [1] 6.9490857 0.4992437 [1] 6.9240124 0.5001761 [1] 6.9404202 0.5010556 [1] 6.8551292 0.5028274 [1] 7.0580443 0.5004227 [1] 7.0226837 0.5025313 [1] 7.0921951 0.5032965 [1] 6.8928034 0.5018612 [1] 6.8957730 0.5000211 [1] 6.8766057 0.5020528 [1] 6.8638582 0.5006472 [1] 6.8295882 0.5037857 [1] 6.9946041 0.5025404 [1] 6.941964 0.500780 [1] 6.9112299 0.5034342 [1] 6.8076352 0.5054724 > > pred <- NNpredict( net=net, + param=netwts$opt, + newdata=d, + newtruth=truth, + record=TRUE, + plot=TRUE) 1000 > > proc.time() user system elapsed 2.09 0.31 2.39