Introduction
The study by Sun et al. (“Predicting Human Brain States with Transformer”) advances the exploration of whether future human brain states—ascertained via functional magnetic resonance imaging (fMRI)—can be accurately forecast using deep-learning models founded on transformer architectures. The guiding principle behind this endeavor is that the human brain, despite its complexity and ceaseless dynamism, has underlying temporal patterns in its resting-state neural activity that might be learned via self-attention mechanisms.
Human neural mechanisms remain, to a considerable extent, enigmatic. Cognition, affect, language, and the underpinnings of memory all emerge from an intricate labyrinth of neuronal and synaptic interactions. Many brain disorders, such as Alzheimer’s disease and schizophrenia, are inexorably linked to dysfunctional patterns of neural connectivity and dynamics, rendering their functional connectivity (FC) critical for diagnostic and prognostic models. Resting-state fMRI (rs-fMRI) is a means to capture these network-level phenomena without necessitating explicit task performance, and it offers an avenue to study these neural substrates in a non-invasive manner.
Sun et al. question whether the near-future fMRI time points—brain states that occur seconds downstream—can be systematically predicted from previously observed activity. This question has immediate ramifications. If successful, it could eventually reduce the scanning time required for clinical assessments, particularly in populations for whom extended scanning can be stressful or unfeasible. Furthermore, it might elucidate new directions for real-time brain-computer interfaces (BCI), as dynamic brain-state forecasts could empower more robust interactions between human neural activity and technological systems.
Within this broader context, the authors adopt the transformer architecture, notable for its multi-headed self-attention structure, to tackle the challenge of brain-state prediction. Transformers have dramatically transformed the landscape of sequential data analysis in language modeling, music generation, and computer vision. Their capacity to capture long-range dependencies makes them a particularly appealing candidate for modeling time series of neural data, where neural states at a given moment can be related to events many seconds prior.
While fMRI data, especially from extensive cohorts such as the Human Connectome Project (HCP), provide large-scale coverage of the brain’s activity, the inherent complexity and noise of brain signals can confound naive models. However, by leveraging an auto-regressive formulation of the transformer architecture, the authors demonstrate that transformers can, indeed, generate meaningful predictions of brain states. They show, for instance, that one can accurately predict roughly 5 seconds of future brain activity using only 21.6 seconds of past data. Moreover, their results indicate that these predicted sequences preserve fundamental network-level structures akin to the typical human functional connectome.
In what follows, we dissect the data and preprocessing pipeline, the specifics of the transformer implementation, the training and validation protocols, and the final results demonstrating the feasibility of fMRI-based brain-state forecasting. Links and sources—particularly the authors’ GitHub repository—are included to enable direct access to their code and data processing scripts.

Data and Preprocessing
Data Source
Sun et al. relied on the well-known Human Connectome Project (HCP) young adult dataset, which aggregates high-quality imaging data (among other modalities) from over one thousand healthy participants in early adulthood. The authors utilized the 3 Tesla rs-fMRI data from 1,003 healthy young adults after excluding 110 participants with missing or incomplete resting-state sessions. Each participant had four resting-state fMRI scans, each consisting of 1,200 time points, giving a temporal resolution of 0.72 seconds and isotropic spatial resolution of 2 mm.
Further details on HCP protocols can be found at:
The Human Connectome Project
Preprocessing Steps
Though the HCP dataset already includes “minimally preprocessed” versions of the data, the authors applied additional preprocessing measures:
- Spatial Smoothing
A Gaussian filter with a 6 mm full width at half maximum (FWHM) was utilized in the CIFTI space to boost signal-to-noise ratio (SNR) and attenuate higher-frequency artifacts. Spatial smoothing is commonly applied to highlight coherent activations in contiguous areas and reduce voxel-level noise. - Bandpass Filtering
The signals were bandpass-filtered between 0.01 and 0.1 Hz. This frequency range is often targeted because resting-state neural fluctuations (including slow spontaneous dynamics) lie largely in that domain. - Z-Score Transformation
A z-score normalization was performed for each time series, forcing the mean to 0 and the standard deviation to 1, thereby ensuring that all participants’ data align to a consistent scale. - Regional Parcellation
A multi-modal parcellation (MMP) atlas subdivides the gray matter into 360 cortical and 19 subcortical areas, yielding 379 total regions of interest (ROIs). For each time point, the authors computed the average BOLD intensity in each ROI. Consequently, every “brain state” is a 379-dimensional vector capturing the BOLD signal across the entire brain.
Thus, each subject’s resting-state session became a time series of 1,200 vectors, where each vector encodes 379 region-specific signal intensities.
Model and Method
Auto-Regressive Framing
The central problem is conceived in an auto-regressive fashion: if xt∈R379\mathbf{x}_t \in \mathbb{R}^{379}xt∈R379 denotes the brain state at time ttt, then given a historical window of brain states xt−n,xt−n+1,…,xt−1\mathbf{x}_{t-n}, \mathbf{x}_{t-n+1}, \ldots, \mathbf{x}_{t-1}xt−n,xt−n+1,…,xt−1, the task is to predict xt\mathbf{x}_{t}xt. By iterating single-step predictions, one can generate an entire synthetic sequence of future brain states.
Transformer Architecture
Sun et al. adapted the time series transformer architecture, building on the seminal work of Vaswani et al. and later developments in time series forecasting with transformers. Transformers emphasize multi-headed self-attention, a mechanism that computes pairwise correlations (attention weights) among tokens in a sequence. This approach can capture long-range dependencies more effectively than recurrent networks that rely on gating mechanisms.
- Positional Encoding
Since self-attention alone does not intrinsically encode the order of tokens, sinusoidal positional encodings are injected into the input to convey time-step information. This ensures that the model is sensitive to the sequential ordering of the input states. - Encoder
The authors’ encoder is a stack of four layers, each layer comprising:- A multi-head self-attention block (with 8 attention heads).
- A feed-forward network.
The encoder processes the input window of size nnn (e.g., 30 time points of 379-dimensional vectors) and transforms it into a latent representation that captures salient temporal dependencies.
- Decoder
The decoder also has four layers of multi-head self-attention (again with 8 attention heads) and feed-forward networks. It accepts two inputs:- The encoder output.
- The last time point of the input sequence (i.e., xt−1\mathbf{x}_{t-1}xt−1).
The decoder attempts to predict the subsequent brain state, xt\mathbf{x}_txt, from these inputs.
- Final Linear Mapping
At the end of the decoder, a fully connected layer translates the internal feature representation into the predicted 379-dimensional vector of the next time point.
Unlike other sequential forecasting implementations (e.g., music generation or language modeling) that employ a look-ahead masking scheme to predict multiple future steps in a single forward pass, Sun et al. simplify the problem by predicting one time step at a time. This approach bypasses the need for complex masking strategies, making the training objective more direct: minimize the mean squared error (MSE) between x^t\widehat{\mathbf{x}}_txt and the ground-truth xt\mathbf{x}_txt.
Training Setup
Window Size, Batch Size, and Optimizer
- Window Size
During a preliminary set of experiments, multiple window sizes (n=5,10,15,20,30,40,50n = 5, 10, 15, 20, 30, 40, 50n=5,10,15,20,30,40,50) were tested using 40 subjects from the dataset. The authors discovered that n=30n = 30n=30 was optimal for single-step predictions and for generating longer synthetic sequences. - Optimizer and Learning Rate
The authors employed the Adam optimizer with an initial learning rate of 1×10−41 \times 10^{-4}1×10−4. The MSE loss function served as the objective, meaning the model parameters were tuned to minimize the squared error between predictions and actual fMRI signals. - Epochs
Preliminary findings indicated that training for around 20 epochs was sufficient for convergence when using a smaller subset of data. However, because of the much larger dataset for final training, the authors settled on 10 epochs within a 10-fold cross-validation framework: in each fold, approximately 90% of the 1,003 subjects were used for training, while the remaining 10% were reserved for validation. - Batch Size
To expedite training, they used a batch size of 512 sequences, each sequence being an extract of 30 consecutive time points from the resting-state data.
Ten-Fold Cross-Validation
With 1,003 subjects, a robust 10-fold cross-validation was applied. By repeatedly rotating which 90% of subjects formed the training set and which 10% formed the validation set, the authors ensured that any idiosyncrasies were not overemphasized, thereby boosting generalizability. The final reported metrics (e.g., MSE, correlations, functional connectivity comparisons) were aggregated across these folds, giving a comprehensive view of the model’s performance.

Evaluation Protocol
Sun et al. carried out multiple levels of evaluation:
- Single-Step Prediction
- The model was tested on previously “unseen” subjects.
- For each time point ttt (starting from t=51t = 51t=51 to t=1200t = 1200t=1200), the input window {xt−30,…,xt−1}\{\mathbf{x}_{t-30}, \ldots, \mathbf{x}_{t-1}\}{xt−30,…,xt−1} was provided to the model to produce x^t\widehat{\mathbf{x}}_txt.
- The mean squared error (MSE) between x^t\widehat{\mathbf{x}}_txt and xt\mathbf{x}_txt was aggregated over all ttt.
- Shuffled Sequence Test
- The same input windows were “shuffled,” effectively destroying temporal ordering.
- If the model genuinely learns sequential dynamics, it should yield much higher errors on randomized sequences.
- The authors performed a paired t-test on these MSE distributions (true sequence vs. random sequence) to confirm significance.
- Long-Range Iterative Predictions
- Having validated the single-step predictions, the authors then tested how well the model could generate an extended synthetic sequence—without continuing access to ground-truth data.
- The process is iterative: after predicting x^t\widehat{\mathbf{x}}_{t}xt, that prediction is appended to the “input window,” shifting the window by one time step. Thus, after a few steps, the model is exclusively feeding on its own predictions (i.e., “closed loop”).
- By continuing this process until reaching the original sequence length, one obtains a full predicted time series.
- The authors used MSE and Spearman’s correlation between x^t\widehat{\mathbf{x}}_txt and xt\mathbf{x}_txt to quantify accuracy over time.
- Functional Connectivity (FC) Analysis
- The authors computed the pairwise Pearson’s correlation across all 379 ROIs in the predicted vs. real time series. This yields an FC matrix capturing how strongly each region correlates with every other region over time.
- Functional connectivity (FC) is integral to capturing large-scale network organization in the brain. If the model’s synthetic sequences preserve the typical adjacency relationships, one should expect predicted FC to overlap significantly with actual FC.
- To judge similarity, they calculated:
- Mean absolute difference between predicted FC and the group-average FC from real data.
- Spatial correlation of the upper triangular portion of the FC matrix (since FC is symmetric) relative to the true group-average FC.
Results
Model Selection
The preliminary experiments with 40 subjects revealed that:
- Window Size
When predicting a single next time point, window sizes of 15 and 30 performed best in terms of low MSE, but for generating full time series, a window size of 20 or 30 was preferable. Ultimately, the authors selected 30 as the final window size for its overall balanced performance. - Epochs
At around 20 epochs on the small subset, the model converged in single-step predictions. For the full dataset, with a 10-fold approach, the authors fixed 10 epochs per fold (noting that by the 6th epoch, the losses stabilized and overfitting was absent).
Single Next Time Point Prediction
When tested on unseen subjects, the model achieved a low average MSE of 0.0013 for single-step forecasting if the sequences were not shuffled. However, when the temporal order of the input was randomized, the MSE soared to 0.97, a dramatic increase by more than a factor of 700, and the difference was highly significant (p < 10^(-10)). This striking disparity powerfully indicates that the transformer’s success hinges on capturing the true temporal dependencies in the data rather than relying on simpler distributional statistics.
Long-Horizon Prediction
To assess how the prediction error evolves as the model repeatedly feeds its own synthetic outputs back into itself, Sun et al. observed that the earliest portion of the predicted sequence remains highly accurate:
- First 7 Time Points
The MSE remains below 0.15, and Spearman’s correlation with the ground truth is above 0.85. This effectively means that the model can robustly predict about 5.04 seconds of future brain states (since each time point is 0.72 seconds) from just 21.6 seconds of initial data. - Accumulating Error
Past this window of about 5 seconds, errors begin to accumulate in the iterative forecasts, a well-known phenomenon in auto-regressive networks. By 20 predicted steps, the MSE is roughly 0.26. Eventually, the MSE plateaus around 0.80 when predicting extended sequences purely from self-generated data. - Comparison with BrainLM
The authors cite results from BrainLM, which also sees an MSE rise after several predicted time points, with a 20-step MSE around 0.568—higher than the comparable measure reported by Sun et al.
Functional Connectivity Preservation
A hallmark of fMRI analysis is to assess whether predicted states preserve the typical correlation architecture among the 379 ROIs. Indeed, Sun et al. computed the group-average FC matrix from the real dataset versus the group-average FC matrix from the synthetic sequences. They also inspected single-subject FCs in the predicted data:
- Visually Similar FC
The predicted group-average FC resembled the real group-average FC in terms of global structure, though with smaller dynamic range in correlation values. - Quantitative Overlap
- Mean Absolute Differences between predicted FC and the real group-average FC typically ranged from 0.35 to 0.45.
- Spatial Correlations hovered between 0.50 and 0.60.
These findings verify that, despite error accumulation, the forecasted time series reflect the canonical functional connectome organization at a group level. In other words, the model captures not only local transitions from one time point to the next but also the broader functional relationships across brain regions—albeit with constraints in replicating fine-grained individual-level nuances for very long forecast windows.
Discussion
Transformer Advantages in Brain Modeling
The multi-headed self-attention mechanism is particularly advantageous for fMRI time series analysis because:
- It can capture long-range interactions without the vanishing gradients that plague recurrent neural networks.
- It naturally aligns with the graph-theoretical perspective, where each time point can be considered a node that influences many subsequent nodes. This parallels how neural activity in widely separated brain regions can remain functionally connected.
- It can exploit robust positional encodings to incorporate the inherent temporality of brain signals while still preserving the capacity to identify distant dependencies.
Error Accumulation
When the model’s outputs recursively become part of the input stream, prediction errors tend to balloon. Even small deviations in earlier predictions can lead to compounding discrepancies in later time steps. Although the first 5–7 predicted steps are impressively accurate, performance degrades for extended sequences, a limitation reminiscent of many generative modeling tasks. Resolving this will require further architectural innovations, for instance:
- Scheduled Sampling: Gradually replace ground truth data with model predictions during training.
- Backpropagation Through Time: Though more feasible in recurrent networks, some variant might be adapted in self-attention frameworks to model multi-step consistency.
- Diffusion or Denoising Approaches: Borrowing from generative modeling in images might help reduce iterative drift.
Potential Clinical and Research Implications
- Reduced Scan Times
The capacity to generate plausible future brain states from short segments can potentially relieve clinical burdens. In populations for whom extended scanning triggers anxiety or motion artifacts (e.g., pediatric populations, patients with movement disorders), obtaining just ~20 seconds of data might suffice for certain analytics. - Brain-Computer Interfaces
Predictive modeling of ongoing brain states could feed into real-time BCI frameworks, where precise knowledge of near-future neural patterns can optimize or refine external prosthetic control or feedback loops. - Neuroscientific Insights
Understanding how a transformer “pays attention” to specific time points (or patterns) may open interpretive pathways into which neural configurations serve as robust drivers or attractors of future activity. This can shed light on fundamental questions about network-based cognition at rest. - Applications to Pathological States
Although the authors focus on healthy young adults, the same architecture could be extended to patient populations with brain disorders. The transformation of functional connectivity over short windows might be indicative of pathological disruptions. Thus, the model could help forecast seizure onset or track progressive neurodegeneration, if appropriately trained on relevant cohorts.
Conclusion and Future Directions
Sun et al. demonstrate that a transformer-based model, trained on large-scale HCP data, can indeed predict near-future brain states from short windows of preceding fMRI data. Remarkably, up to 5.04 seconds of future activity can be forecast with minimal error, and even though errors accumulate over long sequences, the globally emergent functional connectivity patterns remain in line with canonical human connectomes.
Looking ahead, the authors emphasize the following enhancements:
- Improving Accuracy Over Extended Horizons
Strategies are needed to mitigate cascaded errors. Potential methods could involve bridging ideas from advanced sequence modeling or generative denoising. - Personalization
The authors propose exploring transfer learning to tailor a pre-trained “group” model to specific individuals, thus capturing idiosyncratic connectivity features. This personalized approach could lead to more accurate forecasts at the single-subject level, especially for clinical populations. - Interpretability
One of the key desiderata is to make these deep models more transparent. By extracting attention maps, neuroscientists may glean which time points or regions exert the greatest influence on predicted future states, potentially illuminating the large-scale principles of cortical-subcortical coordination. - Extending to Broader Domains
Although the results center on resting-state data, similar approaches might be adapted for task-based fMRI, electroencephalography (EEG), and other high-dimensional signals. The fundamental auto-regression on neural time series is widely applicable.
Links and Sources
- Project Code Repository
The authors have made their code publicly available at https://github.com/syf0122/brain_state_pred. This repository hosts scripts for data preprocessing (adapted to the HCP minimal preprocessing pipelines), model training, cross-validation, and functional connectivity analysis. - Human Connectome Project Data
The dataset employed in this study is available through the WU-Minn Human Connectome Project (Young Adult). Additional details can be found at https://www.humanconnectome.org/study/hcp-young-adult.
Final Remarks
Sun et al. have showcased a compelling application of transformer-based architectures to a dataset of considerable size and quality (HCP). The fact that they can extrapolate about 5 seconds of future resting-state activity with impressive fidelity paves the way for shortened scans and more adaptive BCI designs. Although the approach’s accuracy diminishes in longer auto-regressive sequences due to compounding errors, the large-scale functional connectivity remains close to that of real data in terms of group-average correlation patterns.
By rendering the code open-source, the authors invite further refinement of this method, including personalization strategies, interpretative analysis of self-attention maps, and expansions of the approach to broader imaging modalities or multi-modal fusion with structural MRI, diffusion MRI, and beyond. The continuing evolution of such models may unveil the intricate tapestry of neural interactions that underlie cognition and disease processes, heralding new vistas in computational neuroscience, precision medicine, and advanced AI-driven analyses of human brain function.
This research is incredibly exciting, especially with its potential to reduce scan times for patients who have trouble with long fMRI scans!
I 100% agree! 🙂
😄😄