Soft-Token Trajectory Forecasting: Addressing Teacher Forcing in Risk-Aware Autoregressive Time Series Forecasting
This repository accompanies our paper SoTra: Soft-Token Trajectory Forecasting.
It contains all of the code needed to preprocess the data, train our proposed model, run ablation studies, and compare against several state-of-the-art baselines. The code has been tested with Python ≥ 3.9 and PyTorch ≥ 2.0 on Linux, but should run on any platform with minor adjustments. We also provide SLURM scripts for running on HPC.
| directory | purpose |
|---|---|
| SoTra | contains the implementation of our SoTra model together with dataset-specific scripts, SLURM job files and evaluation tools. Each sub-folder corresponds to a dataset used in the paper (e.g. BP-mean, BP-sys, dclp3-exp, pedap-exp, Ohio-exp) and has its own training script (train.py), testing script (test.py), ablation script (test-ablation.py), job-submission files (train6.slurm, train12.slurm, …), and directories for logs, models and results. The .npy files (zone_matrix_mean.npy, risk_matrix_systolic.npy, etc.) store zone boundaries and risk weights used for Clarke-style error-grid analysis. |
| baselines | code for the baseline models against which SoTra is compared. The baselines include PatchTST, iTransformer, DLinear and Chronos. The trainer.py script implements distributed training and evaluation and is driven via the train_gpu_dist.sh and inference.sh SLURM scripts. A data folder holds streaming dataloader utilities, layers contains model layers, and utils provides configuration management and logging. |
| dataset-preparation | scripts for downloading and preprocessing the raw vital-sign data from the VitalDB API and merging the processed segments into a single CSV compatible with our dataloaders. The vital_process.py script downloads minute-wise arterial blood-pressure tracks, removes invalid samples and splits long recordings into valid continuous segments, while merge_pedap.py merges the resulting per-case CSV files into a single file. An example track list (trks.csv) and shell script (run.sh) are provided. The blood glucose dataset need to be merged too. |
The archive you received only includes empty
dataset,logs,modelsandresultsfolders in the dataset-specific sub-directories. You will populate these when you prepare data and run experiments. We have provided the code to programmatically download the blood pressure datasets. DCLP3 and PEDAD datasets can be downloaded from here and the Ohio dataset can be downloaded from here
Each dataset folder in SoTra contains a self-contained implementation of SoTra along with scripts to run the experiments. The files of interest are:
train.py– trains the model for a specified prediction horizon. The script takes an optional command-line argument specifying the trajectory length / horizon (e.g.python3 train.py 12trains a model to predict 12 future time-steps). It reads the merged CSV fromdataset_directory/dataset_name(default values aredataset/merged_processed_mbp.csvordataset/merged_processed_sbp.csv), constructs PyTorch datasets, and writes checkpoints to themodelsfolder.test.py– loads a trained checkpoint and evaluates the model on the test split. It outputs the root-mean-square error (RMSE) and a risk-weighted loss based on the corresponding zone matrix. The results are written into theresultsdirectory and can be further analysed with the ablation scripts.test-ablation.py– runs ablation variants of the model. The ablations correspond to the configurations reported in our paper’s tables; the numbers appended to the resulting files indicate the configuration number.clarke.py– Plots the error grid.
This repository currently cites the arXiv version of the accepted paper. The citation will be updated once the final published version is available.
@misc{namazi2025mitigatingexposurebiasriskaware,
title={Mitigating Exposure Bias in Risk-Aware Time Series Forecasting with Soft Tokens},
author={Alireza Namazi and Amirreza Dolatpour Fathkouhi and Heman Shakeri},
year={2025},
eprint={2512.10056},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2512.10056},
}