Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# The name for the entire test suite run.
suite_name: "Llama 3.1 70B slice 16"
num_repeats: 1

mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 16}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: [true, false]
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true # speed up the loading
num_of_repeats: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# The name for the entire test suite run.
suite_name: "Llama 3.1 70B slice 2"
num_repeats: 20

mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 2}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: [true, false]
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true # speed up the loading
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# The name for the entire test suite run.
suite_name: "Llama 3.1 70B slice 32"
num_repeats: 1

mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 32}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: [true, false]
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true # speed up the loading
num_of_repeats: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# The name for the entire test suite run.
suite_name: "Llama 3.1 70B slice 4"
num_repeats: 20

mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 4}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: [true, false]
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true # speed up the loading
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# The name for the entire test suite run.
suite_name: "Llama 3.1 70B slice 8"
num_repeats: 1 # depends on the number of repeats in the benchmark.

mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 8}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
options:
# --- Generator Options ---
# These keys must match the attributes of the `V1BenchmarkOptions` class
# associated with the `V1Benchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: [true, false]
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true # speed up the loading
num_of_repeats: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for multi-slice benchmarks."""

from __future__ import annotations

from typing import Any

from absl import logging
from etils import epath
import jax
from orbax.checkpoint import v1 as ocp
from orbax.checkpoint._src.multihost import multislice
from orbax.checkpoint._src.testing.benchmarks.core import checkpoint_generation


def get_multi_slice_abstract_state(
context: ocp.Context,
global_mesh: jax.sharding.Mesh,
*,
reference_checkpoint_path: epath.Path,
reference_sharding_path: epath.Path,
) -> Any:
"""Returns the abstract state for all replicas."""
with ocp.Context(context=context):
metadata = ocp.pytree_metadata(reference_checkpoint_path)
# Abstract tree has shardings on a single replica.
single_replica_abstract_state = (
checkpoint_generation.get_abstract_state_from_sharding_config(
reference_sharding_path,
metadata.metadata,
devices=multislice.replica_devices(
global_mesh, replica_id=0, replica_axis_index=0
).tolist(),
)
)

# Blow shardings up to all replicas.
def _multi_replica_sharding(abstract_arr: jax.ShapeDtypeStruct):
logging.info(
"Original (single-replica) sharding: %s", abstract_arr.sharding
)
assert isinstance(abstract_arr.sharding, jax.sharding.NamedSharding)
single_replica_mesh = abstract_arr.sharding.mesh
single_replica_partition_spec = abstract_arr.sharding.spec
multi_replica_sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(
devices=global_mesh.devices.reshape(
-1, *single_replica_mesh.devices.shape
),
axis_names=["replica", *single_replica_mesh.axis_names],
),
spec=jax.sharding.PartitionSpec(*single_replica_partition_spec),
)
logging.info("Multi-replica sharding: %s", multi_replica_sharding)
return jax.ShapeDtypeStruct(
shape=abstract_arr.shape,
dtype=abstract_arr.dtype,
sharding=multi_replica_sharding,
)

return jax.tree.map(
_multi_replica_sharding,
single_replica_abstract_state,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import v1 as ocp
from orbax.checkpoint._src.testing.benchmarks.v1 import multi_slice_util


_REQUIRED_DEVICE_COUNT = 16


class MultiSliceUtilTest(parameterized.TestCase):

def setUp(self):
self._prev_xla_flags = os.environ.get('XLA_FLAGS')
os.environ['XLA_FLAGS'] = (
self._prev_xla_flags or ''
) + ' --xla_force_host_platform_device_count=16'
super().setUp()
if jax.local_device_count() != _REQUIRED_DEVICE_COUNT:
self.skipTest(
f'Test requires {_REQUIRED_DEVICE_COUNT} local devices, but only'
f' {jax.local_device_count()} are available. Set XLA_FLAGS='
f'"--xla_force_host_platform_device_count={_REQUIRED_DEVICE_COUNT}"'
' before JAX initializes.'
)
self.directory = epath.Path(self.create_tempdir().full_path)

def tearDown(self):
if self._prev_xla_flags is None:
os.environ.pop('XLA_FLAGS', None)
else:
os.environ['XLA_FLAGS'] = self._prev_xla_flags
super().tearDown()

def test_get_multi_slice_abstract_state(self):
# Setup real checkpoint and sharding config
pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}}
ref_ckpt_path = self.directory / 'ref_ckpt'
ocp.save_pytree(ref_ckpt_path, pytree)

sharding_config = {
'a': {
'shape': [32],
'dtype': 'int32',
'sharding': {
'mesh': {'shape': [4], 'axes': ['model']},
'spec': ['model'],
},
},
'b.c': {
'shape': [8, 8],
'dtype': 'float32',
'sharding': {
'mesh': {'shape': [4], 'axes': ['model']},
'spec': [None, 'model'],
},
},
}
sharding_config_path = self.directory / 'sharding_config.json'
sharding_config_path.write_text(json.dumps(sharding_config))
global_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape((4, 4)), ('replica', 'model')
)

abstract_pytree = multi_slice_util.get_multi_slice_abstract_state(
context=ocp.Context(),
global_mesh=global_mesh,
reference_checkpoint_path=ref_ckpt_path,
reference_sharding_path=sharding_config_path,
)
self.assertEqual(
{'replica': 4, 'model': 4}, abstract_pytree['a'].sharding.mesh.shape
)
self.assertEqual(
jax.sharding.PartitionSpec('model'), abstract_pytree['a'].sharding.spec
)
self.assertEqual(
{'replica': 4, 'model': 4},
abstract_pytree['b']['c'].sharding.mesh.shape,
)
self.assertEqual(
jax.sharding.PartitionSpec(None, 'model'),
abstract_pytree['b']['c'].sharding.spec,
)


if __name__ == '__main__':
absltest.main()
Loading
Loading