cito: An R package for training neural networks using torch

Deep Neural Networks (DNN) have become a central method for regression and classification tasks. Some packages exist that allow to fit DNN directly in R, but those are rather limited in their functionality. Most current deep learning applications rely on one of the major deep learning frameworks, in particular PyTorch or TensorFlow, to build and train DNNs. Using these frameworks, however, requires substantially more training and time than typical regression or machine learning functions in the R environment. Here, we present ’cito’, a user-friendly R package for deep learning that allows to specify deep neural networks in the familiar formula syntax used in many R packages. To fit the models, ’cito’ uses ’torch’, taking advantage of the numerically optimized torch library, including the ability to switch between training models on CPUs or GPUs. Moreover, ’cito’ includes many user-friendly functions for model plotting and analysis, including optional confidence intervals (CIs) based on bootstraps on predictions as well as explainable AI (xAI) metrics for effect sizes and variable importance with CIs and p-values. To showcase a typical analysis pipeline using ’cito’, including its built-in xAI features to explore the trained DNN, we build a species distribution model of the African elephant. We hope that by providing a user-friendly R framework to specify, deploy and interpret deep neural networks, ’cito’ will make this interesting model class more accessible to ecological data analysis. A stable version of ‘cito’ can be installed from the comprehensive R archive network (CRAN).


Introduction
Deep neural networks (DNN) are increasingly used in ecology and evolution for regression and classification tasks such as species distribution models, image classification or sound analysis (Christin et al., 2019;Joseph, 2020;Pichler and Hartig, 2023a;Strydom et al., 2021).State-ofthe-art DNN are almost exclusively implemented and trained in specialized deep learning (DL) frameworks such as PyTorch or Tensorflow (Abadi et al., 2016;Paszke et al., 2019).These frameworks, most of which are implemented in Python, provide flexible and performant functions and classes that allow users to implement and train complex DL architectures, such as large language models (e.g., GPT-3 (Brown et al., 2020), RoBERTA (Liu et al., 2019)) or complex object detection models (e.g., Mask R-CNN (He et al., 2017), DeepVit (Zhou et al., 2021)).Their high level of flexibility is appealing to "power users", but the complexity of these frameworks can be prohibitive or at least repelling for scientists with limited knowledge in the field that merely want to use neural networks in standard applications.
As a response to this problem, several simplified frontends for the major DL frameworks have been developed.Many of those are also available in R, the language used by most ecologists for practical data analysis.Well-known examples are 'Keras' for TensorFlow and luz for 'torch' (Allaire and Chollet, 2022;Falbel, 2022).However, while these frontends indeed simplify the model building process considerably, their general structure and syntax still resembles those of the major Python frameworks rather than those of popular R packages for regression or classification tasks that specify models using the formula syntax such as 'ranger', for training random forests, or 'lme4', for training mixed-effect models (Bates et al., 2015;Wright and Ziegler, 2017).Moreover, DL frontends such as 'Keras' or 'luz' mainly concentrate on model fitting and include only a very limited set of plots and convenience functions which are common to most R packages.As a result, working with these frontends still requires a considerable amount of training for users that are so far only familiar with standard R packages.Especially because users have to choose or program code for downstream tasks such as bootstrapping, plots or explainable AI (xAI) metrics by hand.
Besides the mentioned frontends to the major DL frameworks, some specialized R packages for training DNN exist that more closely adhere to the syntax used in most popular R packages, in particular the formula syntax to specify the model structure.However, those packages often lack crucial functionalities, and most of them do not make use of state-of-the-art DL frameworks for model fitting.This limits their use for large DNN because of their numerical inefficiency or their inability to train the models on GPUs.Established R packages such as 'nnet' or 'neuralnet' do not support modern DL techniques, such as different regularization techniques (e.g.dropout) to control the bias-variance tradeoff (Fritsch et al., 2019;Venables and Ripley, 2002) or modern training techniques such as early stopping or learning rate schedulers that help to achieve convergence.The 'h2o' package comes with its own Java backend, and while it allows specifying models with the standard formula syntax, its use in R is cumbersome due to its inability to work with default R objects (LeDell et al., 2022).The 'brulee' R package (Kuhn and Falbel, 2022), which uses 'torch' to train the DNNs specified in standard R syntax, is very similar to the package presented here, but still lacks some critical features (see section 'Performance analysis and validation').
Here, we present 'cito', an R package for training fully-connected neural networks using the standard R formula syntax for model specification.Based on the 'torch' DL framework, 'cito' allows flexible specifying of fully-connected neural networks architectures, supports many modern DL techniques (e.g.dropout and elastic net regularization, learning rate schedulers), can take advantage of CPU and GPU hardware for parallelization, and, despite its simple user interface, optionally offers a high degree of customization such as user-defined loss functions.Moreover, 'cito' supports many downstream functionalities, such as the possibility to continue the training of existing DNN with modified training parameters for fine-tuning, or the application of explainable AI (xAI) methods to interpret the trained models.As such, 'cito' provides a userfriendly but nevertheless complete analysis pipeline for building neural networks in R.
In the remainder of the paper, we introduce the design principles of 'cito' in more detail, show validation and performance analysis, and showcase the application of cito using the example of a species distribution model of the African elephant.

Design of the cito package
Torch backend 'cito' uses 'torch', a variant of PyTorch, as its backend to represent and train the specified neural networks.Until recently, R users who wanted to use PyTorch and Tensorflow had to call their Python bindings through the 'reticulate' package.R packages that relied on this pipeline were thus dependent on appropriate Python installations (e.g.Pichler and Hartig, 2021), which often created dependency issues.This issue got solved with the release of 'torch', a native implementation of the torch libraries with an R frontend (Falbel and Luraschi, 2022).

Building and training neural networks in cito
With 'torch', R users can essentially use PyTorch natively in R, which solves dependency issues, but not the problem that specifying a DNN with 'torch' is complex.
'cito' addresses this problem by providing one simple command, dnn(), which combines everything needed to build and train a fully-connected neural network in one line of code (see package vignette('A-Introduction_to_cito') for more details).The dnn() function includes options to modify the network architecture, the training process and the monitoring (e.g. by visualization) of the training and validation loss ( Table 1), including a baseline loss (based on intercept-only models) that helps to diagnose convergence problems due to inappropriately chosen training hyperparameters (e.g., learning rate and epochs).
The dnn() function returns an S3 object that can be used, for example, with the continue_training() function to continue training for additional epochs (iterations) with the same or modified training hyperparameters or data.Moreover, many standard R functions such as summary(), predict() or residuals() are implemented for the trained models, and additional specialized explainable xAI functions are available for interpreting the fitted networks.More details on these and other functions are available in the R package vignettes that come with the cito package.
The lack of uncertainties (standard errors) is an often-raised concern for DNN.In 'cito', we provide an option to automatically calculate confidence intervals for all outputs (including xAI metrics and predictions) using bootstrapping.As bootstrapping can be computationally expensive, the default for this option is set to false.Bootstrapping can be enabled in the dnn() function setting, e.g., dnn(… ,bootstrap = 50).Bootstrap standard errors are then automatically propagated through all downstream methods and are also used to generate p-values wherever obvious null hypotheses exist.We recommend starting without bootstrapping to optimize the training procedure (Fig. 2) and to then enable the bootstrap for the final model after the training pipeline has been finalized.

Performance comparison and validation of cito
After explaining the design of cito, we shortly compare its performance and functionality with other packages for implementing neural networks in R. We consider in particular 'nnet' and 'neuralnet', which each have their own backend and are not based on modern DL frameworks (Fritsch et al., 2019;Venables and Ripley, 2002), 'h2o', which possesses a much broader toolkit for training neural networks than the previous two packages ( LeDell et al., 2022), and 'brulee' (Kuhn and Falbel, 2022), which, similar to cito, uses the 'torch' DL framework as a backend.
Our comparison shows that 'cito' implements more options than other packages, in particular GPU support, the possibility to continue training and custom loss functions and most importantly tools to interpret the trained DNN models (Table 2).Looking at computational performance, measured by the time it takes to train the networks, we find that some of the older packages, in particular 'neuralnet', perform better than the torchbased packages (including 'cito') for small networks (Figure 1).This is probably due to the smaller overhead of these more specialized packages.However, when moving to larger networks (large and especially wide networks are often beneficial for achieving low generalization errors (Belkin et al., 2019)) 'cito' can play out one of the main advantage of modern ML frameworks, which is GPU support.On the GPU, training time in cito is practically independent of the size of the network, confirming the consensus that training large networks requires GPU resources.On a CPU, 'cito' performs on par with 'brulee', the other torch-based package, but somewhat worse than 'neuralnet'.We interpret these results as showing that for a simple problem, there is still some overhead of using 'torch' as opposed to a native C implementation.Nevertheless, we would argue that the added flexibility and functionality of cito outweighs this advantage of 'neuralnet'.Moreover, our results suggest that the difference between the torch packages and 'neuralnet' lies mainly in the constant overhead needed to set up the models.For large models, their performance is roughly equal.

Workflow and case study
So far, we have mainly discussed the process of model training, which is arguably the core of any machine learning project.Now, we want to comment on the entire workflow when using 'cito' to build and interpret a predictive model.This workflow usually consists of model specification, training, and interpretation and predictions (Figure 2).To make the discussion of the workflow more accessible to the reader, we illustrate this workflow with the example (based on Ryo et al. ( 2021)) of building a species distribution model (SDM) for the African elephant (Loxodonta Africana).
SDMs are niche models that correlate environment with species occurrence data (see Elith & Leathwick, 2009).As occurrence data, we use records of African elephant presence from Ryo et al. ( 2021) that was based on Angelov, 2020, who compiled data from different studies available on GBIF (INaturalist Contributors, 2022a, 2022b;Jlegind, 2021;Musila et al., 2019;Navarro, 2022).Those presence-only data were supplemented by Angelov, 2020 with randomly sampled background points (pseudo-absences) to generate a presence-absence signal for the classifier.As predictors, we used all 19 bioclimatic variables from WorldClim v2 (Fourcade et al., 2018), which were centered and standardized.While it is common in statistical modelling to sample more pseudo-absences than presences, such unbalanced class numbers can be harmful for machine learning algorithms.We therefore randomly undersampled pseudo-absences to match the number of observations (another option would be to oversample presences, but in our example, this resulted in lower accuracy in interim results).The trained models can be used with a range of in-build functions.The predict() can be used to predict the occurrence probability of the elephant (Fig. 3a).The summary() function provides an overview about influential variables by calculating their importances (Fisher et al., 2019) as well as average conditional effects (which are an approximation of linear effects, see (Pichler and Hartig, 2023b)) (Fig. 4a).Partial dependency plots (PDP) and averaged local effect plots (ALE) functions can be used to display the effect of specific features on the response, in this case the occurrence probability of the elephant (Fig. 4b, c).If bootstrapping is enabled, 'cito' automatically uses the bootstrap samples to calculate confidence intervals (CI, as standard errors) for the predictions (Fig. 3a), CIs and p-values for the xAI metrics (Fig. 4a), and CIs for the PDP and ALE plots (Fig. 4b, c).

Conclusion
'cito' is a powerful and versatile R package for building and training fully-connected neural networks with a formula syntax.The package seamlessly integrates into the R regression ecosystem and removes many hurdles in using neural networks for inexperienced users, but also saves programming time for experienced users who just want to build simple neural networks.The unique combination of features provided by 'cito', such as training on a GPU, using custom loss functions, baseline loss, confidence intervals, modern DL training techniques such as continue training, learning rate scheduler or early stopping cannot be found in other packages.Future releases of 'cito' aim to implement additional functionalities such as internal cross validation for hyperparameter optimization, gradient based methods for hyperparameter tuning and the integration of recurrent and convolutional neural networks.

Figure 1 :
Figure 1: Runtime comparison of different deep learning R software packages ('brulee', 'h2o', 'neuralnet', and 'cito' (CPU and GPU)) on different network sizes on an Intel Xeon 6128 and a Nvidia RTX 2080ti.The networks consisted of five equally sized layers (50 to 1000 nodes with a step size of 50) and are trained on a simulated data set with 1000 observations.Panel (A) shows the runtime of the different packages and panel (B) shows the average root mean square error (RMSE) of the models on a holdout of size 1000 observations (RMSE was averaged over different network sizes).Each network was trained 20 times (the dataset was resampled each time).

Figure 3 :
Figure 3: Predictions and standard errors of prediction for the African elephant from a DNN trained by cito.Panel (A) shows the predicted probability of occurrence of the African elephant.Panel (B) shows the standard error for the predicted probabilities (confidence interval).

Figure 4 :
Figure 4: xAI metrics with bootstrap confidence intervals (+/-1 se) from model trained by 'cito'.Panel (A) shows (permutation) feature importances and average conditional effects (approximation of linear effects) from the summary() output for the 19 Bioclim variables.Panel (B) and (C) show the accumulated local effect plots (ALE), i.e., the change of the predicted occurrence probability, for the Bioclim variables 3 (Isothermality) and 4 (Temperature Seasonality).

Table 1 :
Hyperparameters for fully-connected neural networks and their default values in 'cito'.Defaults for all parameters are set to sensible values; however, some parameters typically need to be tuned.Detailed guidance on this is provided in the help file of the dnn() function or in the cito R package vignette 'Training neural networks'.

Table 2 :
Feature comparison of R packages used to build fully-connected neural networks This plot can be used to diagnose convergence problems, for example if the training loss does not decrease over time or does not fall below the baseline loss.In this case, it would be advisable to abort and restart the training with different hyperparameters (e.g., smaller learning rate), use a learning rate scheduler, or perform a systematic hyperparameter tuning.We provide extensive help on this topic in the documentation and in a vignette (vignette('B-Training_neural_networks')).Here we show an example where we restart the training with a smaller learning rate and a learning rate scheduler that automatically reduces the learning rate if the loss does not decrease in 8 continuous epochs (patience =8) to achieve a better fit: