Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions backend/python/mlx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ def is_int(s):
except ValueError:
return False

def parse_rdma_options(options_list):
"""
Parse RDMA-specific options from the options list.

Expected format: "mlx_rdma.enabled=true", "mlx_rdma.ctrl_host=0.0.0.0", etc.
Returns a dict if RDMA is enabled, None otherwise.
"""
rdma_config = {}
rdma_enabled = False

for opt in options_list:
if ":" not in opt:
continue
key, value = opt.split(":", 1)

# Check if this is an RDMA option
if key.startswith("mlx_rdma."):
# Extract the actual option name (e.g., "mlx_rdma.enabled" -> "enabled")
option_name = key[len("mlx_rdma."):]

# Convert value to appropriate type
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

rdma_config[option_name] = value

# Track if RDMA is explicitly enabled
if option_name == "enabled" and value:
rdma_enabled = True

return rdma_config if rdma_enabled else None

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
Expand Down Expand Up @@ -98,6 +134,30 @@ async def LoadModel(self, request, context):

print(f"Options: {self.options}", file=sys.stderr)

# Parse RDMA options and initialize distributed training if enabled
rdma_config = parse_rdma_options(options)
if rdma_config:
print(f"Initializing MLX distributed training with RDMA: {rdma_config}", file=sys.stderr)

# Set environment variables for mlx-jaccl-cluster if provided
if "ctrl_host" in rdma_config:
os.environ["JACCL_CTRL_HOST"] = str(rdma_config["ctrl_host"])
if "ctrl_port" in rdma_config:
os.environ["JACCL_CTRL_PORT"] = str(rdma_config["ctrl_port"])

# Initialize MLX distributed backend (JACCL for RDMA)
try:
mx.distributed.init(backend="jaccl")
print(f"MLX distributed training initialized. Rank: {mx.distributed.rank()}, World size: {mx.distributed.world_size()}", file=sys.stderr)
self.use_rdma = True
self.rank = mx.distributed.rank()
self.world_size = mx.distributed.world_size()
except Exception as e:
print(f"Failed to initialize MLX distributed training: {e}", file=sys.stderr)
self.use_rdma = False
else:
self.use_rdma = False

# Build tokenizer config for MLX using options
tokenizer_config = {}

Expand All @@ -120,7 +180,15 @@ async def LoadModel(self, request, context):
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
else:
self.model, self.tokenizer = load(request.Model)


# Shard model across RDMA workers if enabled
if self.use_rdma:
try:
self.model = self.model.shard(self.world_size)
print(f"Model sharded across {self.world_size} workers", file=sys.stderr)
except Exception as e:
print(f"Failed to shard model: {e}", file=sys.stderr)

# Initialize thread-safe LRU prompt cache for efficient generation
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
Expand Down Expand Up @@ -321,10 +389,6 @@ def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
return tokens.tolist()
return list(tokens)





def _build_generation_params(self, request, default_max_tokens=200):
"""
Build generation parameters from request attributes and options.
Expand Down Expand Up @@ -447,4 +511,4 @@ async def serve(address):
)
args = parser.parse_args()

asyncio.run(serve(args.addr))
asyncio.run(serve(args.addr))
10 changes: 9 additions & 1 deletion core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar)
}),
// Add MLX RDMA support if enabled via options
config.WithTunnelCallback(func(tunnels []string) {
if os.Getenv("MLX_RDMA_ENABLED") == "true" {
tunnelEnvVar := strings.Join(tunnels, ",")
os.Setenv("MLX_GRPC_SERVERS", tunnelEnvVar)
xlog.Debug("setting MLX_GRPC_SERVERS", "value", tunnelEnvVar)
}
}),
}

if r.DisableMetricsEndpoint {
Expand Down Expand Up @@ -305,4 +313,4 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
})

return appHTTP.Start(r.Address)
}
}
Loading