Attention module improves both performance and interpretability of four‐dimensional functional magnetic resonance imaging decoding neural network

Abstract Decoding brain cognitive states from neuroimaging signals is an important topic in neuroscience. In recent years, deep neural networks (DNNs) have been recruited for multiple brain state decoding and achieved good performance. However, the open question of how to interpret the DNN black box remains unanswered. Capitalizing on advances in machine learning, we integrated attention modules into brain decoders to facilitate an in‐depth interpretation of DNN channels. A four‐dimensional (4D) convolution operation was also included to extract temporo‐spatial interaction within the fMRI signal. The experiments showed that the proposed model obtains a very high accuracy (97.4%) and outperforms previous researches on the seven different task benchmarks from the Human Connectome Project (HCP) dataset. The visualization analysis further illustrated the hierarchical emergence of task‐specific masks with depth. Finally, the model was retrained to regress individual traits within the HCP and to classify viewing images from the BOLD5000 dataset, respectively. Transfer learning also achieves good performance. Further visualization analysis shows that, after transfer learning, low‐level attention masks remained similar to the source domain, whereas high‐level attention masks changed adaptively. In conclusion, the proposed 4D model with attention module performed well and facilitated interpretation of DNNs, which is helpful for subsequent research.


| INTRODUCTION
For many years, decoding the brain's activities has been one of the major topics in neuroscience. Inferring brain states consists of predicting the tasks subjects performed and identifying brain regions related to specific cognitive functions (Friston et al., 1994;Lv et al., 2015;McKeown et al., 1998;Norman, Polyn, Detre, & Haxby, 2006). Deep learning (DL) methods based on a variety of artificial neural networks have gained considerable attention in the scientific community for more than a decade, breaking benchmark records in several domains, including vision, speech, and natural language processing (Krizhevsky, Sutskever, & Hinton, 2017;LeCun, Bengio, & Hinton, 2015). In this context, deep neural networks (DNNs), especially convolutional neural networks (CNNs), have been recruited for brain decoding Li & Fan, 2018;Yin, Li, & Wu, 2020;Zhang, Tetrel, Thirion, & Bellec, 2021), and achieved high accuracy (>90%) in brain multiple state decoding (Nguyen, Ng, Kaplan, & Ray, 2020;X. Wang et al., 2020). It is important to note, however, several open challenges still need to be addressed while using deep learning to investigate functional magnetic resonance imaging (fMRI) data.
The first challenge is the abstraction of complex temporo-spatial features within the fMRI time series. A fMRI time series is a fourdimensional (4D) data that consists of three-dimensional (3D) spatial and one-dimensional (1D) temporal information, which means brain regions engage and disengage in time during coherent cognitive activity (Chen, Kreutz-Delgado, Sereno, & Huang, 2019;Shine et al., 2016).
Inspired by this, Mao et al. (2019) developed a model of 3D CNN stacks and a long short-term memory (LSTM) for spatial and temporal feature abstraction, respectively. A bit more reasonable approach would be to jointly leverage the inherent spatial-temporal information in fMRI data (Ismail Fawaz, Forestier, Weber, Idoumghar, & Muller, 2019). However, designing and optimizing architectures for 4D fMRI decoding is difficult due to the lack of systematic comparisons of various spatiotemporal processing and the substantial explosion of computational and memory requirements.
The second challenge is the researchers' requirement for a higher degree of accountability of the model, which is the core of the feasibility and reproducibility of brain decoding (Lindsay, 2020). Deep learning is regarded as a black-box model, and recent efforts have been made to develop an interpretable brain decoding model through feature ranking (Li & Fan, 2019), visualizing the convolutional kernels (Vu, Kim, Jung, & Lee, 2020), guided back-propagation (X. Wang et al., 2020), and so on. Improved DNN interpretability in fMRI analysis could lead to more accountable usage, better algorithm maintenance and improvement, and more open science (Tjoa & Guan, 2021).
Another challenge is the conflict between the DNNs' requirement for large amounts of data and the relatively modest quantity of datasets in typical cognitive research (Yotsutsuji, Lei, & Akama, 2021).
Most fMRI experiments comprise tens to hundreds of participants due to experimental costs or participant selection. It is natural to use transfer learning to alleviate the data scarcity problem in the target domain (e.g., small sample datasets) by utilizing the knowledge acquired in the source domain (e.g., large cohorts; Gao, Zhang, Wang, Guo, & Zhang, 2019;Svanera et al., 2019;Thomas, Müller, & Samek, 2019;X. Wang et al., 2020). The fMRI data vary across datasets (e.g., scanner, scanning parameters, task design, template space), so it remains an open question how far the DNN can transferlearn in fMRI.
Inspired by these challenges, the main contributions to this article are threefold. First, we extended the problem of temporal modeling and spatial feature extraction to the 4D convolution module and compared various approaches to fMRI data processing. Second, we employed the mixed attention modules to improve the decoding performance, which not only enhanced the ability to distinguish and focus on specific features but also presented an in-depth interpretation of CNN. Third, we explored the benefits of transfer learning in fMRI analysis under different problem definitions and task design, demonstrating that the model that captures cognitive similarities can extend to distinguish individual trait differences.

| Human Connectome Project dataset
The minimally preprocessed 3T data from the S1200 release of the Human Connectome Project (HCP; Glasser et al., 2013) were used in this research. The present study included task fMRI of 1,034 subjects during seven tasks: emotion, gambling, language, motor, relational, social, and working memory (WM). The seven tasks, which lasted for about 20-30 frames under different conditions during each block, provided a high degree of brain activation coverage (Barch et al., 2013). Thus, the parameter estimates of the model trained on this dataset contained similarities to multiple cognitive domains and were utilized as the source domain in the transfer learning experiment. The HCP S1200 dataset has been preprocessed with the HCP functional pipeline and normalized to the Montreal Neurological Institute's (MNI) 152 space. According to the previous studies X. Wang et al., 2020), only one condition was selected for each task (Table 1) and resulted in 14,821 fMRI 4D instances across all subjects and tasks. To save computing memory, a bounding box with the size of [80,96,88] voxels was applied to each fMRI volume, and the blank parts that did not contain brain tissues were cropped out.

| BOLD5000 dataset
The BOLD5000 (Chang et al., 2019) dataset was also used for transfer learning of the proposed model. The dataset selected event-related design paradigms to investigate visual perception, which collected the fMRI data of four participants while viewing 5,000 real-world images. Each image was presented for 1 s and followed by a 9 s blank screen with a fixation cross. Thus, a single trial lasted five frames (repetition time, TR = 2 s). Two conditions of stimulus images were employed in this study: Scene containing whole scenes and ImageNet focusing on a single object. Implicit image attributes can provide category selectivity in high-level visual regions. Using fMRIPrep (Esteban et al., 2017), the preprocessing including motion correction, distortion correction, and co-registration to the corresponding T1w of the fMRI data was applied. Then each volume was also cropped to the size of [80,96,88] voxels, and each segmented fMRI input covered the entire trial and included two extra TRs extended forward and backward.

| The proposed neural network
The proposed model consists of a 4D convolution layer and four 3D attention modules, followed by a fully-connected layer ( Figure 1a).

| 4D convolution
The 4D convolution kernel K ℝ k l Âk h ÂkwÂk d Âkc was applied to the input x ℝ lÂhÂwÂdÂc , where l is the temporal length, h is the height, w is the width, d is the depth, and c is the length of the channels. The 4D convolution operation, Conv4D, was implemented by two loops of the native 3D convolution operation, Conv3D, of the Pytorch (Paszke et al., 2019): where s t is the temporal strides (s t = 1, 2, …) and Conv3D employed 3D convolution with a spatial stride of s = 2. A stride of >1 leads to a down-sample in the designated dimension. After the 4D convolution, the temporal dimension was squeezed and flattened to channel dimension of the subsequent 3D attention module.

| The attention module
The attention mechanism in the DNN selects focused regions and thus enhances the discriminative representation of objects (Vaswani et al., 2017). The attention module is also beneficial for optimizing by serving as a gradient update filter to prevent gradients from noisy Naive dot production of two branches degrades the value of features. Attention residual learning is used to ease this problem by constructing the attention branch as an identical mapping. Formally, the output of attention module x iþ1 serving as the input of the next layer is modified as: What's more, the attention mask branch can be viewed as an identical mapping that changes adaptively as layers go deeper. What the neural network learns at each level can be demonstrated by the distribution of attention. The attention masks of each channel were visualized to present an in-depth interpretation of the network by upsampling the feature map corresponding to A(x) and mapping it to T1w.

| Training and evaluation
The implementation of the different model variants is based on the PyTorch framework. Training was performed on an NVIDIA GTX 1080Ti graphic card. To conduct a fair comparison, the batch size was set to 16 and each model was trained for 60 epochs using the Adam algorithm with the standard parameters (β 1 = 0.9 and β 2 = 0.999). The learning rate was initialized at 0.0001 and decayed by a factor of 5 when the validation loss plateaued after 15 epochs. The loss converged well and overfitting was not observed during validation experiments. Our validation strategy employed a fivefold cross-validation across subjects and the dataset was categorized into subsets as follows: training set (70%), validating set (10%), and testing set (20%). Control experiments were conducted on various model variants (Table 2) to verify whether the 4D convolution and attention modules brought a substantial improvement. We also analyzed a set of 4DResNet consisting of different sizes of 4D kernels and presented comparison results using different frames as input. A segment of k continuous frames, which was randomly split from each instance, was used as input for training. During the testing stages, the predictions for all segmentations of one instance are summed up, and the task label with the majority vote is predicted to represent the final class of the instance.

| Transfer learning
Transfer learning describes a process in which a network is trained on a source dataset and subsequently reuses the parameters of the The key idea of this workflow is similar to that mentioned above. We fine-tuned the model to decode binary types of stimulus images (scene vs. object) seen by subjects and employed the leaveone-subject-out (LOSO) cross-validation, which means that the data from three subjects was used to train and one to test.

| Performance evaluation on HCP dataset
The performance of various models was compared by the mean and SD of accuracy ( Table 2). All of the proposed models effectively distinguished seven tasks, with the 4DResNet-Att outperforming the others with an accuracy of 97.4% ± 0.4% (mean ± SD). Figure 2a shows the decoding performance of 4DResNet-Att on seven cognitive tasks, and the confusion matrix shows a nice block diagonal architecture. The cognitive tasks were accurately identified with the accuracy of: Emotion (96.2 ± 0.2%), gambling (99.4 ± 0.3%), language (98.7 ± 0.4%), motor (96.0 ± 0.4%), relational (93.6 ± 0.9%), social (99.4 ± 0.3%), and WM (98.9 ± 0.4%). Furthermore, the confusion matrix showed misclassifications of the relational and the gambling, the emotion and the gambling, the motor and the gambling, and the relational and the WM.
The superior performance of the 4DResNet-Att model in comparison to the 3DResNet (X. Wang et al., 2020) and other recent researchers  is possibly due to the capability to handle complex spatiotemporal dynamics in fMRI series via 4D convolution operations and the use of the attention mechanism to adaptively select a focused location.
Specifically, the 4DResNet is able to capture dynamic changes in hemodynamic response on temporal dimension and to integrate these representations from interconnected brain regions on spatial dimension. To evaluate whether 4DCNN brings a substantial improvement over 3DCNN, the 4DResNet-Att model was compared with the 3DResNet-Att model on the same brain decoding tasks using different lengths of frames as input (Figure 2b). Overall, the 4DResNet substantially enhanced classification performance compared to the 3DResNet, except for the 7-frame condition. The low performance at shorter fMRI input could be caused by two factors: (1) few information in short input, especially in series shorter than a hemodynamic response; (2) the 4DResNet tends to measure the relative dynamic change over a long range. Besides, we also evaluated a set of 4DResNet consisting of different sizes of 4D kernels to decode brain activity. Our results revealed that decoders with a short 4D-kernel size achieved lower decoding performance than decoders using a relatively longer 4D-kernel ( Figure 2c).
Furthermore, to establish whether the use of attention mechanisms could enhance fMRI decoding, we compared the 4DResNet with attention modules and the naive 4DResNet. Figure 2c shows the The bolded values indicate the highest accuracy of different models.

| Visualization of attention mask on the HCP dataset
Previous studies have employed some visualizations to build an interpretable brain decoding model in fMRI analysis (Vu et al., 2020; X. Wang et al., 2020;Yin et al., 2020). Here, we visualized the focused regions of the attention module in each convolution layer to present an in-depth interpretation of the DNN. Each channel obtained seven attention masks for different tasks, which were averaged across all of the input samples from all of the subjects.
Overall, the resulting attention masks at the low-level (first and second stages) have excellent coverage of the brain and prefer to highlight the areas containing the useful BOLD signal, such as the whole brain structure (Figure 3a), and diminish the noise areas like the brainstem or cerebrospinal fluid areas ( Figure S1b,c). The masks also focused on some functional networks and cerebral cortex related to different cognitive functions ( Figure S1), such as the default mode network, sensorimotor network, temporal lobe, and occipital lobe. The enhancement of gray matter areas helped to preserve the important features that could be further refined to distinguish between different cognitive states at high-level.
The attention masks at the high-level (third and fourth stages) are getting more focused to cover task-specific brain areas (Figure 3c). It is notable, however, the focused layouts of the attention masks varied across different tasks and were remarkably task-specific. A channel could generate specific focused regions for different tasks, such as the left motor cortex areas in motor task, the ventral lateral prefrontal cortex and both superior and inferior temporal cortex in language task, the prefrontal cortex in relational task, and the temporal parietal junction and superior temporal cortex regions in social task (Figures S2 and S3). At the fourth stage, the attention masks become more abstract due to the stride in the convolution operation (Figure 3d), and the weights of attention have a narrower range, which could be due to the fact that the masks also serve as gradient update filters. A small range of attention weights in the high-level feature map could prevent some gradient problems.

| Transfer learning
Two different approaches were used to explore the benefits of transfer learning in fMRI analysis under different problem definitions or task design.  What's more, the initial model, which used the same architecture and was trained from scratch by initializing random weights achieved a lower correlation coefficient in prediction (r s = .306, p < .001).
The comparisons of predictions between different models were shown in Table 3. Furthermore, the visualization analysis shows that low-level attention masks remained distributed similarly to the source domain, whereas high-level attention masks changed adaptively as knowledge transferred from group similarities to individual differences ( Figure 4b).
Second, the pretrained model from the HCP dataset was finetuned to decode different types of stimulus images on BOLD5000.
The knowledge learned from the source domain is highly applicable to the target domain, and the transferred model achieved 77.6 ± 3.4% (4DResNet-Att), 73.5 ± 2.1% (4DResNet), and 64.3 ± 3.8% (3DResNet-Att) accuracy. However, all initial models trained from scratch failed to converge to satisfactory accuracy (<60%) across a wide range of choices of hyper-parameters. Furthermore, the visualizations demonstrated that the attention masks changed adaptively to fit individual subjects' brain structures, despite the fact that the fMRI data were registered to the corresponding T1w space rather than the standard MNI152 space ( Figure 5). As the model was fine-tuned to decode visual tasks, the attention masks from the high-levels also changed adaptively to reweight task-related brain regions. CNNs and passed these latent features to an LSTM network to take into account the temporal dependencies within task-evoked brain activity. The model we proposed includes a 4D convolution layer to detect temporo-spatial features, and puts the features into the channel dimension of the following 3D layers to reduce memory consumption.
The above results suggest that the proposed model has a good balance

| Attention module and interpretation of networks
The attention mechanism helps humans to mainly focus on the most useful information in the human perception process. Inspired by this, attention mechanisms have been studied extensively in many deep learning fields (Vaswani et al., 2017;F. Wang et al., 2017;Woo et al., 2018). In this research, the proposed 3D mixed attention module consisted of a main branch and an attention branch and considered both channel and spatial features. The experimental results demonstrate that attention modules have many advantages. For example, the architecture with attention modules was trained to converge faster and more easily and achieve better performance, which could be due to the attention mechanism reweighting the focused areas to enhance discriminative features. The attention module is also beneficial for optimizing during back-propagation, which serves as a gradient update filter to prevent noisy gradients and enhance gradients from important regions.
What's more, the attention modules not only improve decoding performance but also serve as a visualization tool to investigate how neural networks work in fMRI decoding. Cognitive neuroscience research requires a higher degree of accountability, while an end-to-end trainable network has always been regarded as a black-box in neuroscience. Presenting an in-depth interpretation of a method can demonstrate the feasibility and reproducibility of fMRI studies (Li & Fan, 2019;Vu et al., 2020). A good visual explanation should not only be treated as a localization method but also allow researchers to investigate how the neural network works. The analysis shows that the low-level masks provide excellent coverage of the brain to highlight useful structures while pruning noisy areas. As the layers go deeper, the attention masks get finer to cover various specific cortexes. The high-level attention masks varied across different tasks, re-weighting more attention to the areas related to the specific target task. What's more, the attention masks adapted to fit different subjects' brain structures. This also suggests that our architecture could be a suitable approach to avoid individual variability across subjects in the raw and minimally preprocessed fMRI series without spatial normalization. Besides, the attention areas that could present biologically meaningful interpretations of cognitive neuroscience demonstrated that the proposed CNN decoded states from task-related activations but not from nuisance variables.

| Transfer learning
Transferability has been demonstrated to be a significant advantage of DL methods over traditional methods in fMRI decoding (Gao et al., 2019;X. Wang et al., 2020). To this end, we explored the benefits of transfer learning under various conditions. The transferred regression model yielded significant predictions of individual trait differences and achieved better Spearman's correlation coefficient than the previous study (Greene et al., 2018). This could be due that the previous study relied on the discriminative power of feature selections, and not all connectivity parameters are relevant for prediction, while the transferred model could automatically capture the full range of individual trait differences. This also suggests that the group cognitive similarities among intrinsic brain states could generally be reused to predict individual differences, which is important for precision medicine in clinical research. Furthermore, previous studies most commonly applied transfer learning between the block-design dataset. On the BOLD5000, the pretrained model from the HCP dataset was fine-tuned to decode different visual tasks and obtained 77.6%. Despite the fact that the model was trained using the block-design dataset, the internal properties of human hemodynamic responses contained in the parameters are consistent and could be reused in the event-design dataset.

| Limitations and future applications
In this project, the proposed model outperformed other architectures.
Despite the 4D convolution processing dynamic changes more efficiently, some limits remain, such as a substantial increase in computational and memory requirements. What's more, we only chose one condition for each cognitive domain in order to be comparable to previous studies, while the BOLD signals might be a mixture of hemodynamic responses evoked by different task events. A decoding model with fine cognitive granularity would generalize similarities and differences among task-induced brain states from multiple cognitive domains, which is important for transfer learning. The visualization result demonstrated that the high decoding performance was driven by the response of biologically meaningful brain regions. However, the statistical property of the attention mask remains unclear. We could have the results of qualitative analysis and should be cautious until further investigations into its reliability and statistical properties. The transfer learning method, which successfully extended similarities in brain activity to individual differences, showed potential for research in psychiatry and neurology. The pretrained model based on cognitive state can serve as a brain information retrieval system to distinguish differences in neurologic diseases and classify different psychiatric categories.

| CONCLUSION
In this study, we designed a 4DResNet with attention module for brain decoding. After investigating the efficacy of some alternative classifiers, the proposed 4DResNet-Att achieved 97.4% on the HCP dataset. We further demonstrated the model's transferability to a variety of tasks and datasets and presented an in-depth interpretation of the network. The visualization analysis of attention distributions illustrated the hierarchical emergence of task-specific masks with depth. After transfer learning, the adaptively changed attention distribution demonstrated the representation could be general extended from cognitive similarities to individual differences.