Skip to content

A High-Performance JAX Framework for Programmatic Reinforcement Learning

License

Notifications You must be signed in to change notification settings

SynthesisLab/bordax

Repository files navigation

BordAX

A High-Performance JAX Framework for Programmatic Reinforcement Learning

Python 3.11+ JAX License: MIT Tests Coverage Code style: black


Overview

BordAX is a research-focused framework for Programmatic Reinforcement Learning (PRL) that combines the speed of JAX with support for structured, interpretable policies including neural networks, boolean functions, and decision trees.

Key Features

  • High Performance — Fully JIT-compiled training pipelines leveraging JAX's XLA compilation
  • Modular Architecture — Clean separation between agents, algorithms, environments, and training
  • Multiple Policy Types — MLPs, boolean functions (HyperBool), and decision trees (DTSemNet)
  • Flexible Algorithms — Built-in PPO (on-policy) and DQN (off-policy) with easy extensibility
  • Environment Agnostic — Supports both Gymnax (JIT-compiled) and Gymnasium environments
  • Production Ready — Checkpointing, logging, WandB integration, and comprehensive tests

Performance

BordAX achieves high performance through:

  • Full JIT compilation for jittable environments (Gymnax)
  • Vectorized environments via jax.vmap
  • Efficient loops using jax.lax.scan
  • Pure functional design compatible with XLA optimization

JIT Compilation Strategy

Environment Algorithm JIT Scope
Gymnax On-policy Entire train_step
Gymnax Off-policy update only
Gymnasium Any update only

Benchmark: BordAX vs Stable-Baselines3

PPO on CartPole-v1 with identical hyperparameters (5 seeds, 51k timesteps):

Framework Training Time Throughput Speedup
BordAX + Gymnax (Full JIT) 4.26s ± 0.12s 12,027 steps/s 3.2x
BordAX + Gymnasium 6.38s ± 0.27s 8,021 steps/s 2.2x
Stable-Baselines3 13.79s ± 0.55s 3,714 steps/s 1.0x

Benchmark Comparison

With Gymnax (fully JIT-compiled), BordAX is 3.2x faster than Stable-Baselines3. Even with Gymnasium (Python environment), BordAX is 2.2x faster.

Run the benchmark yourself:

pip install stable-baselines3
python compare_sb3.py

Installation

# Clone the repository
git clone https://github.com/SynthesisLab/bordax.git
cd bordax

# Install with uv (recommended)
uv sync

# Or with pip
pip install -e .

# With optional dependencies (WandB, visualization)
pip install -e ".[all]"

Verify Installation

python -c "from bordax.training.trainer import Trainer; print('BordAX installed successfully')"

Quick Start

Train PPO on CartPole

python train_ppo.py
  • Solves CartPole-v1 (reward = 500) in ~400k steps
  • Training time: ~18 seconds on CPU
  • Throughput: ~23,000 steps/s

PPO Training

Train DQN on CartPole

python train_dqn.py
  • Solves CartPole-v1 in ~50k steps
  • Training time: ~36 seconds on CPU
  • Includes 1,000 step warmup phase

DQN Training

Custom Training Script

import jax
from bordax.training.trainer import Trainer, TrainerConfig
from bordax.algorithms.utils import make_algo
from bordax.environments.utils import make_env
from bordax.agents.utils import make_agent

# Setup environments
env = make_env("gymnax/CartPole-v1", {"init_config": {}, "reset_config": {}}, num_envs=4)
eval_env = make_env("gymnax/CartPole-v1", {"init_config": {}, "reset_config": {}}, num_envs=1)

# Create agent with MLP policy and value networks
agent = make_agent("mlp/mlp", env, {
    "policy_layers": [64, 64],
    "value_layers": [64, 64],
})

# Configure PPO algorithm
algorithm = make_algo("ppo", {
    "lr": 3e-4,
    "rollout_length": 2048,
    "gamma": 0.99,
    "_lambda": 0.95,
    "clip_schedule": lambda _: 0.2,
    "vf_schedule": lambda _: 0.5,
    "ent_schedule": lambda _: 0.01,
    "num_minibatches": 16,
    "num_sgd_steps": 10,
})

# Setup trainer
config = TrainerConfig(
    num_checkpoints=100,
    epochs_per_checkpoint=1,
    evaluation_episodes=32,
    debug=True,
)

trainer = Trainer(env, eval_env, agent, algorithm, config)

# Train
key = jax.random.PRNGKey(0)
init_key, train_key = jax.random.split(key)
trainer.init(init_key)
eval_data = trainer.run(train_key)

Architecture

BordAX uses a modular pipeline architecture that cleanly separates concerns:

Trainer
  └─> Algorithm (Collector + BatchBuilder + Updater)
       ├─> Collector: Generates environment transitions
       ├─> BatchBuilder: Constructs training batches
       └─> Updater: Computes gradients and updates parameters

Core Components

Component Purpose Examples
Agent Defines policy and value networks MLPPolicyValue, BooleanPolicyValue, DTPolicy, DQNAgent
Algorithm Bundles training pipeline components ppo_algo(), dqn_algo()
Collector Generates transitions via env interaction OnPolicyCollector, EpsGreedyCollector
BatchBuilder Transforms data into training batches FullBufferBatch, MiniBatch, UniformReplayBatch
Updater Updates parameters using gradients SGDUpdate, DQNUpdater
Trainer Orchestrates full training loop Trainer

Supported Algorithms

Algorithm Type Collector Batch Strategy
PPO On-policy OnPolicyCollector FullBufferBatchMiniBatch
DQN Off-policy EpsGreedyCollector UniformReplayBatch

Policy Representations

Standard Neural Networks

MLP Policy-Value (mlp/mlp):

agent = make_agent("mlp/mlp", env, {
    "policy_layers": [128, 128, 64],
    "value_layers": [128, 128, 64],
})

Programmatic Policies

HyperBool — Boolean function-based policies (boolean/mlp):

agent = make_agent("boolean/mlp", env, {
    "n": 4,  # Number of boolean variables
    "value_layers": [128, 64, 32],
})

DTSemNet — Decision tree policies (dt/mlp):

agent = make_agent("dt/mlp", env, {
    "tree_depth": 4,
    "value_layers": [64, 64],
})

DQN Agent

Q-Network (dqn):

agent = make_agent("dqn", env, {
    "layers": [64, 64],
})

Project Structure

bordax/
├── bordax/
│   ├── agents/              # Agent implementations
│   │   ├── base.py          # MLPPolicyValue, BooleanPolicyValue, DTPolicy, DQNAgent
│   │   ├── components.py    # Neural modules (MLP, DTSemNet, BooleanFunction)
│   │   └── utils.py         # make_agent() factory
│   ├── algorithms/          # RL algorithms
│   │   ├── base.py          # Algorithm class, ppo_algo(), dqn_algo()
│   │   ├── losses.py        # PPOLoss, DQNLoss
│   │   └── utils.py         # make_algo() factory
│   ├── data/                # Data collection and batching
│   │   ├── collectors.py    # OnPolicyCollector, EpsGreedyCollector
│   │   ├── batchbuilders.py # Batch transformations
│   │   └── buffer.py        # ReplayBuffer
│   ├── environments/        # Environment adapters
│   │   └── utils.py         # EnvAdapter, make_env()
│   ├── training/            # Training infrastructure
│   │   ├── trainer.py       # Main Trainer class
│   │   ├── evaluation.py    # Evaluator
│   │   ├── logging.py       # Logger with WandB support
│   │   ├── checkpointing.py # Model checkpointing (Orbax)
│   │   └── updaters.py      # SGDUpdate, DQNUpdater
│   └── types.py             # Core type definitions
├── tests/                   # Test suite (48 tests, 77% coverage)
│   ├── unit/                # Fast component tests
│   ├── integration/         # Pipeline tests
│   └── slow/                # Learning verification tests
├── train_ppo.py             # PPO training example
├── train_dqn.py             # DQN training example
└── compare_sb3.py           # Stable-Baselines3 benchmark

Testing

BordAX has a comprehensive test suite with 48 tests achieving 77% code coverage.

# Run all tests (excluding slow)
uv run pytest tests/ -m "not slow" -v

# Run slow learning tests
uv run pytest tests/ -m slow -v

# Run with coverage
uv run pytest tests/ --cov=bordax --cov-report=term-missing

Test Categories

Category Tests Purpose
Unit 44 Fast component tests
Integration 2 Full pipeline verification
Slow 2 Learning verification

Dependencies

Package Version Purpose
JAX >=0.8.0 Core computation
Flax >=0.12.0 Neural networks
Optax >=0.2.6 Optimizers
Gymnax >=0.0.9 JAX environments
Gymnasium >=1.2.0 Standard environments
Distrax >=0.1.7 Distributions
Orbax >=0.11.32 Checkpointing

Optional: WandB (experiment tracking), Matplotlib/Seaborn (visualization)


Restoring from Checkpoints

# Restore last checkpoint and continue training
python train_ppo.py --restore-last

License

BordAX is released under the MIT License.


Acknowledgments

BordAX builds on excellent work from the JAX ecosystem:

  • JAX — High-performance numerical computing
  • Flax — Neural network library
  • Gymnax — JAX-compatible RL environments
  • Optax — Gradient processing and optimization
  • Distrax — Probability distributions

Built with JAX for speed and interpretability

About

A High-Performance JAX Framework for Programmatic Reinforcement Learning

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages