@@ -780,6 +780,11 @@ def test_load_checkpoint_resharding_jax(self):
780780 if len (devices ) < 1 :
781781 self .skipTest ("Test requires at least 1 JAX device" )
782782
783+ # Skip test if there are more than 2 devices, as these tests are
784+ # designed for 2-device scenarios and may not work with more devices
785+ if len (devices ) > 2 :
786+ self .skipTest (f"Test for 2 devices, but { len (devices )} available" )
787+
783788 num_devices = min (2 , len (devices ))
784789
785790 # Configure JAX to use virtual devices if needed
@@ -796,7 +801,7 @@ def test_load_checkpoint_resharding_jax(self):
796801 # Set up distribution based on available devices
797802 if num_devices >= 2 :
798803 # Multi-device distribution
799- device_mesh = DeviceMesh ((2 ,), axis_names = ["data" ])
804+ device_mesh = DeviceMesh ((num_devices ,), axis_names = ["data" ])
800805 layout_map = LayoutMap (device_mesh )
801806 layout_map ["dense_layer/kernel" ] = TensorLayout (
802807 axes = ("data" , None )
@@ -930,6 +935,11 @@ def test_distributed_checkpoint_directory_structure(self):
930935 if len (devices ) < 1 :
931936 self .skipTest ("Test requires at least 1 JAX device" )
932937
938+ # Skip test if more than 2 devices, as these tests are designed
939+ # for 2-device scenarios and may not work correctly with more devices
940+ if len (devices ) > 2 :
941+ self .skipTest (f"Test requires 2 devices, found { len (devices )} " )
942+
933943 num_devices = min (2 , len (devices ))
934944
935945 # Configure JAX to use virtual devices if needed
@@ -946,7 +956,7 @@ def test_distributed_checkpoint_directory_structure(self):
946956 # Set up distribution based on available devices
947957 if num_devices >= 2 :
948958 # Multi-device distribution for distributed checkpointing test
949- device_mesh = DeviceMesh ((2 ,), axis_names = ["data" ])
959+ device_mesh = DeviceMesh ((num_devices ,), axis_names = ["data" ])
950960 layout_map = LayoutMap (device_mesh )
951961 layout_map ["dense_layer/kernel" ] = TensorLayout (
952962 axes = ("data" , None )
0 commit comments