Skip to content

Shakeri-Lab/SoTra

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Soft-Token Trajectory Forecasting: Addressing Teacher Forcing in Risk-Aware Autoregressive Time Series Forecasting

This repository accompanies our paper SoTra: Soft-Token Trajectory Forecasting.
It contains all of the code needed to preprocess the data, train our proposed model, run ablation studies, and compare against several state-of-the-art baselines. The code has been tested with Python ≥ 3.9 and PyTorch ≥ 2.0 on Linux, but should run on any platform with minor adjustments. We also provide SLURM scripts for running on HPC.


Repository layout

directory purpose
SoTra contains the implementation of our SoTra model together with dataset-specific scripts, SLURM job files and evaluation tools. Each sub-folder corresponds to a dataset used in the paper (e.g. BP-mean, BP-sys, dclp3-exp, pedap-exp, Ohio-exp) and has its own training script (train.py), testing script (test.py), ablation script (test-ablation.py), job-submission files (train6.slurm, train12.slurm, …), and directories for logs, models and results. The .npy files (zone_matrix_mean.npy, risk_matrix_systolic.npy, etc.) store zone boundaries and risk weights used for Clarke-style error-grid analysis.
baselines code for the baseline models against which SoTra is compared. The baselines include PatchTST, iTransformer, DLinear and Chronos. The trainer.py script implements distributed training and evaluation and is driven via the train_gpu_dist.sh and inference.sh SLURM scripts. A data folder holds streaming dataloader utilities, layers contains model layers, and utils provides configuration management and logging.
dataset-preparation scripts for downloading and preprocessing the raw vital-sign data from the VitalDB API and merging the processed segments into a single CSV compatible with our dataloaders. The vital_process.py script downloads minute-wise arterial blood-pressure tracks, removes invalid samples and splits long recordings into valid continuous segments, while merge_pedap.py merges the resulting per-case CSV files into a single file. An example track list (trks.csv) and shell script (run.sh) are provided. The blood glucose dataset need to be merged too.

The archive you received only includes empty dataset, logs, models and results folders in the dataset-specific sub-directories. You will populate these when you prepare data and run experiments. We have provided the code to programmatically download the blood pressure datasets. DCLP3 and PEDAD datasets can be downloaded from here and the Ohio dataset can be downloaded from here

Training the SoTra model

Each dataset folder in SoTra contains a self-contained implementation of SoTra along with scripts to run the experiments. The files of interest are:

  • train.py – trains the model for a specified prediction horizon. The script takes an optional command-line argument specifying the trajectory length / horizon (e.g. python3 train.py 12 trains a model to predict 12 future time-steps). It reads the merged CSV from dataset_directory/dataset_name (default values are dataset/merged_processed_mbp.csv or dataset/merged_processed_sbp.csv), constructs PyTorch datasets, and writes checkpoints to the models folder.
  • test.py – loads a trained checkpoint and evaluates the model on the test split. It outputs the root-mean-square error (RMSE) and a risk-weighted loss based on the corresponding zone matrix. The results are written into the results directory and can be further analysed with the ablation scripts.
  • test-ablation.py – runs ablation variants of the model. The ablations correspond to the configurations reported in our paper’s tables; the numbers appended to the resulting files indicate the configuration number.
  • clarke.py – Plots the error grid.

Citation

This repository currently cites the arXiv version of the accepted paper. The citation will be updated once the final published version is available.

@misc{namazi2025mitigatingexposurebiasriskaware,
      title={Mitigating Exposure Bias in Risk-Aware Time Series Forecasting with Soft Tokens},
      author={Alireza Namazi and Amirreza Dolatpour Fathkouhi and Heman Shakeri},
      year={2025},
      eprint={2512.10056},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2512.10056},
}

About

Official repo for "Soft-Token Trajectory Forecasting: Addressing Teacher Forcing in Risk-Aware Autoregressive Time Series Forecasting"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors