From 06c17eb506df1294d19ad69abe0aa465aa98e329 Mon Sep 17 00:00:00 2001 From: localai-bot <139863280+localai-bot@users.noreply.github.com> Date: Sun, 22 Feb 2026 20:11:29 +0000 Subject: [PATCH] feat: add RDMA support to MLX backend via mlx-jaccl-cluster integration --- backend/python/mlx/backend.py | 76 ++++++++++++++++++++++++++++++++--- core/cli/run.go | 10 ++++- 2 files changed, 79 insertions(+), 7 deletions(-) diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index aaa0d6f347f8..63417ab37756 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -41,6 +41,42 @@ def is_int(s): except ValueError: return False +def parse_rdma_options(options_list): + """ + Parse RDMA-specific options from the options list. + + Expected format: "mlx_rdma.enabled=true", "mlx_rdma.ctrl_host=0.0.0.0", etc. + Returns a dict if RDMA is enabled, None otherwise. + """ + rdma_config = {} + rdma_enabled = False + + for opt in options_list: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + + # Check if this is an RDMA option + if key.startswith("mlx_rdma."): + # Extract the actual option name (e.g., "mlx_rdma.enabled" -> "enabled") + option_name = key[len("mlx_rdma."):] + + # Convert value to appropriate type + if is_float(value): + value = float(value) + elif is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + + rdma_config[option_name] = value + + # Track if RDMA is explicitly enabled + if option_name == "enabled" and value: + rdma_enabled = True + + return rdma_config if rdma_enabled else None + # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ @@ -98,6 +134,30 @@ async def LoadModel(self, request, context): print(f"Options: {self.options}", file=sys.stderr) + # Parse RDMA options and initialize distributed training if enabled + rdma_config = parse_rdma_options(options) + if rdma_config: + print(f"Initializing MLX distributed training with RDMA: {rdma_config}", file=sys.stderr) + + # Set environment variables for mlx-jaccl-cluster if provided + if "ctrl_host" in rdma_config: + os.environ["JACCL_CTRL_HOST"] = str(rdma_config["ctrl_host"]) + if "ctrl_port" in rdma_config: + os.environ["JACCL_CTRL_PORT"] = str(rdma_config["ctrl_port"]) + + # Initialize MLX distributed backend (JACCL for RDMA) + try: + mx.distributed.init(backend="jaccl") + print(f"MLX distributed training initialized. Rank: {mx.distributed.rank()}, World size: {mx.distributed.world_size()}", file=sys.stderr) + self.use_rdma = True + self.rank = mx.distributed.rank() + self.world_size = mx.distributed.world_size() + except Exception as e: + print(f"Failed to initialize MLX distributed training: {e}", file=sys.stderr) + self.use_rdma = False + else: + self.use_rdma = False + # Build tokenizer config for MLX using options tokenizer_config = {} @@ -120,7 +180,15 @@ async def LoadModel(self, request, context): self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) else: self.model, self.tokenizer = load(request.Model) - + + # Shard model across RDMA workers if enabled + if self.use_rdma: + try: + self.model = self.model.shard(self.world_size) + print(f"Model sharded across {self.world_size} workers", file=sys.stderr) + except Exception as e: + print(f"Failed to shard model: {e}", file=sys.stderr) + # Initialize thread-safe LRU prompt cache for efficient generation max_cache_entries = self.options.get("max_cache_entries", 10) self.max_kv_size = self.options.get("max_kv_size", None) @@ -321,10 +389,6 @@ def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: return tokens.tolist() return list(tokens) - - - - def _build_generation_params(self, request, default_max_tokens=200): """ Build generation parameters from request attributes and options. @@ -447,4 +511,4 @@ async def serve(address): ) args = parser.parse_args() - asyncio.run(serve(args.addr)) + asyncio.run(serve(args.addr)) \ No newline at end of file diff --git a/core/cli/run.go b/core/cli/run.go index a67b35fadc41..cfc3632f1e73 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -145,6 +145,14 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar) }), + // Add MLX RDMA support if enabled via options + config.WithTunnelCallback(func(tunnels []string) { + if os.Getenv("MLX_RDMA_ENABLED") == "true" { + tunnelEnvVar := strings.Join(tunnels, ",") + os.Setenv("MLX_GRPC_SERVERS", tunnelEnvVar) + xlog.Debug("setting MLX_GRPC_SERVERS", "value", tunnelEnvVar) + } + }), } if r.DisableMetricsEndpoint { @@ -305,4 +313,4 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { }) return appHTTP.Start(r.Address) -} +} \ No newline at end of file