From a4f7cbd394ff6018dd7fc96502e1e377793655d3 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 11:26:52 -0500 Subject: [PATCH 01/21] Add torchrun DDP support for trainer pipelines --- comlrl/trainers/actor_critic/ac_base.py | 43 +++++++++++- comlrl/trainers/actor_critic/iac.py | 23 +++++-- comlrl/trainers/actor_critic/maac.py | 25 +++++-- comlrl/trainers/reinforce/magrpo.py | 38 +++++++++-- comlrl/utils/__init__.py | 16 +++++ comlrl/utils/distributed.py | 90 +++++++++++++++++++++++++ 6 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 comlrl/utils/distributed.py diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index 49eea91..2db57b8 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -5,6 +5,12 @@ import wandb from tqdm import tqdm # type: ignore from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +try: + from datasets import IterableDataset as HFIterableDataset +except Exception: # pragma: no cover + HFIterableDataset = None class ActorCriticTrainerBase: @@ -139,7 +145,9 @@ def _summarize_rollout_metrics(self, rollouts: List[Any]) -> Dict[str, float]: return metrics def _iter_dataloader(self, dataloader, epoch: int, total_epochs: int): - if getattr(self, "verbose", True): + dist_env = getattr(self, "dist_env", None) + is_main = bool(getattr(dist_env, "is_main", True)) + if getattr(self, "verbose", True) and is_main: return enumerate( tqdm( dataloader, @@ -165,7 +173,12 @@ def _on_epoch_end( epoch_metrics: Dict[str, List[float]], ) -> None: summary = self._summarize_epoch_metrics(epoch_metrics) - if summary and getattr(self, "verbose", True): + dist_env = getattr(self, "dist_env", None) + if ( + summary + and getattr(self, "verbose", True) + and getattr(dist_env, "is_main", True) + ): print(f"Epoch {epoch + 1}/{total_epochs} metrics: {summary}") def _tag_metrics( @@ -177,6 +190,9 @@ def _tag_metrics( def _log_metrics(self, metrics: Dict[str, float]) -> None: if not metrics: return + dist_env = getattr(self, "dist_env", None) + if dist_env is not None and dist_env.enabled and not dist_env.is_main: + return if self.wandb_initialized and wandb is not None: wandb.log(metrics, step=self.env_step) @@ -246,10 +262,27 @@ def _flush_buffers(self, epoch_metrics: Dict[str, List[float]]) -> None: def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Training requires a dataset.") + dist_env = getattr(self, "dist_env", None) + sampler = None + if ( + dist_env is not None + and dist_env.enabled + and ( + HFIterableDataset is None + or not isinstance(self.train_dataset, HFIterableDataset) + ) + ): + sampler = DistributedSampler( + self.train_dataset, + num_replicas=dist_env.world_size, + rank=dist_env.rank, + shuffle=False, + ) return DataLoader( self.train_dataset, batch_size=1, shuffle=False, + sampler=sampler, collate_fn=lambda batch: batch, ) @@ -264,6 +297,9 @@ def get_eval_dataloader(self) -> Optional[DataLoader]: ) def evaluate(self) -> Dict[str, float]: + dist_env = getattr(self, "dist_env", None) + if dist_env is not None and dist_env.enabled and not dist_env.is_main: + return {} if self.eval_dataset is None: return {} @@ -304,6 +340,9 @@ def train(self) -> None: for epoch in range(total_epochs): epoch_metrics = defaultdict(list) dataloader = self.get_train_dataloader() + sampler = getattr(dataloader, "sampler", None) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(epoch) it = self._iter_dataloader(dataloader, epoch, total_epochs) for batch_idx, batch in it: if ( diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 60d3e8c..1d6db9f 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -10,6 +10,7 @@ from datasets import Dataset, IterableDataset from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase from comlrl.models.actor_critic import CausalLMWithValueHead +from comlrl.utils.distributed import init_distributed, unwrap_model, wrap_ddp from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -161,8 +162,8 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} self.critic_type = (self.args.critic_type or "v").lower() - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dist_env = init_distributed() + self.device = self.dist_env.device self.agents: List[CausalLMWithValueHead] = [] self.critics: List[CausalLMWithValueHead] = [] @@ -274,6 +275,11 @@ def __init__( apply_tokenizer_specials(tok, models) else: apply_tokenizer_specials(self.tokenizer, [*self.agents, *self.critics]) + + if self.dist_env.enabled: + self.agents = [wrap_ddp(agent, self.dist_env) for agent in self.agents] + self.critics = [wrap_ddp(critic, self.dist_env) for critic in self.critics] + self.agent_optimizers = [] self.critic_optimizers = [] @@ -297,7 +303,7 @@ def __init__( self.wandb_config = wandb_config self.wandb_initialized = False self._last_train_log_step = -1 - if wandb_config is not None: + if wandb_config is not None and self.dist_env.is_main: self._init_wandb() self.verbose = True if isinstance(self.wandb_config, dict): @@ -308,6 +314,8 @@ def __init__( self.verbose = bool(out.get("verbose")) def _init_wandb(self) -> None: + if not self.dist_env.is_main: + return if self.wandb_initialized: return if wandb is None: @@ -1025,16 +1033,18 @@ def _update( return averaged def save_model(self, output_dir: str) -> None: + if self.dist_env.enabled and not self.dist_env.is_main: + return os.makedirs(output_dir, exist_ok=True) if self.args.num_agents == 1: - actor = self.agents[0] + actor = unwrap_model(self.agents[0]) actor.model.save_pretrained(output_dir) if actor.value_head is not None: torch.save( actor.value_head.state_dict(), os.path.join(output_dir, "value_head.pt"), ) - critic = self.critics[0] if self.critics else None + critic = unwrap_model(self.critics[0]) if self.critics else None if critic is not None: critic_dir = os.path.join(output_dir, "critic") os.makedirs(critic_dir, exist_ok=True) @@ -1046,6 +1056,7 @@ def save_model(self, output_dir: str) -> None: ) else: for agent_idx, actor in enumerate(self.agents): + actor = unwrap_model(actor) agent_dir = os.path.join(output_dir, f"agent_{agent_idx}") os.makedirs(agent_dir, exist_ok=True) actor.model.save_pretrained(agent_dir) @@ -1056,7 +1067,7 @@ def save_model(self, output_dir: str) -> None: ) if not self.critics or agent_idx >= len(self.critics): continue - critic = self.critics[agent_idx] + critic = unwrap_model(self.critics[agent_idx]) critic_dir = os.path.join(agent_dir, "critic") os.makedirs(critic_dir, exist_ok=True) critic.model.save_pretrained(critic_dir) diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 380c5b3..a904633 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -11,6 +11,7 @@ from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase import wandb from comlrl.models.actor_critic import CausalLMWithValueHead +from comlrl.utils.distributed import init_distributed, unwrap_model, wrap_ddp from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -135,7 +136,8 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dist_env = init_distributed() + self.device = self.dist_env.device tokenizers = resolve_tokenizers(agent_model, tokenizer, agents) if isinstance(tokenizers, list): @@ -234,6 +236,11 @@ def __init__( apply_tokenizer_specials(self.tokenizers[0], [self.critics[0]]) else: apply_tokenizer_specials(self.tokenizer, [*self.agents, self.critics[0]]) + + if self.dist_env.enabled: + self.agents = [wrap_ddp(agent, self.dist_env) for agent in self.agents] + self.critics = [wrap_ddp(self.critics[0], self.dist_env)] + self.formatters = build_formatters(formatters, self.args.num_agents) try: self._reward_signature = inspect.signature(reward_func) @@ -258,7 +265,7 @@ def __init__( self.wandb_initialized = False self.env_step = 0 self._last_train_log_step = -1 - if wandb_config is not None: + if wandb_config is not None and self.dist_env.is_main: self._init_wandb() self.verbose = True if isinstance(self.wandb_config, dict): @@ -269,6 +276,8 @@ def __init__( self.verbose = bool(out.get("verbose")) def _init_wandb(self) -> None: + if not self.dist_env.is_main: + return if self.wandb_config is None: self.wandb_config = {} wandb_project = self.wandb_config.get("project", "comlrl") @@ -948,22 +957,26 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: self._log_metrics(epoch_log) summary = self._summarize_epoch_metrics(epoch_metrics) - if summary and getattr(self, "verbose", True): + if summary and getattr(self, "verbose", True) and self.dist_env.is_main: to_print = epoch_log if epoch_log else summary print(f"Epoch {epoch + 1}/{total_epochs} metrics: {to_print}") def save_model(self, output_dir: str) -> None: + if self.dist_env.enabled and not self.dist_env.is_main: + return os.makedirs(output_dir, exist_ok=True) for agent_idx, actor in enumerate(self.agents): + actor = unwrap_model(actor) agent_dir = os.path.join(output_dir, f"agent_{agent_idx}") os.makedirs(agent_dir, exist_ok=True) actor.model.save_pretrained(agent_dir) critic_dir = os.path.join(output_dir, "critic") os.makedirs(critic_dir, exist_ok=True) - self.critics[0].model.save_pretrained(critic_dir) - if self.critics[0].value_head is not None: + critic = unwrap_model(self.critics[0]) + critic.model.save_pretrained(critic_dir) + if critic.value_head is not None: torch.save( - self.critics[0].value_head.state_dict(), + critic.value_head.state_dict(), os.path.join(critic_dir, "value_head.pt"), ) if self.tokenizers: diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index d6289b4..64382a8 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -10,9 +10,11 @@ import wandb from datasets import Dataset, IterableDataset from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizerBase +from comlrl.utils.distributed import init_distributed, unwrap_model, wrap_ddp from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import infer_model_name, resolve_model_sources from comlrl.utils.reward_utils import call_reward_function @@ -144,7 +146,8 @@ def __init__( eval_aggregator: Optional[Callable] = None, args: Optional[MAGRPOConfig] = None, ): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dist_env = init_distributed() + self.device = self.dist_env.device self.args = args if args is not None else self.default_config_cls() if agent_model is None and agents is None: @@ -249,6 +252,9 @@ def __init__( self.eval_aggregator = eval_aggregator self.external_transition = external_transition + if self.dist_env.enabled: + self.agents = [wrap_ddp(agent, self.dist_env) for agent in self.agents] + self.optimizers = [ torch.optim.AdamW( agent.parameters(), @@ -259,7 +265,7 @@ def __init__( self.wandb_config = wandb_config self.wandb_initialized = False - if self.wandb_config is not None: + if self.wandb_config is not None and self.dist_env.is_main: self._init_wandb() self.dataset_type = dataset_type or None @@ -285,6 +291,8 @@ def __init__( def _init_wandb(self): """Initialize Weights & Biases for tracking with multi-turn config.""" + if not self.dist_env.is_main: + return if not self.wandb_initialized: if self.wandb_config is None: self.wandb_config = {} @@ -398,12 +406,23 @@ def get_train_dataloader(self) -> DataLoader: """Returns the training DataLoader.""" if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") + sampler = None + if self.dist_env.enabled and not isinstance( + self.train_dataset, IterableDataset + ): + sampler = DistributedSampler( + self.train_dataset, + num_replicas=self.dist_env.world_size, + rank=self.dist_env.rank, + shuffle=False, + ) return DataLoader( self.train_dataset, batch_size=1, collate_fn=lambda examples: examples, shuffle=False, + sampler=sampler, drop_last=False, num_workers=0, ) @@ -432,6 +451,8 @@ def evaluate(self, num_eval_samples: int = 4) -> Dict[str, float]: Returns: Dictionary containing evaluation metrics """ + if self.dist_env.enabled and not self.dist_env.is_main: + return {} if self.eval_dataset is None: return {} @@ -667,7 +688,10 @@ def train(self, **kwargs): ] # immediate rewards epoch_turn_returns = [[] for _ in range(self.args.num_turns)] # returns dl = self.get_train_dataloader() - if getattr(self, "verbose", True): + sampler = getattr(dl, "sampler", None) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(epoch) + if getattr(self, "verbose", True) and self.dist_env.is_main: it = enumerate( tqdm( dl, @@ -1041,7 +1065,8 @@ def _generate_completions( Returns: Dict: A dictionary containing generated completions and associated data """ - device = agent.device + agent_module = unwrap_model(agent) + device = next(agent_module.parameters()).device # Apply the appropriate formatter to create prompts from batch items if prompts_override is not None: @@ -1070,7 +1095,7 @@ def _generate_completions( if self.tokenizer is None: raise ValueError("Tokenizer is required for generating completions") tokenizer = self.tokenizers[agent_idx] - apply_tokenizer_specials(tokenizer, [agent]) + apply_tokenizer_specials(tokenizer, [agent_module]) pad_id = tokenizer.pad_token_id prompt_encodings = tokenizer( @@ -1499,9 +1524,12 @@ def save_model(self, output_dir): Args: output_dir: Directory to save the models to """ + if self.dist_env.enabled and not self.dist_env.is_main: + return os.makedirs(output_dir, exist_ok=True) for agent_idx, agent in enumerate(self.agents): + agent = unwrap_model(agent) agent_dir = f"{output_dir}/agent_{agent_idx}" os.makedirs(agent_dir, exist_ok=True) diff --git a/comlrl/utils/__init__.py b/comlrl/utils/__init__.py index dcc9bf6..4494d4a 100644 --- a/comlrl/utils/__init__.py +++ b/comlrl/utils/__init__.py @@ -2,6 +2,15 @@ from .model_loading import infer_model_name, resolve_model_sources from .reward_processor import RewardProcessors from .reward_utils import call_reward_function, normalize_reward_lengths +from .distributed import ( + DistributedContext, + all_gather_objects, + barrier, + init_distributed, + is_main_process, + unwrap_model, + wrap_ddp, +) from .tokenizer_utils import ( apply_tokenizer_specials, ensure_pad_token, @@ -22,4 +31,11 @@ "RewardProcessors", "call_reward_function", "normalize_reward_lengths", + "DistributedContext", + "init_distributed", + "wrap_ddp", + "unwrap_model", + "is_main_process", + "barrier", + "all_gather_objects", ] diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py new file mode 100644 index 0000000..4ac62aa --- /dev/null +++ b/comlrl/utils/distributed.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, List, Optional + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +@dataclass(frozen=True) +class DistributedContext: + enabled: bool + rank: int + world_size: int + local_rank: int + is_main: bool + device: torch.device + + +def init_distributed(backend: Optional[str] = None) -> DistributedContext: + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", str(rank))) + enabled = world_size > 1 + + if torch.cuda.is_available(): + if enabled: + device_count = max(1, torch.cuda.device_count()) + local_rank = local_rank % device_count + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}" if enabled else "cuda") + else: + device = torch.device("cpu") + + if enabled and not dist.is_initialized(): + backend_name = backend or ("nccl" if device.type == "cuda" else "gloo") + dist.init_process_group(backend=backend_name, rank=rank, world_size=world_size) + + return DistributedContext( + enabled=enabled, + rank=rank, + world_size=world_size, + local_rank=local_rank, + is_main=(rank == 0), + device=device, + ) + + +def wrap_ddp( + model: torch.nn.Module, + ctx: DistributedContext, + *, + find_unused_parameters: bool = False, +) -> torch.nn.Module: + if not ctx.enabled: + return model + if isinstance(model, DDP): + return model + kwargs = { + "find_unused_parameters": find_unused_parameters, + } + if ctx.device.type == "cuda": + kwargs["device_ids"] = [ctx.local_rank] + kwargs["output_device"] = ctx.local_rank + return DDP(model, **kwargs) + + +def unwrap_model(model: Any) -> Any: + return model.module if isinstance(model, DDP) else model + + +def is_main_process(ctx: Optional[DistributedContext]) -> bool: + if ctx is None: + return True + return bool(ctx.is_main) + + +def barrier(ctx: Optional[DistributedContext]) -> None: + if ctx is not None and ctx.enabled and dist.is_initialized(): + dist.barrier() + + +def all_gather_objects(obj: Any, ctx: Optional[DistributedContext]) -> List[Any]: + if ctx is None or not ctx.enabled: + return [obj] + gathered: List[Any] = [None for _ in range(ctx.world_size)] + dist.all_gather_object(gathered, obj) + return gathered From c48a92e14ecdd1fbc117cae9c388e170d7e1dc31 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 13:39:13 -0500 Subject: [PATCH 02/21] support dual parallel modes with torchrun and device schedulers --- comlrl/schedulers/__init__.py | 4 + comlrl/schedulers/device_scheduler.py | 102 ++++++++++++++++++++++ comlrl/schedulers/torchrun_scheduler.py | 43 ++++++++++ comlrl/trainers/actor_critic/ac_base.py | 6 +- comlrl/trainers/actor_critic/iac.py | 107 ++++++++++++++++++------ comlrl/trainers/actor_critic/maac.py | 98 +++++++++++++++++----- comlrl/trainers/reinforce/magrpo.py | 49 +++++++++-- comlrl/utils/__init__.py | 6 ++ comlrl/utils/distributed.py | 30 +++++++ 9 files changed, 388 insertions(+), 57 deletions(-) create mode 100644 comlrl/schedulers/__init__.py create mode 100644 comlrl/schedulers/device_scheduler.py create mode 100644 comlrl/schedulers/torchrun_scheduler.py diff --git a/comlrl/schedulers/__init__.py b/comlrl/schedulers/__init__.py new file mode 100644 index 0000000..69af71e --- /dev/null +++ b/comlrl/schedulers/__init__.py @@ -0,0 +1,4 @@ +from .device_scheduler import DeviceScheduler +from .torchrun_scheduler import TorchrunScheduler + +__all__ = ["DeviceScheduler", "TorchrunScheduler"] diff --git a/comlrl/schedulers/device_scheduler.py b/comlrl/schedulers/device_scheduler.py new file mode 100644 index 0000000..bab08f3 --- /dev/null +++ b/comlrl/schedulers/device_scheduler.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Iterable, List, Optional, Sequence, Tuple, Union + +import torch + +DeviceSpec = Union[str, Sequence[str]] + + +class DeviceScheduler: + @staticmethod + def assign_devices( + num_agents: int, + agent_devices: Optional[DeviceSpec], + critic_devices: Optional[DeviceSpec], + *, + use_separate_critic: bool, + ) -> Tuple[List[torch.device], List[torch.device]]: + agent_list = DeviceScheduler.resolve_devices( + agent_devices, num_agents, kind="agent_devices" + ) + if use_separate_critic: + critic_spec = ( + critic_devices if critic_devices is not None else agent_devices + ) + critic_list = DeviceScheduler.resolve_devices( + critic_spec, num_agents, kind="critic_devices" + ) + else: + critic_list = list(agent_list) + return agent_list, critic_list + + @staticmethod + def assign_shared_critic_device( + agent_devices: Sequence[torch.device], + critic_devices: Optional[DeviceSpec], + ) -> torch.device: + if critic_devices is None: + return agent_devices[0] + return DeviceScheduler.resolve_devices( + critic_devices, 1, kind="critic_devices" + )[0] + + @staticmethod + def resolve_devices( + spec: Optional[DeviceSpec], + num_devices: int, + *, + kind: str = "devices", + ) -> List[torch.device]: + if num_devices < 1: + raise ValueError(f"{kind} count must be >= 1.") + + if spec is None or (isinstance(spec, str) and spec.lower() == "auto"): + return DeviceScheduler._auto_devices(num_devices) + + if isinstance(spec, str): + return [torch.device(spec)] * num_devices + + if isinstance(spec, Sequence): + if len(spec) == 0: + raise ValueError(f"{kind} must be a non-empty list or 'auto'.") + if len(spec) == 1: + return [torch.device(spec[0])] * num_devices + if len(spec) != num_devices: + raise ValueError( + f"{kind} length ({len(spec)}) must be 1 or {num_devices}." + ) + return [torch.device(s) for s in spec] + + raise ValueError(f"Unsupported {kind} spec: {spec!r}.") + + @staticmethod + def devices_disjoint(device_groups: Iterable[Sequence[torch.device]]) -> bool: + seen = set() + for group in device_groups: + for device in group: + key = DeviceScheduler._device_key(device) + if key in seen: + return False + seen.add(key) + return True + + @staticmethod + def _device_key(device: torch.device) -> str: + if device.type == "cuda": + return f"cuda:{device.index}" + return device.type + + @staticmethod + def _auto_devices(num_devices: int) -> List[torch.device]: + if not torch.cuda.is_available(): + return [torch.device("cpu")] * num_devices + + count = int(torch.cuda.device_count()) + if count < 1: + return [torch.device("cpu")] * num_devices + if count < num_devices: + return [torch.device("cuda:0")] * num_devices + + indices = list(range(count)) + return [torch.device(f"cuda:{idx}") for idx in indices[:num_devices]] diff --git a/comlrl/schedulers/torchrun_scheduler.py b/comlrl/schedulers/torchrun_scheduler.py new file mode 100644 index 0000000..f938136 --- /dev/null +++ b/comlrl/schedulers/torchrun_scheduler.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os +from typing import Optional + +import torch + +from comlrl.utils.distributed import DistributedContext, init_distributed, local_context + + +class TorchrunScheduler: + """Resolve and initialize process-level parallel execution mode.""" + + _VALID_MODES = {"auto", "ddp", "scheduler"} + + @staticmethod + def world_size_from_env() -> int: + return int(os.environ.get("WORLD_SIZE", "1")) + + @classmethod + def resolve_mode(cls, requested_mode: Optional[str]) -> str: + mode = str(requested_mode or "auto").strip().lower() + if mode not in cls._VALID_MODES: + raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") + + world_size = cls.world_size_from_env() + if mode == "auto": + return "ddp" if world_size > 1 else "scheduler" + if mode == "scheduler" and world_size > 1: + raise ValueError( + "parallel_mode='scheduler' requires WORLD_SIZE=1 (single process)." + ) + return mode + + @staticmethod + def ddp_context() -> DistributedContext: + return init_distributed() + + @staticmethod + def scheduler_context( + device: Optional[torch.device] = None, + ) -> DistributedContext: + return local_context(device) diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index 2db57b8..cf181d6 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -70,6 +70,7 @@ def _encode_prompt( prompt: str, agent_idx: Optional[int] = None, tokenizer: Optional[Any] = None, + device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: tokenizer = tokenizer or self._get_tokenizer(agent_idx) encoded = tokenizer( @@ -77,9 +78,10 @@ def _encode_prompt( return_tensors="pt", truncation=True, ) + target_device = device or self.device return { - "input_ids": encoded["input_ids"].to(self.device), - "attention_mask": encoded["attention_mask"].to(self.device), + "input_ids": encoded["input_ids"].to(target_device), + "attention_mask": encoded["attention_mask"].to(target_device), } def _prepare_advantages(self, rollouts: List[Any]) -> None: diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 1d6db9f..b88b4a6 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -10,7 +10,11 @@ from datasets import Dataset, IterableDataset from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase from comlrl.models.actor_critic import CausalLMWithValueHead -from comlrl.utils.distributed import init_distributed, unwrap_model, wrap_ddp +from comlrl.schedulers import DeviceScheduler, TorchrunScheduler +from comlrl.utils.distributed import ( + unwrap_model, + wrap_ddp, +) from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -46,6 +50,9 @@ class IACConfig: value_head_hidden_dim: Optional[int] = None num_agents: int = 2 num_turns: int = 2 + parallel_mode: str = "auto" + agent_devices: Optional[Union[str, Sequence[str]]] = None + critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False discount: float = 0.9 num_generations: int = 1 @@ -83,6 +90,9 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") + mode = str(self.parallel_mode or "auto").lower() + if mode not in {"auto", "ddp", "scheduler"}: + raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") @dataclass @@ -162,8 +172,30 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} self.critic_type = (self.args.critic_type or "v").lower() - self.dist_env = init_distributed() - self.device = self.dist_env.device + self.parallel_mode = TorchrunScheduler.resolve_mode( + getattr(self.args, "parallel_mode", "auto") + ) + if self.parallel_mode == "ddp": + if ( + getattr(self.args, "agent_devices", None) is not None + or getattr(self.args, "critic_devices", None) is not None + ): + raise ValueError( + "agent_devices/critic_devices are only valid in parallel_mode='scheduler'." + ) + self.dist_env = TorchrunScheduler.ddp_context() + self.device = self.dist_env.device + self.agent_devices = [self.device] * self.args.num_agents + self.critic_devices = [self.device] * self.args.num_agents + else: + self.agent_devices, self.critic_devices = DeviceScheduler.assign_devices( + self.args.num_agents, + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + use_separate_critic=self.args.use_separate_critic, + ) + self.device = self.agent_devices[0] + self.dist_env = TorchrunScheduler.scheduler_context(self.device) self.agents: List[CausalLMWithValueHead] = [] self.critics: List[CausalLMWithValueHead] = [] @@ -189,7 +221,7 @@ def __init__( expected_count=self.args.num_agents, model_label="agent_model", ) - for actor_source in actor_sources: + for idx, actor_source in enumerate(actor_sources): if actor_source is None: raise ValueError("A policy model identifier or instance is required.") if isinstance(actor_source, CausalLMWithValueHead): @@ -220,7 +252,7 @@ def __init__( value_head_hidden_dim=self.args.value_head_hidden_dim, attach_value_head=attach_value, ) - agent_model.to(self.device) + agent_model.to(self.agent_devices[idx]) self.agents.append(agent_model) if self.args.use_separate_critic: @@ -235,7 +267,7 @@ def __init__( expected_count=self.args.num_agents, model_label="critic_model", ) - for critic_source in critic_sources: + for idx, critic_source in enumerate(critic_sources): if isinstance(critic_source, CausalLMWithValueHead): critic_model = critic_source elif isinstance(critic_source, PreTrainedModel): @@ -262,7 +294,7 @@ def __init__( value_head_hidden_dim=self.args.critic_value_head_hidden_dim, attach_value_head=True, ) - critic_model.to(self.device) + critic_model.to(self.critic_devices[idx]) self.critics.append(critic_model) else: self.critics = [] @@ -313,6 +345,14 @@ def __init__( if isinstance(out, dict) and "verbose" in out: self.verbose = bool(out.get("verbose")) + def _agent_device(self, agent_idx: int) -> torch.device: + return self.agent_devices[agent_idx] + + def _critic_device(self, agent_idx: int) -> torch.device: + if self.args.use_separate_critic: + return self.critic_devices[agent_idx] + return self.agent_devices[agent_idx] + def _init_wandb(self) -> None: if not self.dist_env.is_main: return @@ -424,7 +464,10 @@ def _generate_rollout( agent_idx: int, num_ret: int, ) -> Dict[str, Any]: - encoded_prompt = self._encode_prompt(prompt, agent_idx=agent_idx) + agent_device = self._agent_device(agent_idx) + encoded_prompt = self._encode_prompt( + prompt, agent_idx=agent_idx, device=agent_device + ) prompt_input_ids = encoded_prompt["input_ids"] prompt_attention_mask = encoded_prompt["attention_mask"] prompt_len = prompt_input_ids.size(1) @@ -461,17 +504,20 @@ def _generate_rollout( tokenizer.decode(seq[:resp_len], skip_special_tokens=True) ) - full_attention_mask = torch.ones_like(sequences, device=self.device) + full_attention_mask = torch.ones_like(sequences, device=agent_device) with torch.no_grad(): if self.args.use_separate_critic: critic_model = self.critics[agent_idx] if critic_model is None: raise RuntimeError("Critic model missing for agent.") + critic_device = self._critic_device(agent_idx) + critic_sequences = sequences.to(critic_device) + critic_mask = full_attention_mask.to(critic_device) value = self._value_for_critic_type( critic_model, - sequences, - full_attention_mask, + critic_sequences, + critic_mask, prompt_len, response_lens, ) @@ -564,7 +610,7 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: value = data["values"][i] reward = float(rewards_matrix[agent_idx][i]) reward_tensor = torch.tensor( - [reward], device=self.device, dtype=torch.float32 + [reward], device=self._agent_device(agent_idx), dtype=torch.float32 ) returns = reward_tensor.clone() advantage = returns - value @@ -675,7 +721,9 @@ def _collect_rollouts_multi_turn( value = data["values"][0] reward_val = float(rewards_matrix[agent_idx][0]) reward_tensor = torch.tensor( - [reward_val], device=self.device, dtype=torch.float32 + [reward_val], + device=self._agent_device(agent_idx), + dtype=torch.float32, ) completion_text = data["completions"][0] @@ -727,7 +775,9 @@ def _collect_rollouts_multi_turn( immediate = float(sample.reward.view(-1)[0].item()) future = immediate + gamma * future sample.returns = ( - torch.tensor([future], device=self.device).detach().cpu() + torch.tensor([future], device=self._agent_device(agent_idx)) + .detach() + .cpu() ) sample.advantage = torch.zeros_like(sample.returns) sample.normalized_advantage = None @@ -887,10 +937,12 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa actor_losses: List[torch.Tensor] = [] value_losses: List[torch.Tensor] = [] + agent_device = self._agent_device(agent_idx) + critic_device = self._critic_device(agent_idx) for sample in batch: - sequences = sample.full_input_ids.to(self.device).unsqueeze(0) - attention_mask = sample.attention_mask.to(self.device).unsqueeze(0) + sequences = sample.full_input_ids.to(agent_device).unsqueeze(0) + attention_mask = sample.attention_mask.to(agent_device).unsqueeze(0) # Policy log-prob uses full sequence; value uses prompt-only baseline. logprob, _ = self._policy_eval( @@ -904,10 +956,12 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa if self.args.use_separate_critic: if critic_model is None: raise RuntimeError("Critic model not initialised.") + critic_sequences = sequences.to(critic_device) + critic_attention_mask = attention_mask.to(critic_device) value = self._critic_eval( critic_model, - sequences, - attention_mask, + critic_sequences, + critic_attention_mask, sample.prompt_len, sample.response_len, ) @@ -920,10 +974,13 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa [sample.response_len], ) - old_value = sample.old_value.to(self.device, dtype=value.dtype) - old_logprob = sample.old_logprob.to(self.device) - advantage = sample.normalized_advantage.to(self.device, dtype=value.dtype) - returns = sample.returns.to(self.device, dtype=value.dtype) + value_device = value.device + old_value = sample.old_value.to(value_device, dtype=value.dtype) + old_logprob = sample.old_logprob.to(agent_device) + policy_advantage = sample.normalized_advantage.to( + agent_device, dtype=logprob.dtype + ) + returns = sample.returns.to(value_device, dtype=value.dtype) if ( not torch.isfinite(logprob).all() @@ -932,17 +989,17 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa raise FloatingPointError( "Encountered non-finite logprob during AC step." ) - if not torch.isfinite(advantage).all(): + if not torch.isfinite(policy_advantage).all(): raise FloatingPointError("Advantage contains non-finite values.") if not torch.isfinite(returns).all(): raise FloatingPointError("Returns contain non-finite values.") - policy_loss = -(logprob * advantage) + policy_loss = -(logprob * policy_advantage) value_target = sample.metadata.get("value_target") if value_target is None: raise RuntimeError("value_target missing for critic update.") - value_target = value_target.to(self.device, dtype=value.dtype) + value_target = value_target.to(value_device, dtype=value.dtype) if ( self.args.value_clip_range is not None and not self.args.use_separate_critic diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index a904633..cadc340 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -11,7 +11,11 @@ from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase import wandb from comlrl.models.actor_critic import CausalLMWithValueHead -from comlrl.utils.distributed import init_distributed, unwrap_model, wrap_ddp +from comlrl.schedulers import DeviceScheduler, TorchrunScheduler +from comlrl.utils.distributed import ( + unwrap_model, + wrap_ddp, +) from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -42,6 +46,9 @@ class MAACConfig: num_agents: int = 2 num_generations: int = 1 num_turns: int = 2 + parallel_mode: str = "auto" + agent_devices: Optional[Union[str, Sequence[str]]] = None + critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False discount: float = 0.9 critic_type: str = "v" # "v" (V(s)) or "q" (Q(s,a)) @@ -77,6 +84,9 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") + mode = str(self.parallel_mode or "auto").lower() + if mode not in {"auto", "ddp", "scheduler"}: + raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") class MAACTrainer(ActorCriticTrainerBase): @@ -135,9 +145,32 @@ def __init__( self.eval_dataset = eval_dataset self.metrics_callback = metrics_callback self.model_config = model_config or {} - - self.dist_env = init_distributed() - self.device = self.dist_env.device + self.parallel_mode = TorchrunScheduler.resolve_mode( + getattr(self.args, "parallel_mode", "auto") + ) + if self.parallel_mode == "ddp": + if ( + getattr(self.args, "agent_devices", None) is not None + or getattr(self.args, "critic_devices", None) is not None + ): + raise ValueError( + "agent_devices/critic_devices are only valid in parallel_mode='scheduler'." + ) + self.dist_env = TorchrunScheduler.ddp_context() + self.device = self.dist_env.device + self.agent_devices = [self.device] * self.args.num_agents + self.critic_device = self.device + else: + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.args.num_agents, + kind="agent_devices", + ) + self.critic_device = DeviceScheduler.assign_shared_critic_device( + self.agent_devices, getattr(self.args, "critic_devices", None) + ) + self.device = self.agent_devices[0] + self.dist_env = TorchrunScheduler.scheduler_context(self.device) tokenizers = resolve_tokenizers(agent_model, tokenizer, agents) if isinstance(tokenizers, list): @@ -156,7 +189,7 @@ def __init__( expected_count=self.args.num_agents, model_label="agent_model", ) - for actor_source in actor_sources: + for idx, actor_source in enumerate(actor_sources): if actor_source is None: raise ValueError("agent_model must be provided for MAAC.") if isinstance(actor_source, CausalLMWithValueHead): @@ -185,7 +218,7 @@ def __init__( attach_value_head=False, value_head_hidden_dim=None, ) - agent_model.to(self.device) + agent_model.to(self.agent_devices[idx]) self.agents.append(agent_model) critic_sources, _critic_name = resolve_model_sources( @@ -227,7 +260,7 @@ def __init__( "critic_value_head_hidden_dim" ), ) - critic_model_instance.to(self.device) + critic_model_instance.to(self.critic_device) self.critics: List[CausalLMWithValueHead] = [critic_model_instance] if self.tokenizers and len(self.tokenizers) == len(self.agents): @@ -275,6 +308,9 @@ def __init__( if isinstance(out, dict) and "verbose" in out: self.verbose = bool(out.get("verbose")) + def _agent_device(self, agent_idx: int) -> torch.device: + return self.agent_devices[agent_idx] + def _init_wandb(self) -> None: if not self.dist_env.is_main: return @@ -386,7 +422,11 @@ def _build_critic_input( return base + "\n\n" + "\n\n".join(action_lines) def _critic_value_from_text(self, critic_input: str) -> Dict[str, Any]: - encoded = self._encode_prompt(critic_input, tokenizer=self._get_tokenizer(0)) + encoded = self._encode_prompt( + critic_input, + tokenizer=self._get_tokenizer(0), + device=self.critic_device, + ) ids = encoded["input_ids"] mask = encoded["attention_mask"] prompt_len = ids.size(1) @@ -400,7 +440,10 @@ def _critic_value_from_text(self, critic_input: str) -> Dict[str, Any]: } def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: - encoded_prompt = self._encode_prompt(prompt, agent_idx=agent_idx) + agent_device = self._agent_device(agent_idx) + encoded_prompt = self._encode_prompt( + prompt, agent_idx=agent_idx, device=agent_device + ) prompt_input_ids = encoded_prompt["input_ids"] prompt_attention_mask = encoded_prompt["attention_mask"] prompt_len = prompt_input_ids.size(1) @@ -442,7 +485,7 @@ def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: "prompt": prompt, "prompt_len": prompt_len, "sequences": sequences, - "attention_mask": torch.ones_like(sequences, device=self.device), + "attention_mask": torch.ones_like(sequences, device=agent_device), "response_lens": response_lens, "completions": completion_texts, } @@ -525,7 +568,9 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: attn = data["attention_mask"][i] resp_len = data["response_lens"][i] reward = float(rewards_matrix[agent_idx][i]) - reward_tensor = torch.tensor([reward], device=self.device) + reward_tensor = torch.tensor( + [reward], device=self._agent_device(agent_idx) + ) logprob, _ = self._policy_eval( self.agents[agent_idx], @@ -680,7 +725,9 @@ def _collect_rollouts_multi_turn( attn = data["attention_mask"][0] resp_len = data["response_lens"][0] reward_val = float(rewards_matrix[agent_idx][0]) - reward_tensor = torch.tensor([reward_val], device=self.device) + reward_tensor = torch.tensor( + [reward_val], device=self._agent_device(agent_idx) + ) logprob, _ = self._policy_eval( self.agents[agent_idx], @@ -743,7 +790,9 @@ def _collect_rollouts_multi_turn( immediate = float(sample.reward.view(-1)[0].item()) future = immediate + gamma * future sample.returns = ( - torch.tensor([future], device=self.device).detach().cpu() + torch.tensor([future], device=self._agent_device(agent_idx)) + .detach() + .cpu() ) sample.advantage = torch.zeros_like(sample.returns) sample.normalized_advantage = None @@ -844,10 +893,12 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa actor_losses: List[torch.Tensor] = [] value_losses: List[torch.Tensor] = [] + agent_device = self._agent_device(agent_idx) + critic_device = self.critic_device for sample in batch: - sequences = sample.full_input_ids.to(self.device).unsqueeze(0) - attention_mask = sample.attention_mask.to(self.device).unsqueeze(0) + sequences = sample.full_input_ids.to(agent_device).unsqueeze(0) + attention_mask = sample.attention_mask.to(agent_device).unsqueeze(0) logprob, _ = self._policy_eval( agent_model, @@ -858,30 +909,33 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa output_values=False, ) - joint_ids = sample.metadata["joint_input_ids"].to(self.device) - joint_mask = sample.metadata["joint_attention_mask"].to(self.device) + joint_ids = sample.metadata["joint_input_ids"].to(critic_device) + joint_mask = sample.metadata["joint_attention_mask"].to(critic_device) joint_len = int(sample.metadata["joint_prompt_len"]) value = self._value_on_prompt_only( self.critics[0], joint_ids, joint_mask, joint_len ) - old_value = sample.old_value.to(self.device, dtype=value.dtype) - advantage = sample.normalized_advantage.to(self.device, dtype=value.dtype) + value_device = value.device + old_value = sample.old_value.to(value_device, dtype=value.dtype) + policy_advantage = sample.normalized_advantage.to( + agent_device, dtype=logprob.dtype + ) value_target = sample.metadata.get("value_target") if value_target is None: raise RuntimeError("value_target missing for critic update.") - returns = value_target.to(self.device, dtype=value.dtype) + returns = value_target.to(value_device, dtype=value.dtype) if not torch.isfinite(logprob).all(): raise FloatingPointError( "Encountered non-finite logprob during AC step." ) - if not torch.isfinite(advantage).all(): + if not torch.isfinite(policy_advantage).all(): raise FloatingPointError("Advantage contains non-finite values.") if not torch.isfinite(returns).all(): raise FloatingPointError("Returns contain non-finite values.") - policy_loss = -(logprob * advantage) + policy_loss = -(logprob * policy_advantage) value_error = (returns - value) ** 2 actor_losses.append(policy_loss) diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 64382a8..d222068 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -14,7 +14,11 @@ from tqdm import tqdm # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizerBase -from comlrl.utils.distributed import init_distributed, unwrap_model, wrap_ddp +from comlrl.schedulers import DeviceScheduler, TorchrunScheduler +from comlrl.utils.distributed import ( + unwrap_model, + wrap_ddp, +) from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import infer_model_name, resolve_model_sources from comlrl.utils.reward_utils import call_reward_function @@ -33,6 +37,8 @@ class MAGRPOConfig: agent_learning_rate: float = 5.0e-6 logging_steps: int = 50 num_agents: int = 2 + parallel_mode: str = "auto" + agent_devices: Optional[Union[str, Sequence[str]]] = None # Sampling/generation num_generations: int = 4 @@ -81,6 +87,9 @@ def __post_init__(self) -> None: self.train_batch_size = self.rollout_buffer_size if self.train_batch_size < 1: raise ValueError("train_batch_size must be >= 1.") + mode = str(self.parallel_mode or "auto").lower() + if mode not in {"auto", "ddp", "scheduler"}: + raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") @dataclass @@ -146,10 +155,21 @@ def __init__( eval_aggregator: Optional[Callable] = None, args: Optional[MAGRPOConfig] = None, ): - self.dist_env = init_distributed() - self.device = self.dist_env.device - self.args = args if args is not None else self.default_config_cls() + self.parallel_mode = TorchrunScheduler.resolve_mode( + getattr(self.args, "parallel_mode", "auto") + ) + if self.parallel_mode == "ddp": + if getattr(self.args, "agent_devices", None) is not None: + raise ValueError( + "agent_devices is only valid in parallel_mode='scheduler'." + ) + self.dist_env = TorchrunScheduler.ddp_context() + self.device = self.dist_env.device + else: + self.dist_env = TorchrunScheduler.scheduler_context() + self.device = self.dist_env.device + if agent_model is None and agents is None: raise ValueError("Either agent_model or agents must be provided.") if ( @@ -194,6 +214,16 @@ def __init__( self.num_agents = expected_count self.model_name = model_name + if self.parallel_mode == "ddp": + self.agent_devices = [self.device] * self.num_agents + else: + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.num_agents, + kind="agent_devices", + ) + self.device = self.agent_devices[0] + self.dist_env = TorchrunScheduler.scheduler_context(self.device) if actor_sources and all(isinstance(src, str) for src in actor_sources): from transformers import AutoModelForCausalLM @@ -213,6 +243,9 @@ def __init__( self.agents = list(actor_sources) self.critics = [] + for idx, agent in enumerate(self.agents): + agent.to(self.agent_devices[idx]) + tokenizers = resolve_tokenizers(agent_model, tokenizer, actor_sources) if isinstance(tokenizers, list): self.tokenizers = tokenizers @@ -673,9 +706,8 @@ def train(self, **kwargs): if self.wandb_config is not None and not self.wandb_initialized: self._init_wandb() - device = self.device - for agent in self.agents: - agent.to(device) + for agent_idx, agent in enumerate(self.agents): + agent.to(self.agent_devices[agent_idx]) agent.train() # Create the data pipeline for generating examples @@ -1429,7 +1461,8 @@ def _compute_loss_with_gradients(self, agent, completions_data, returns): Returns: torch.Tensor: The computed loss with gradients attached """ - device = agent.device + agent_module = unwrap_model(agent) + device = next(agent_module.parameters()).device # Make sure we have the correct number of rewards if len(returns) == 0: diff --git a/comlrl/utils/__init__.py b/comlrl/utils/__init__.py index 4494d4a..6272ff5 100644 --- a/comlrl/utils/__init__.py +++ b/comlrl/utils/__init__.py @@ -8,7 +8,10 @@ barrier, init_distributed, is_main_process, + local_context, + resolve_parallel_mode, unwrap_model, + world_size_from_env, wrap_ddp, ) from .tokenizer_utils import ( @@ -33,6 +36,9 @@ "normalize_reward_lengths", "DistributedContext", "init_distributed", + "resolve_parallel_mode", + "world_size_from_env", + "local_context", "wrap_ddp", "unwrap_model", "is_main_process", diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py index 4ac62aa..bc5f53b 100644 --- a/comlrl/utils/distributed.py +++ b/comlrl/utils/distributed.py @@ -19,6 +19,36 @@ class DistributedContext: device: torch.device +def world_size_from_env() -> int: + # Backward-compatible shim; prefer comlrl.schedulers.TorchrunScheduler. + from comlrl.schedulers import TorchrunScheduler + + return TorchrunScheduler.world_size_from_env() + + +def resolve_parallel_mode(requested_mode: Optional[str]) -> str: + # Backward-compatible shim; prefer comlrl.schedulers.TorchrunScheduler. + from comlrl.schedulers import TorchrunScheduler + + return TorchrunScheduler.resolve_mode(requested_mode) + + +def local_context(device: Optional[torch.device] = None) -> DistributedContext: + if device is None: + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + return DistributedContext( + enabled=False, + rank=0, + world_size=1, + local_rank=0, + is_main=True, + device=device, + ) + + def init_distributed(backend: Optional[str] = None) -> DistributedContext: world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) From 3e3f64801c2a713a9767dd50735134b04d3f4993 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 14:02:18 -0500 Subject: [PATCH 03/21] rename parallel_training key and remove legacy parallel wrappers --- comlrl/schedulers/torchrun_scheduler.py | 4 ++-- comlrl/trainers/actor_critic/iac.py | 14 +++++++------- comlrl/trainers/actor_critic/maac.py | 14 +++++++------- comlrl/trainers/reinforce/magrpo.py | 16 ++++++++-------- comlrl/utils/__init__.py | 4 ---- comlrl/utils/distributed.py | 14 -------------- 6 files changed, 24 insertions(+), 42 deletions(-) diff --git a/comlrl/schedulers/torchrun_scheduler.py b/comlrl/schedulers/torchrun_scheduler.py index f938136..940ca5f 100644 --- a/comlrl/schedulers/torchrun_scheduler.py +++ b/comlrl/schedulers/torchrun_scheduler.py @@ -21,14 +21,14 @@ def world_size_from_env() -> int: def resolve_mode(cls, requested_mode: Optional[str]) -> str: mode = str(requested_mode or "auto").strip().lower() if mode not in cls._VALID_MODES: - raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") + raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") world_size = cls.world_size_from_env() if mode == "auto": return "ddp" if world_size > 1 else "scheduler" if mode == "scheduler" and world_size > 1: raise ValueError( - "parallel_mode='scheduler' requires WORLD_SIZE=1 (single process)." + "parallel_training='scheduler' requires WORLD_SIZE=1 (single process)." ) return mode diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index b88b4a6..1ade28f 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -50,7 +50,7 @@ class IACConfig: value_head_hidden_dim: Optional[int] = None num_agents: int = 2 num_turns: int = 2 - parallel_mode: str = "auto" + parallel_training: str = "auto" agent_devices: Optional[Union[str, Sequence[str]]] = None critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False @@ -90,9 +90,9 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") - mode = str(self.parallel_mode or "auto").lower() + mode = str(self.parallel_training or "auto").lower() if mode not in {"auto", "ddp", "scheduler"}: - raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") + raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") @dataclass @@ -172,16 +172,16 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} self.critic_type = (self.args.critic_type or "v").lower() - self.parallel_mode = TorchrunScheduler.resolve_mode( - getattr(self.args, "parallel_mode", "auto") + self.parallel_training = TorchrunScheduler.resolve_mode( + getattr(self.args, "parallel_training", "auto") ) - if self.parallel_mode == "ddp": + if self.parallel_training == "ddp": if ( getattr(self.args, "agent_devices", None) is not None or getattr(self.args, "critic_devices", None) is not None ): raise ValueError( - "agent_devices/critic_devices are only valid in parallel_mode='scheduler'." + "agent_devices/critic_devices are only valid in parallel_training='scheduler'." ) self.dist_env = TorchrunScheduler.ddp_context() self.device = self.dist_env.device diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index cadc340..bf0b47c 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -46,7 +46,7 @@ class MAACConfig: num_agents: int = 2 num_generations: int = 1 num_turns: int = 2 - parallel_mode: str = "auto" + parallel_training: str = "auto" agent_devices: Optional[Union[str, Sequence[str]]] = None critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False @@ -84,9 +84,9 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") - mode = str(self.parallel_mode or "auto").lower() + mode = str(self.parallel_training or "auto").lower() if mode not in {"auto", "ddp", "scheduler"}: - raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") + raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") class MAACTrainer(ActorCriticTrainerBase): @@ -145,16 +145,16 @@ def __init__( self.eval_dataset = eval_dataset self.metrics_callback = metrics_callback self.model_config = model_config or {} - self.parallel_mode = TorchrunScheduler.resolve_mode( - getattr(self.args, "parallel_mode", "auto") + self.parallel_training = TorchrunScheduler.resolve_mode( + getattr(self.args, "parallel_training", "auto") ) - if self.parallel_mode == "ddp": + if self.parallel_training == "ddp": if ( getattr(self.args, "agent_devices", None) is not None or getattr(self.args, "critic_devices", None) is not None ): raise ValueError( - "agent_devices/critic_devices are only valid in parallel_mode='scheduler'." + "agent_devices/critic_devices are only valid in parallel_training='scheduler'." ) self.dist_env = TorchrunScheduler.ddp_context() self.device = self.dist_env.device diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index d222068..01dd175 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -37,7 +37,7 @@ class MAGRPOConfig: agent_learning_rate: float = 5.0e-6 logging_steps: int = 50 num_agents: int = 2 - parallel_mode: str = "auto" + parallel_training: str = "auto" agent_devices: Optional[Union[str, Sequence[str]]] = None # Sampling/generation @@ -87,9 +87,9 @@ def __post_init__(self) -> None: self.train_batch_size = self.rollout_buffer_size if self.train_batch_size < 1: raise ValueError("train_batch_size must be >= 1.") - mode = str(self.parallel_mode or "auto").lower() + mode = str(self.parallel_training or "auto").lower() if mode not in {"auto", "ddp", "scheduler"}: - raise ValueError("parallel_mode must be one of: auto, ddp, scheduler.") + raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") @dataclass @@ -156,13 +156,13 @@ def __init__( args: Optional[MAGRPOConfig] = None, ): self.args = args if args is not None else self.default_config_cls() - self.parallel_mode = TorchrunScheduler.resolve_mode( - getattr(self.args, "parallel_mode", "auto") + self.parallel_training = TorchrunScheduler.resolve_mode( + getattr(self.args, "parallel_training", "auto") ) - if self.parallel_mode == "ddp": + if self.parallel_training == "ddp": if getattr(self.args, "agent_devices", None) is not None: raise ValueError( - "agent_devices is only valid in parallel_mode='scheduler'." + "agent_devices is only valid in parallel_training='scheduler'." ) self.dist_env = TorchrunScheduler.ddp_context() self.device = self.dist_env.device @@ -214,7 +214,7 @@ def __init__( self.num_agents = expected_count self.model_name = model_name - if self.parallel_mode == "ddp": + if self.parallel_training == "ddp": self.agent_devices = [self.device] * self.num_agents else: self.agent_devices = DeviceScheduler.resolve_devices( diff --git a/comlrl/utils/__init__.py b/comlrl/utils/__init__.py index 6272ff5..b303b2d 100644 --- a/comlrl/utils/__init__.py +++ b/comlrl/utils/__init__.py @@ -9,9 +9,7 @@ init_distributed, is_main_process, local_context, - resolve_parallel_mode, unwrap_model, - world_size_from_env, wrap_ddp, ) from .tokenizer_utils import ( @@ -36,8 +34,6 @@ "normalize_reward_lengths", "DistributedContext", "init_distributed", - "resolve_parallel_mode", - "world_size_from_env", "local_context", "wrap_ddp", "unwrap_model", diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py index bc5f53b..0e06717 100644 --- a/comlrl/utils/distributed.py +++ b/comlrl/utils/distributed.py @@ -19,20 +19,6 @@ class DistributedContext: device: torch.device -def world_size_from_env() -> int: - # Backward-compatible shim; prefer comlrl.schedulers.TorchrunScheduler. - from comlrl.schedulers import TorchrunScheduler - - return TorchrunScheduler.world_size_from_env() - - -def resolve_parallel_mode(requested_mode: Optional[str]) -> str: - # Backward-compatible shim; prefer comlrl.schedulers.TorchrunScheduler. - from comlrl.schedulers import TorchrunScheduler - - return TorchrunScheduler.resolve_mode(requested_mode) - - def local_context(device: Optional[torch.device] = None) -> DistributedContext: if device is None: if torch.cuda.is_available(): From ff98e590f177ee5797f239756385b81c49f1fb0e Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 14:50:32 -0500 Subject: [PATCH 04/21] add user guide for parallel training modes --- .../docs/user-guide/parallel_training.md | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 docs/content/docs/user-guide/parallel_training.md diff --git a/docs/content/docs/user-guide/parallel_training.md b/docs/content/docs/user-guide/parallel_training.md new file mode 100644 index 0000000..d22f717 --- /dev/null +++ b/docs/content/docs/user-guide/parallel_training.md @@ -0,0 +1,72 @@ +--- +title: Parallel Training +weight: 6 +--- + +CoMLRL supports two mutually-exclusive execution schedulers for training/inference: + +- `ddp`: Distributed Data Parallel via `torchrun` (`TorchrunScheduler`) +- `scheduler`: Single-process device placement for agents/critics (`DeviceScheduler`) + +## Config Fields + +Use these fields in `iac` / `maac` / `magrpo` sections: + +- `parallel_training`: `auto | ddp | scheduler` +- `agent_devices`: optional device spec (e.g., `"cuda:0"` or `["cuda:0", "cuda:1"]`) +- `critic_devices`: optional device spec for critic(s) (IAC/MAAC) + +{{% hint info %}} +`parallel_training=ddp` and explicit `agent_devices` / `critic_devices` are mutually exclusive. +{{% /hint %}} + +## How `auto` Is Resolved + +`auto` is resolved by `WORLD_SIZE`: + +- `WORLD_SIZE > 1` -> `ddp` +- `WORLD_SIZE = 1` -> `scheduler` + +This decision does **not** depend on how many GPUs are visible. + +## `CUDA_VISIBLE_DEVICES` vs `WORLD_SIZE` + +- `CUDA_VISIBLE_DEVICES`: which GPUs are visible to each process +- `WORLD_SIZE`: how many processes participate in distributed training + +Examples: + +1. `CUDA_VISIBLE_DEVICES=0,1 python train.py ...` +- one process (`WORLD_SIZE=1`) +- `auto` -> `scheduler` + +2. `CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py ...` +- two processes (`WORLD_SIZE=2`) +- `auto` -> `ddp` + +## Usage Examples + +### Device Scheduler (single process, model-level placement) + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ + --config configs/iac_xxx.yaml \ + --override \ + iac.parallel_training=scheduler \ + iac.agent_devices='["cuda:0","cuda:1"]' \ + iac.critic_devices='["cuda:2","cuda:3"]' +``` + +### DDP (multi-process data parallel) + +```bash +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ + --config configs/iac_xxx.yaml \ + --override iac.parallel_training=ddp +``` + +For DDP, do not set `agent_devices`/`critic_devices`. + +## Logging Note + +In DDP mode, trainer metrics and model save are rank-0 only. W&B `system/*` metrics reflect the rank-0 process view by default. From 44f533128c8374aa543e621b75f9e3e30733be9e Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 14:52:45 -0500 Subject: [PATCH 05/21] rename guide to multi-gpu training and clarify two modes --- .../docs/user-guide/multi-gpu-training.md | 73 +++++++++++++++++++ .../docs/user-guide/parallel_training.md | 72 ------------------ 2 files changed, 73 insertions(+), 72 deletions(-) create mode 100644 docs/content/docs/user-guide/multi-gpu-training.md delete mode 100644 docs/content/docs/user-guide/parallel_training.md diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/multi-gpu-training.md new file mode 100644 index 0000000..2b82ced --- /dev/null +++ b/docs/content/docs/user-guide/multi-gpu-training.md @@ -0,0 +1,73 @@ +--- +title: Multi-GPU Training +weight: 6 +--- + +When multiple GPUs are available, CoMLRL can train more efficiently (higher throughput and better wall-clock speed) with two mutually-exclusive modes. + +## Mode 1: DDP (`ddp`) + +Distributed Data Parallel uses multiple processes (typically one process per GPU) and synchronizes gradients. + +- Best for scaling training throughput with larger effective batch processing +- Requires `torchrun` (multi-process launch) +- Does not reduce per-GPU model memory by itself + +Example: + +```bash +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ + --config configs/iac_xxx.yaml \ + --override iac.parallel_training=ddp +``` + +## Mode 2: Device Scheduler (`scheduler`) + +Single-process model placement that assigns different agents/critics to different GPUs. + +- Best for multi-agent layouts or heterogeneous model placement +- No gradient all-reduce across processes +- Configure with `agent_devices` / `critic_devices` + +Example: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ + --config configs/iac_xxx.yaml \ + --override \ + iac.parallel_training=scheduler \ + iac.agent_devices='["cuda:0","cuda:1"]' \ + iac.critic_devices='["cuda:2","cuda:3"]' +``` + +## `auto` Selection Rule + +`parallel_training=auto` is resolved by `WORLD_SIZE`: + +- `WORLD_SIZE > 1` -> `ddp` +- `WORLD_SIZE = 1` -> `scheduler` + +So if you run plain `python` (single process), `auto` selects `scheduler` even when multiple GPUs are visible. + +## `CUDA_VISIBLE_DEVICES` vs `WORLD_SIZE` + +- `CUDA_VISIBLE_DEVICES`: which GPUs each process can see +- `WORLD_SIZE`: how many processes participate in distributed training + +These are related but not the same variable. + +## Config Fields + +Use these fields in `iac` / `maac` / `magrpo`: + +- `parallel_training`: `auto | ddp | scheduler` +- `agent_devices`: optional device spec (string or list) +- `critic_devices`: optional device spec for IAC/MAAC + +{{% hint info %}} +`parallel_training=ddp` and explicit `agent_devices` / `critic_devices` are mutually exclusive. +{{% /hint %}} + +## Logging Note + +In DDP mode, trainer logging/checkpointing is rank-0 only. W&B `system/*` metrics reflect the rank-0 process by default. diff --git a/docs/content/docs/user-guide/parallel_training.md b/docs/content/docs/user-guide/parallel_training.md deleted file mode 100644 index d22f717..0000000 --- a/docs/content/docs/user-guide/parallel_training.md +++ /dev/null @@ -1,72 +0,0 @@ ---- -title: Parallel Training -weight: 6 ---- - -CoMLRL supports two mutually-exclusive execution schedulers for training/inference: - -- `ddp`: Distributed Data Parallel via `torchrun` (`TorchrunScheduler`) -- `scheduler`: Single-process device placement for agents/critics (`DeviceScheduler`) - -## Config Fields - -Use these fields in `iac` / `maac` / `magrpo` sections: - -- `parallel_training`: `auto | ddp | scheduler` -- `agent_devices`: optional device spec (e.g., `"cuda:0"` or `["cuda:0", "cuda:1"]`) -- `critic_devices`: optional device spec for critic(s) (IAC/MAAC) - -{{% hint info %}} -`parallel_training=ddp` and explicit `agent_devices` / `critic_devices` are mutually exclusive. -{{% /hint %}} - -## How `auto` Is Resolved - -`auto` is resolved by `WORLD_SIZE`: - -- `WORLD_SIZE > 1` -> `ddp` -- `WORLD_SIZE = 1` -> `scheduler` - -This decision does **not** depend on how many GPUs are visible. - -## `CUDA_VISIBLE_DEVICES` vs `WORLD_SIZE` - -- `CUDA_VISIBLE_DEVICES`: which GPUs are visible to each process -- `WORLD_SIZE`: how many processes participate in distributed training - -Examples: - -1. `CUDA_VISIBLE_DEVICES=0,1 python train.py ...` -- one process (`WORLD_SIZE=1`) -- `auto` -> `scheduler` - -2. `CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py ...` -- two processes (`WORLD_SIZE=2`) -- `auto` -> `ddp` - -## Usage Examples - -### Device Scheduler (single process, model-level placement) - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ - --config configs/iac_xxx.yaml \ - --override \ - iac.parallel_training=scheduler \ - iac.agent_devices='["cuda:0","cuda:1"]' \ - iac.critic_devices='["cuda:2","cuda:3"]' -``` - -### DDP (multi-process data parallel) - -```bash -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ - --config configs/iac_xxx.yaml \ - --override iac.parallel_training=ddp -``` - -For DDP, do not set `agent_devices`/`critic_devices`. - -## Logging Note - -In DDP mode, trainer metrics and model save are rank-0 only. W&B `system/*` metrics reflect the rank-0 process view by default. From 3e4816ec6871d5bd498874c2236d738ec45fa5e7 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 15:16:25 -0500 Subject: [PATCH 06/21] add strict local rank to visible GPU validation for ddp --- comlrl/utils/distributed.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py index 0e06717..7ed6106 100644 --- a/comlrl/utils/distributed.py +++ b/comlrl/utils/distributed.py @@ -43,8 +43,17 @@ def init_distributed(backend: Optional[str] = None) -> DistributedContext: if torch.cuda.is_available(): if enabled: - device_count = max(1, torch.cuda.device_count()) - local_rank = local_rank % device_count + device_count = torch.cuda.device_count() + if device_count < 1: + raise RuntimeError( + "DDP requested but no CUDA devices are visible to this process." + ) + if local_rank < 0 or local_rank >= device_count: + raise ValueError( + "Invalid distributed GPU mapping: " + f"LOCAL_RANK={local_rank}, visible_cuda_devices={device_count}. " + "Make sure nproc_per_node does not exceed visible GPUs." + ) torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}" if enabled else "cuda") else: From e9f4b74e3c6991fd907fb650d543c4e8694f173a Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 15:28:45 -0500 Subject: [PATCH 07/21] ud --- comlrl/schedulers/torchrun_scheduler.py | 12 ++--- comlrl/trainers/actor_critic/iac.py | 8 ++-- comlrl/trainers/actor_critic/maac.py | 8 ++-- comlrl/trainers/reinforce/magrpo.py | 10 ++-- docs/content/docs/env/code-completion.md | 5 +- docs/content/docs/env/coding.md | 9 +--- docs/content/docs/env/minecraft.md | 7 +-- docs/content/docs/env/writing.md | 8 +--- .../docs/user-guide/multi-gpu-training.md | 48 +++++++++---------- 9 files changed, 46 insertions(+), 69 deletions(-) diff --git a/comlrl/schedulers/torchrun_scheduler.py b/comlrl/schedulers/torchrun_scheduler.py index 940ca5f..4eaa5f1 100644 --- a/comlrl/schedulers/torchrun_scheduler.py +++ b/comlrl/schedulers/torchrun_scheduler.py @@ -11,7 +11,7 @@ class TorchrunScheduler: """Resolve and initialize process-level parallel execution mode.""" - _VALID_MODES = {"auto", "ddp", "scheduler"} + _VALID_MODES = {"auto", "ddp", "mp"} @staticmethod def world_size_from_env() -> int: @@ -21,14 +21,14 @@ def world_size_from_env() -> int: def resolve_mode(cls, requested_mode: Optional[str]) -> str: mode = str(requested_mode or "auto").strip().lower() if mode not in cls._VALID_MODES: - raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") + raise ValueError("parallel_training must be one of: auto, ddp, mp.") world_size = cls.world_size_from_env() if mode == "auto": - return "ddp" if world_size > 1 else "scheduler" - if mode == "scheduler" and world_size > 1: + return "ddp" if world_size > 1 else "mp" + if mode == "mp" and world_size > 1: raise ValueError( - "parallel_training='scheduler' requires WORLD_SIZE=1 (single process)." + "parallel_training='mp' requires WORLD_SIZE=1 (single process)." ) return mode @@ -37,7 +37,7 @@ def ddp_context() -> DistributedContext: return init_distributed() @staticmethod - def scheduler_context( + def mp_context( device: Optional[torch.device] = None, ) -> DistributedContext: return local_context(device) diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 1ade28f..8c0da14 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -91,8 +91,8 @@ def __post_init__(self) -> None: if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") mode = str(self.parallel_training or "auto").lower() - if mode not in {"auto", "ddp", "scheduler"}: - raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") + if mode not in {"auto", "ddp", "mp"}: + raise ValueError("parallel_training must be one of: auto, ddp, mp.") @dataclass @@ -181,7 +181,7 @@ def __init__( or getattr(self.args, "critic_devices", None) is not None ): raise ValueError( - "agent_devices/critic_devices are only valid in parallel_training='scheduler'." + "agent_devices/critic_devices are only valid in parallel_training='mp'." ) self.dist_env = TorchrunScheduler.ddp_context() self.device = self.dist_env.device @@ -195,7 +195,7 @@ def __init__( use_separate_critic=self.args.use_separate_critic, ) self.device = self.agent_devices[0] - self.dist_env = TorchrunScheduler.scheduler_context(self.device) + self.dist_env = TorchrunScheduler.mp_context(self.device) self.agents: List[CausalLMWithValueHead] = [] self.critics: List[CausalLMWithValueHead] = [] diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index bf0b47c..e415d01 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -85,8 +85,8 @@ def __post_init__(self) -> None: if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") mode = str(self.parallel_training or "auto").lower() - if mode not in {"auto", "ddp", "scheduler"}: - raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") + if mode not in {"auto", "ddp", "mp"}: + raise ValueError("parallel_training must be one of: auto, ddp, mp.") class MAACTrainer(ActorCriticTrainerBase): @@ -154,7 +154,7 @@ def __init__( or getattr(self.args, "critic_devices", None) is not None ): raise ValueError( - "agent_devices/critic_devices are only valid in parallel_training='scheduler'." + "agent_devices/critic_devices are only valid in parallel_training='mp'." ) self.dist_env = TorchrunScheduler.ddp_context() self.device = self.dist_env.device @@ -170,7 +170,7 @@ def __init__( self.agent_devices, getattr(self.args, "critic_devices", None) ) self.device = self.agent_devices[0] - self.dist_env = TorchrunScheduler.scheduler_context(self.device) + self.dist_env = TorchrunScheduler.mp_context(self.device) tokenizers = resolve_tokenizers(agent_model, tokenizer, agents) if isinstance(tokenizers, list): diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 01dd175..265c5d9 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -88,8 +88,8 @@ def __post_init__(self) -> None: if self.train_batch_size < 1: raise ValueError("train_batch_size must be >= 1.") mode = str(self.parallel_training or "auto").lower() - if mode not in {"auto", "ddp", "scheduler"}: - raise ValueError("parallel_training must be one of: auto, ddp, scheduler.") + if mode not in {"auto", "ddp", "mp"}: + raise ValueError("parallel_training must be one of: auto, ddp, mp.") @dataclass @@ -162,12 +162,12 @@ def __init__( if self.parallel_training == "ddp": if getattr(self.args, "agent_devices", None) is not None: raise ValueError( - "agent_devices is only valid in parallel_training='scheduler'." + "agent_devices is only valid in parallel_training='mp'." ) self.dist_env = TorchrunScheduler.ddp_context() self.device = self.dist_env.device else: - self.dist_env = TorchrunScheduler.scheduler_context() + self.dist_env = TorchrunScheduler.mp_context() self.device = self.dist_env.device if agent_model is None and agents is None: @@ -223,7 +223,7 @@ def __init__( kind="agent_devices", ) self.device = self.agent_devices[0] - self.dist_env = TorchrunScheduler.scheduler_context(self.device) + self.dist_env = TorchrunScheduler.mp_context(self.device) if actor_sources and all(isinstance(src, str) for src in actor_sources): from transformers import AutoModelForCausalLM diff --git a/docs/content/docs/env/code-completion.md b/docs/content/docs/env/code-completion.md index 0f99d09..2a8a3e7 100644 --- a/docs/content/docs/env/code-completion.md +++ b/docs/content/docs/env/code-completion.md @@ -5,7 +5,4 @@ bookHref: https://github.com/OpenMLRL/LLM_Collab_Code_Completion bookHidden: true --- -Multi-agent autocompletion tasks where each model fills in part of a codebase. -The [LLM_Collab_Code_Completion](https://github.com/OpenMLRL/LLM_Collab_Code_Completion) -project currently focuses on **ClassEval**, which asks teams of LLMs to finish -class skeletons based on docstrings and partially implemented methods. + diff --git a/docs/content/docs/env/coding.md b/docs/content/docs/env/coding.md index 8a139a5..341d6bc 100644 --- a/docs/content/docs/env/coding.md +++ b/docs/content/docs/env/coding.md @@ -4,11 +4,4 @@ weight: 2 bookHref: https://github.com/OpenMLRL/LLM_Collab_Code_Generation --- -A suite of cooperative programming benchmarks where agents propose, critique, and -refine solutions. The environments shipped in -[LLM_Collab_Code_Generation](https://github.com/OpenMLRL/LLM_Collab_Code_Generation) -cover: - -- **MBPP** – mostly basic Python problems for rapid iteration. -- **HumanEval** – handwritten tasks from OpenAI for exact-match grading. -- **CoopHumanEval** – HumanEval variants that explicitly require collaboration. + diff --git a/docs/content/docs/env/minecraft.md b/docs/content/docs/env/minecraft.md index b0a468b..a523014 100644 --- a/docs/content/docs/env/minecraft.md +++ b/docs/content/docs/env/minecraft.md @@ -4,9 +4,4 @@ weight: 4 bookHref: https://github.com/OpenMLRL/LLM_Collab_Minecraft --- -Multi-agent building environments in Minecraft. The -[LLM_Collab_Minecraft](https://github.com/OpenMLRL/LLM_Collab_Minecraft) -repository includes: - -- **[StrBuild](https://github.com/OpenMLRL/LLM_Collab_Minecraft/tree/main/str_build)** – agents collaboratively build structured designs from text. -- **[HouseBuild](https://github.com/OpenMLRL/LLM_Collab_Minecraft/tree/main/house_build)** – agents coordinate materials and steps to construct houses. + diff --git a/docs/content/docs/env/writing.md b/docs/content/docs/env/writing.md index 303fabd..9689122 100644 --- a/docs/content/docs/env/writing.md +++ b/docs/content/docs/env/writing.md @@ -4,10 +4,4 @@ weight: 1 bookHref: https://github.com/OpenMLRL/LLM_Collab_Writing --- -Collaborative summarization and expansion tasks for pairs (or teams) of LLMs. -The reference implementation lives in the -[LLM_Collab_Writing](https://github.com/OpenMLRL/LLM_Collab_Writing) repository -and includes: - -- **TLDR** – distills Reddit threads into concise summaries. -- **ArXiv Introductions** – grows short abstracts into multi-paragraph drafts. + diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/multi-gpu-training.md index 2b82ced..f048064 100644 --- a/docs/content/docs/user-guide/multi-gpu-training.md +++ b/docs/content/docs/user-guide/multi-gpu-training.md @@ -3,25 +3,11 @@ title: Multi-GPU Training weight: 6 --- -When multiple GPUs are available, CoMLRL can train more efficiently (higher throughput and better wall-clock speed) with two mutually-exclusive modes. +When multiple GPUs are available, CoMLRL can improve training throughput and reduce training time. -## Mode 1: DDP (`ddp`) +CoMLRL supports two ways to leverage multiple GPUs: Model Parallel scheduling (**MP**) for agent/critic placement in a single process and PyTorch Distributed Data Parallel (**DDP**) across multiple processes -Distributed Data Parallel uses multiple processes (typically one process per GPU) and synchronizes gradients. - -- Best for scaling training throughput with larger effective batch processing -- Requires `torchrun` (multi-process launch) -- Does not reduce per-GPU model memory by itself - -Example: - -```bash -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ - --config configs/iac_xxx.yaml \ - --override iac.parallel_training=ddp -``` - -## Mode 2: Device Scheduler (`scheduler`) +## Model Parallel Scheduler Single-process model placement that assigns different agents/critics to different GPUs. @@ -35,19 +21,35 @@ Example: CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ --config configs/iac_xxx.yaml \ --override \ - iac.parallel_training=scheduler \ + iac.parallel_training=mp \ iac.agent_devices='["cuda:0","cuda:1"]' \ iac.critic_devices='["cuda:2","cuda:3"]' ``` +## Distributed Data Parallel + +Distributed Data Parallel uses multiple processes (typically one process per GPU) and synchronizes gradients. + +- Best for scaling data-parallel throughput +- Requires `torchrun` (multi-process launch) +- Does not reduce per-GPU model memory by itself + +Example: + +```bash +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ + --config configs/iac_xxx.yaml \ + --override iac.parallel_training=ddp +``` + ## `auto` Selection Rule `parallel_training=auto` is resolved by `WORLD_SIZE`: - `WORLD_SIZE > 1` -> `ddp` -- `WORLD_SIZE = 1` -> `scheduler` +- `WORLD_SIZE = 1` -> `mp` -So if you run plain `python` (single process), `auto` selects `scheduler` even when multiple GPUs are visible. +So plain `python` (single process) selects `mp`, even when multiple GPUs are visible. ## `CUDA_VISIBLE_DEVICES` vs `WORLD_SIZE` @@ -60,14 +62,10 @@ These are related but not the same variable. Use these fields in `iac` / `maac` / `magrpo`: -- `parallel_training`: `auto | ddp | scheduler` +- `parallel_training`: `auto | ddp | mp` - `agent_devices`: optional device spec (string or list) - `critic_devices`: optional device spec for IAC/MAAC {{% hint info %}} `parallel_training=ddp` and explicit `agent_devices` / `critic_devices` are mutually exclusive. {{% /hint %}} - -## Logging Note - -In DDP mode, trainer logging/checkpointing is rank-0 only. W&B `system/*` metrics reflect the rank-0 process by default. From c4b25a854aae08269bc0591218c4f3dc23b8432d Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 15:57:16 -0500 Subject: [PATCH 08/21] Update multi-gpu-training.md --- .../docs/user-guide/multi-gpu-training.md | 64 ++++++------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/multi-gpu-training.md index f048064..5c77957 100644 --- a/docs/content/docs/user-guide/multi-gpu-training.md +++ b/docs/content/docs/user-guide/multi-gpu-training.md @@ -5,36 +5,28 @@ weight: 6 When multiple GPUs are available, CoMLRL can improve training throughput and reduce training time. -CoMLRL supports two ways to leverage multiple GPUs: Model Parallel scheduling (**MP**) for agent/critic placement in a single process and PyTorch Distributed Data Parallel (**DDP**) across multiple processes +CoMLRL supports two ways to leverage multiple GPUs: Model Parallel scheduling (**MP**) for agent/critic deployment and PyTorch Distributed Data Parallel (**DDP**) across multiple processes. ## Model Parallel Scheduler -Single-process model placement that assigns different agents/critics to different GPUs. - -- Best for multi-agent layouts or heterogeneous model placement -- No gradient all-reduce across processes -- Configure with `agent_devices` / `critic_devices` - -Example: +When `parallel_training=mp`, CoMLRL deploys the agents and critics across the specified devices via `agent_devices` / `critic_devices`. +The training and inference for each model (agent/critic) are running separately on its assigned device. +The responses are aggregated on the CPU and pass to the reward function. The reward is then broadcast back to all devices for training. +MP supports training larger and more models than a single GPU can hold, but the training throughput is limited by the slowest model. ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ - --config configs/iac_xxx.yaml \ - --override \ - iac.parallel_training=mp \ - iac.agent_devices='["cuda:0","cuda:1"]' \ +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py + --config configs/iac_xxx.yaml + --override + iac.parallel_training=mp + iac.agent_devices='["cuda:0","cuda:1"]' iac.critic_devices='["cuda:2","cuda:3"]' ``` ## Distributed Data Parallel -Distributed Data Parallel uses multiple processes (typically one process per GPU) and synchronizes gradients. - -- Best for scaling data-parallel throughput -- Requires `torchrun` (multi-process launch) -- Does not reduce per-GPU model memory by itself - -Example: +When `parallel_training=ddp`, CoMLRL launches multiple processes (one per GPU) and synchronizes gradients across them. Each process runs the full training loop across multiple models, but only on its assigned GPU. The model parameters are kept in sync across processes using PyTorch's DDP. +DDP improves the training throughput, but requires more GPU memory since each process holds a full copy of the models. DDP also requires more careful setup (e.g., environment variables, process launching) and may not be compatible with all reward functions. ```bash CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ @@ -42,30 +34,10 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ --override iac.parallel_training=ddp ``` -## `auto` Selection Rule - -`parallel_training=auto` is resolved by `WORLD_SIZE`: - -- `WORLD_SIZE > 1` -> `ddp` -- `WORLD_SIZE = 1` -> `mp` - -So plain `python` (single process) selects `mp`, even when multiple GPUs are visible. - -## `CUDA_VISIBLE_DEVICES` vs `WORLD_SIZE` - -- `CUDA_VISIBLE_DEVICES`: which GPUs each process can see -- `WORLD_SIZE`: how many processes participate in distributed training - -These are related but not the same variable. - -## Config Fields - -Use these fields in `iac` / `maac` / `magrpo`: - -- `parallel_training`: `auto | ddp | mp` -- `agent_devices`: optional device spec (string or list) -- `critic_devices`: optional device spec for IAC/MAAC +## Auto Parallelization -{{% hint info %}} -`parallel_training=ddp` and explicit `agent_devices` / `critic_devices` are mutually exclusive. -{{% /hint %}} +The `parallel_training` field is set to `auto` by default. +When users have `WORLD_SIZE=1` and `CUDA_VISIBLE_DEVICES=0`, CoMLRL trainers fall back to single-gpu training on `cuda:0` without launching multiple processes. +When users have multiple GPUs available, and `WORLD_SIZE=1`, CoMLRL trainers use MP to deploy models across the visible GPUs. +When users have multiple GPUs and `WORLD_SIZE > 1`, CoMLRL trainers use DDP to synchronize training across processes. +These two modes are mutually exclusive. From 6109f4b7b2a5f5fa94d325fd87f28fc4369d40a3 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 16:10:28 -0500 Subject: [PATCH 09/21] ud --- docs/content/docs/dev/changelog.md | 5 +++++ docs/content/docs/user-guide/model-loading.md | 22 ++++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/docs/content/docs/dev/changelog.md b/docs/content/docs/dev/changelog.md index 91aa575..908927f 100644 --- a/docs/content/docs/dev/changelog.md +++ b/docs/content/docs/dev/changelog.md @@ -5,6 +5,11 @@ weight: 3 --- +## Version 1.3.7 + +- Remove the redundant sampling hyperparameters in algorithms, change the sampling logics. +- Allow multi-gpu training with MP and DDP. + ## Version 1.3.6 - Fixed critical bug of loading heterogeneous models and reform the model loading logics diff --git a/docs/content/docs/user-guide/model-loading.md b/docs/content/docs/user-guide/model-loading.md index a290b33..de887ba 100644 --- a/docs/content/docs/user-guide/model-loading.md +++ b/docs/content/docs/user-guide/model-loading.md @@ -18,7 +18,7 @@ For example, to load 3 _Qwen/Qwen2.5-1.5B_ agents: trainer = MAGRPOTrainer( agent_model="Qwen/Qwen2.5-1.5B", agents=None, - num_agents=3, + args=MAGRPOConfig(num_agents=3, temperature=0.7, top_p=0.9, top_k=None), ) ``` @@ -34,7 +34,7 @@ For example, to load a _Qwen/Qwen2.5-Coder-3B_ and a _Qwen/Qwen2.5-Coder-7B_: trainer = MAGRPOTrainer( agent_model=None, agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], - num_agents=2, + args=MAGRPOConfig(num_agents=2, temperature=0.7, top_p=0.9, top_k=None), ) ``` @@ -51,7 +51,7 @@ trainer = MAACTrainer( agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-1.5B"], critic_model="Qwen/Qwen2.5-Coder-7B", critics=None, - num_agents=2, + args=MAACConfig(num_agents=2, temperature=0.7, top_p=0.9, top_k=None), ) ``` @@ -67,7 +67,13 @@ trainer = IACTrainer( agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], critic_model=None, critics=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], - num_agents=2, + args=IACConfig( + num_agents=2, + use_separate_critic=True, + temperature=0.7, + top_p=0.9, + top_k=None, + ), ) ``` @@ -79,7 +85,13 @@ trainer = IACTrainer( agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], critic_model=None, critics=None, - num_agents=2, + args=IACConfig( + num_agents=2, + use_separate_critic=False, + temperature=0.7, + top_p=0.9, + top_k=None, + ), ) ``` From 18bf5063068b1e64c682f628e3bac349a75e9da1 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 16:33:34 -0500 Subject: [PATCH 10/21] ud --- comlrl/schedulers/torchrun_scheduler.py | 28 ++++- comlrl/trainers/actor_critic/ac_base.py | 29 ++++- comlrl/trainers/reinforce/magrpo.py | 54 +++++--- comlrl/utils/distributed.py | 47 ++++++- .../docs/user-guide/multi-gpu-training.md | 4 +- tests/test_config_constraints.py | 3 + tests/test_distributed_metrics.py | 116 ++++++++++++++++++ tests/test_torchrun_scheduler.py | 75 +++++++++++ 8 files changed, 330 insertions(+), 26 deletions(-) create mode 100644 tests/test_distributed_metrics.py create mode 100644 tests/test_torchrun_scheduler.py diff --git a/comlrl/schedulers/torchrun_scheduler.py b/comlrl/schedulers/torchrun_scheduler.py index 4eaa5f1..9b5e600 100644 --- a/comlrl/schedulers/torchrun_scheduler.py +++ b/comlrl/schedulers/torchrun_scheduler.py @@ -15,7 +15,15 @@ class TorchrunScheduler: @staticmethod def world_size_from_env() -> int: - return int(os.environ.get("WORLD_SIZE", "1")) + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except (TypeError, ValueError): + return 1 + + @staticmethod + def _missing_ddp_env_vars() -> list[str]: + required = ("WORLD_SIZE", "RANK", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT") + return [k for k in required if str(os.environ.get(k, "")).strip() == ""] @classmethod def resolve_mode(cls, requested_mode: Optional[str]) -> str: @@ -25,7 +33,23 @@ def resolve_mode(cls, requested_mode: Optional[str]) -> str: world_size = cls.world_size_from_env() if mode == "auto": - return "ddp" if world_size > 1 else "mp" + if world_size <= 1: + return "mp" + # In shared cluster environments WORLD_SIZE may be exported globally. + # Only switch to DDP when torchrun-style variables are complete. + return "ddp" if not cls._missing_ddp_env_vars() else "mp" + if mode == "ddp": + if world_size <= 1: + raise ValueError( + "parallel_training='ddp' requires WORLD_SIZE>1. " + "Use torchrun --nproc_per_node=... to launch." + ) + missing = cls._missing_ddp_env_vars() + if missing: + raise ValueError( + "parallel_training='ddp' requires torchrun environment variables. " + f"Missing: {', '.join(missing)}." + ) if mode == "mp" and world_size > 1: raise ValueError( "parallel_training='mp' requires WORLD_SIZE=1 (single process)." diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index cf181d6..e5ca9f9 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -6,6 +6,8 @@ from tqdm import tqdm # type: ignore from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from comlrl.utils.distributed import barrier as dist_barrier +from comlrl.utils.distributed import reduce_metrics_dict try: from datasets import IterableDataset as HFIterableDataset @@ -189,14 +191,23 @@ def _tag_metrics( prefix = f"turn_{turn_idx + 1}/" return {prefix + key: value for key, value in metrics.items()} - def _log_metrics(self, metrics: Dict[str, float]) -> None: + def _log_metrics( + self, metrics: Dict[str, float], *, synchronize: bool = True + ) -> None: if not metrics: return dist_env = getattr(self, "dist_env", None) - if dist_env is not None and dist_env.enabled and not dist_env.is_main: + metrics_to_log = metrics + if dist_env is not None and dist_env.enabled and synchronize: + metrics_to_log = reduce_metrics_dict(metrics, dist_env) + if not dist_env.is_main: + return + elif dist_env is not None and dist_env.enabled and not dist_env.is_main: + return + if not metrics_to_log: return if self.wandb_initialized and wandb is not None: - wandb.log(metrics, step=self.env_step) + wandb.log(metrics_to_log, step=self.env_step) def _should_log_train(self) -> bool: interval = int(getattr(self.args, "logging_steps", 1)) @@ -333,7 +344,7 @@ def evaluate(self) -> Dict[str, float]: eval_log[f"eval/turn_{turn_idx + 1}/{key}"] = value if eval_log: - self._log_metrics(eval_log) + self._log_metrics(eval_log, synchronize=False) return eval_log def train(self) -> None: @@ -352,7 +363,15 @@ def train(self) -> None: and self.args.eval_interval > 0 and batch_idx % int(self.args.eval_interval) == 0 ): - self.evaluate() + dist_env = getattr(self, "dist_env", None) + if dist_env is not None and dist_env.enabled: + # Keep all ranks step-aligned to avoid DDP hangs during eval windows. + dist_barrier(dist_env) + if dist_env.is_main: + self.evaluate() + dist_barrier(dist_env) + else: + self.evaluate() self._run_batch(batch, epoch_metrics) self._flush_buffers(epoch_metrics) diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 265c5d9..e7fe486 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -16,6 +16,8 @@ from comlrl.schedulers import DeviceScheduler, TorchrunScheduler from comlrl.utils.distributed import ( + barrier as dist_barrier, + reduce_metrics_dict, unwrap_model, wrap_ddp, ) @@ -738,8 +740,19 @@ def train(self, **kwargs): if int(self.args.eval_interval) > 0 and ( batch_idx % int(self.args.eval_interval) == 0 ): - # evaluate() already logs its metrics; avoid duplicate logging here - _ = self.evaluate(num_eval_samples=int(self.args.eval_num_samples)) + if self.dist_env.enabled: + # Keep all ranks synchronized around evaluation windows. + dist_barrier(self.dist_env) + if self.dist_env.is_main: + # evaluate() already logs its metrics. + _ = self.evaluate( + num_eval_samples=int(self.args.eval_num_samples) + ) + dist_barrier(self.dist_env) + else: + _ = self.evaluate( + num_eval_samples=int(self.args.eval_num_samples) + ) # Process single batch item (batch_size=1 enforced) batch_item = batch[0] @@ -756,20 +769,29 @@ def train(self, **kwargs): self._process_buffer(agent_idx, buffer) # Log per-turn epoch averages inline (avoid custom system/* metrics) - if self.wandb_initialized and wandb.run is not None: - epoch_log: Dict[str, Any] = {} - n_turns = max(1, int(self.args.num_turns)) - for turn_idx in range(n_turns): - if epoch_turn_rewards and epoch_turn_rewards[turn_idx]: - epoch_log[f"turn_{turn_idx + 1}/epoch_reward_mean"] = float( - np.mean(epoch_turn_rewards[turn_idx]) - ) - if epoch_turn_returns and epoch_turn_returns[turn_idx]: - epoch_log[f"turn_{turn_idx + 1}/epoch_avg_return"] = float( - np.mean(epoch_turn_returns[turn_idx]) - ) - if epoch_log: - wandb.log(epoch_log, step=self.env_step) + epoch_log: Dict[str, Any] = {} + n_turns = max(1, int(self.args.num_turns)) + for turn_idx in range(n_turns): + if epoch_turn_rewards and epoch_turn_rewards[turn_idx]: + epoch_log[f"turn_{turn_idx + 1}/epoch_reward_mean"] = float( + np.mean(epoch_turn_rewards[turn_idx]) + ) + if epoch_turn_returns and epoch_turn_returns[turn_idx]: + epoch_log[f"turn_{turn_idx + 1}/epoch_avg_return"] = float( + np.mean(epoch_turn_returns[turn_idx]) + ) + + if self.dist_env.enabled: + reduced_epoch_log = reduce_metrics_dict(epoch_log, self.dist_env) + if ( + self.dist_env.is_main + and reduced_epoch_log + and self.wandb_initialized + and wandb.run is not None + ): + wandb.log(reduced_epoch_log, step=self.env_step) + elif epoch_log and self.wandb_initialized and wandb.run is not None: + wandb.log(epoch_log, step=self.env_step) def _train_step_returns( self, diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py index 7ed6106..68767a0 100644 --- a/comlrl/utils/distributed.py +++ b/comlrl/utils/distributed.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.distributed as dist @@ -113,3 +113,48 @@ def all_gather_objects(obj: Any, ctx: Optional[DistributedContext]) -> List[Any] gathered: List[Any] = [None for _ in range(ctx.world_size)] dist.all_gather_object(gathered, obj) return gathered + + +def reduce_metrics_dict( + metrics: Dict[str, float], + ctx: Optional[DistributedContext], +) -> Dict[str, float]: + """Average scalar metrics across distributed ranks. + + This helper must be called by all ranks in the same order. + """ + if ctx is None or not ctx.enabled: + return dict(metrics) + if not metrics: + return {} + + keys = sorted(metrics.keys()) + gathered_keys = all_gather_objects(keys, ctx) + same_keyset = all(k == keys for k in gathered_keys) + + if same_keyset: + values = torch.tensor( + [float(metrics[k]) for k in keys], device=ctx.device, dtype=torch.float64 + ) + dist.all_reduce(values, op=dist.ReduceOp.SUM) + values /= float(ctx.world_size) + reduced = {k: float(values[i].item()) for i, k in enumerate(keys)} + return reduced if ctx.is_main else {} + + union_keys = sorted({k for key_list in gathered_keys for k in key_list}) + value_tensor = torch.tensor( + [float(metrics.get(k, 0.0)) for k in union_keys], + device=ctx.device, + dtype=torch.float64, + ) + count_tensor = torch.tensor( + [1.0 if k in metrics else 0.0 for k in union_keys], + device=ctx.device, + dtype=torch.float64, + ) + dist.all_reduce(value_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + count_tensor = torch.clamp(count_tensor, min=1.0) + averaged = value_tensor / count_tensor + reduced = {k: float(averaged[i].item()) for i, k in enumerate(union_keys)} + return reduced if ctx.is_main else {} diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/multi-gpu-training.md index 5c77957..1723570 100644 --- a/docs/content/docs/user-guide/multi-gpu-training.md +++ b/docs/content/docs/user-guide/multi-gpu-training.md @@ -23,7 +23,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py iac.critic_devices='["cuda:2","cuda:3"]' ``` -## Distributed Data Parallel +## Distributed Data Parallel Scheduler When `parallel_training=ddp`, CoMLRL launches multiple processes (one per GPU) and synchronizes gradients across them. Each process runs the full training loop across multiple models, but only on its assigned GPU. The model parameters are kept in sync across processes using PyTorch's DDP. DDP improves the training throughput, but requires more GPU memory since each process holds a full copy of the models. DDP also requires more careful setup (e.g., environment variables, process launching) and may not be compatible with all reward functions. @@ -39,5 +39,5 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ The `parallel_training` field is set to `auto` by default. When users have `WORLD_SIZE=1` and `CUDA_VISIBLE_DEVICES=0`, CoMLRL trainers fall back to single-gpu training on `cuda:0` without launching multiple processes. When users have multiple GPUs available, and `WORLD_SIZE=1`, CoMLRL trainers use MP to deploy models across the visible GPUs. -When users have multiple GPUs and `WORLD_SIZE > 1`, CoMLRL trainers use DDP to synchronize training across processes. +When users have multiple GPUs and complete torchrun distributed env vars (`WORLD_SIZE/RANK/LOCAL_RANK/MASTER_ADDR/MASTER_PORT`), CoMLRL trainers use DDP to synchronize training across processes. These two modes are mutually exclusive. diff --git a/tests/test_config_constraints.py b/tests/test_config_constraints.py index 8d0ee88..ea5822b 100644 --- a/tests/test_config_constraints.py +++ b/tests/test_config_constraints.py @@ -31,6 +31,7 @@ def test_iac_config_constraints(): _assert_invalid_fields(IACConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(IACConfig, "num_generations", 0) _assert_invalid(IACConfig, "critic_type", "x") + _assert_invalid(IACConfig, "parallel_training", "invalid") with pytest.raises(ValueError, match="num_generations"): IACConfig(num_turns=2, num_generations=2) @@ -55,6 +56,7 @@ def test_maac_config_constraints(): ) _assert_invalid_fields(MAACConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(MAACConfig, "critic_type", "x") + _assert_invalid(MAACConfig, "parallel_training", "invalid") with pytest.raises(ValueError, match="num_generations"): MAACConfig(num_turns=2, num_generations=2) @@ -79,6 +81,7 @@ def test_magrpo_config_constraints(): ) _assert_invalid_fields(MAGRPOConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(MAGRPOConfig, "num_generations", 1) + _assert_invalid(MAGRPOConfig, "parallel_training", "invalid") MAGRPOConfig() MAGRPOConfig(num_generations=2) diff --git a/tests/test_distributed_metrics.py b/tests/test_distributed_metrics.py new file mode 100644 index 0000000..c081662 --- /dev/null +++ b/tests/test_distributed_metrics.py @@ -0,0 +1,116 @@ +from types import SimpleNamespace + +import torch + +import comlrl.utils.distributed as dist_utils +from comlrl.trainers.actor_critic.ac_base import ActorCriticTrainerBase +from comlrl.utils.distributed import DistributedContext + + +def _ctx(*, enabled: bool, is_main: bool, rank: int = 0, world_size: int = 1): + return DistributedContext( + enabled=enabled, + rank=rank, + world_size=world_size, + local_rank=rank, + is_main=is_main, + device=torch.device("cpu"), + ) + + +def test_reduce_metrics_dict_local_returns_input(): + ctx = _ctx(enabled=False, is_main=True) + metrics = {"loss": 1.5, "reward": 2.5} + assert dist_utils.reduce_metrics_dict(metrics, ctx) == metrics + + +def test_reduce_metrics_dict_distributed_main_averages(monkeypatch): + ctx = _ctx(enabled=True, is_main=True, world_size=2) + metrics = {"a": 1.0, "b": 3.0} + + monkeypatch.setattr( + dist_utils, + "all_gather_objects", + lambda obj, _ctx: [obj, obj], + ) + + def _fake_all_reduce(tensor, op=None): # noqa: ARG001 + tensor += torch.tensor([3.0, 1.0], dtype=tensor.dtype, device=tensor.device) + + monkeypatch.setattr(dist_utils.dist, "all_reduce", _fake_all_reduce) + + reduced = dist_utils.reduce_metrics_dict(metrics, ctx) + assert reduced == {"a": 2.0, "b": 2.0} + + +def test_reduce_metrics_dict_distributed_non_main_returns_empty(monkeypatch): + ctx = _ctx(enabled=True, is_main=False, rank=1, world_size=2) + metrics = {"a": 1.0} + + monkeypatch.setattr( + dist_utils, + "all_gather_objects", + lambda obj, _ctx: [obj, obj], + ) + + def _fake_all_reduce(tensor, op=None): # noqa: ARG001 + tensor += torch.tensor([1.0], dtype=tensor.dtype, device=tensor.device) + + monkeypatch.setattr(dist_utils.dist, "all_reduce", _fake_all_reduce) + + assert dist_utils.reduce_metrics_dict(metrics, ctx) == {} + + +def test_ac_base_log_metrics_skips_reduction_when_unsynchronized(monkeypatch): + trainer = ActorCriticTrainerBase() + trainer.dist_env = _ctx(enabled=True, is_main=True, world_size=2) + trainer.wandb_initialized = True + trainer.env_step = 5 + trainer.args = SimpleNamespace(logging_steps=1) + trainer._last_train_log_step = -1 + + called = {"reduce": 0, "log": []} + + def _fake_reduce(metrics, _ctx): # noqa: ARG001 + called["reduce"] += 1 + return {"loss": 99.0} + + def _fake_log(metrics, step): # noqa: ARG001 + called["log"].append(dict(metrics)) + + monkeypatch.setattr( + "comlrl.trainers.actor_critic.ac_base.reduce_metrics_dict", _fake_reduce + ) + monkeypatch.setattr("wandb.log", _fake_log) + + trainer._log_metrics({"loss": 1.0}, synchronize=False) + assert called["reduce"] == 0 + assert called["log"] == [{"loss": 1.0}] + + +def test_ac_base_log_metrics_reduces_when_synchronized(monkeypatch): + trainer = ActorCriticTrainerBase() + trainer.dist_env = _ctx(enabled=True, is_main=True, world_size=2) + trainer.wandb_initialized = True + trainer.env_step = 6 + trainer.args = SimpleNamespace(logging_steps=1) + trainer._last_train_log_step = -1 + + called = {"reduce": 0, "log": []} + + def _fake_reduce(metrics, _ctx): # noqa: ARG001 + called["reduce"] += 1 + assert metrics == {"loss": 1.0} + return {"loss": 2.0} + + def _fake_log(metrics, step): # noqa: ARG001 + called["log"].append(dict(metrics)) + + monkeypatch.setattr( + "comlrl.trainers.actor_critic.ac_base.reduce_metrics_dict", _fake_reduce + ) + monkeypatch.setattr("wandb.log", _fake_log) + + trainer._log_metrics({"loss": 1.0}, synchronize=True) + assert called["reduce"] == 1 + assert called["log"] == [{"loss": 2.0}] diff --git a/tests/test_torchrun_scheduler.py b/tests/test_torchrun_scheduler.py new file mode 100644 index 0000000..e801f5d --- /dev/null +++ b/tests/test_torchrun_scheduler.py @@ -0,0 +1,75 @@ +import os +from contextlib import contextmanager + +import pytest + +from comlrl.schedulers.torchrun_scheduler import TorchrunScheduler + + +@contextmanager +def _set_env(**updates): + old = {k: os.environ.get(k) for k in updates} + for key, value in updates.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = str(value) + try: + yield + finally: + for key, value in old.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +def test_auto_uses_mp_when_world_size_is_one(): + with _set_env( + WORLD_SIZE="1", + RANK=None, + LOCAL_RANK=None, + MASTER_ADDR=None, + MASTER_PORT=None, + ): + assert TorchrunScheduler.resolve_mode("auto") == "mp" + + +def test_auto_uses_mp_when_torchrun_env_is_incomplete(): + with _set_env( + WORLD_SIZE="2", + RANK=None, + LOCAL_RANK=None, + MASTER_ADDR=None, + MASTER_PORT=None, + ): + assert TorchrunScheduler.resolve_mode("auto") == "mp" + + +def test_auto_uses_ddp_when_torchrun_env_is_complete(): + with _set_env( + WORLD_SIZE="2", + RANK="0", + LOCAL_RANK="0", + MASTER_ADDR="127.0.0.1", + MASTER_PORT="29500", + ): + assert TorchrunScheduler.resolve_mode("auto") == "ddp" + + +def test_ddp_requires_complete_torchrun_env(): + with _set_env( + WORLD_SIZE="2", + RANK=None, + LOCAL_RANK=None, + MASTER_ADDR=None, + MASTER_PORT=None, + ): + with pytest.raises(ValueError, match="Missing"): + TorchrunScheduler.resolve_mode("ddp") + + +def test_mp_rejects_world_size_greater_than_one(): + with _set_env(WORLD_SIZE="2"): + with pytest.raises(ValueError, match="WORLD_SIZE=1"): + TorchrunScheduler.resolve_mode("mp") From 3cdfec8f4cd6c9c6dd41a0c494ae639e769da1a2 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 16:51:29 -0500 Subject: [PATCH 11/21] ud --- docs/content/docs/dev/changelog.md | 12 +++++----- .../docs/user-guide/multi-gpu-training.md | 22 ++++++++++++++----- .../docs/user-guide/multi-turn-training.md | 4 ++-- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/docs/content/docs/dev/changelog.md b/docs/content/docs/dev/changelog.md index 908927f..12ab69b 100644 --- a/docs/content/docs/dev/changelog.md +++ b/docs/content/docs/dev/changelog.md @@ -12,8 +12,8 @@ weight: 3 ## Version 1.3.6 -- Fixed critical bug of loading heterogeneous models and reform the model loading logics -- Polish the docs +- Fixed critical bug of loading heterogeneous models and reform the model loading logics. +- Polish the docs. ## Version 1.3.5 @@ -38,13 +38,13 @@ weight: 3 ## Version 1.3.2 -- Fix wandb logging issue in MAGRPOTrainer +- Fix wandb logging issue in MAGRPOTrainer. ## Version 1.3.1 -- Allow batch training in MAGRPOTrainer, IACTrainer and MAACTrainer -- Allow multi-turn training in IACTrainer and MAACTrainer -- Change the x-axis from data_step to env_step +- Allow batch training in MAGRPOTrainer, IACTrainer and MAACTrainer. +- Allow multi-turn training in IACTrainer and MAACTrainer. +- Change the x-axis from data_step to env_step. ## Version 1.3.0 diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/multi-gpu-training.md index 1723570..432ec4c 100644 --- a/docs/content/docs/user-guide/multi-gpu-training.md +++ b/docs/content/docs/user-guide/multi-gpu-training.md @@ -1,13 +1,23 @@ --- -title: Multi-GPU Training +title: Training Parallelization +linkTitle: Training Parallelization weight: 6 --- When multiple GPUs are available, CoMLRL can improve training throughput and reduce training time. -CoMLRL supports two ways to leverage multiple GPUs: Model Parallel scheduling (**MP**) for agent/critic deployment and PyTorch Distributed Data Parallel (**DDP**) across multiple processes. +CoMLRL supports two schedulers for leveraging multiple GPUs: Model Parallelization (**MP**) for agent/critic deployment and PyTorch Distributed Data Parallelization (**DDP**) across multiple processes. -## Model Parallel Scheduler +## Concepts + +- `CUDA_VISIBLE_DEVICES`: The GPUs visible to the current process. +- `WORLD_SIZE`: Total number of distributed processes participating in one training job. +- `RANK`: Global process index in `[0, WORLD_SIZE-1]`. +- `LOCAL_RANK`: Process index on the current node; used to select the node-local GPU. +- `MASTER_ADDR`: Address of the process-group rendezvous host (usually rank 0 node). +- `MASTER_PORT`: Port on `MASTER_ADDR` used to initialize distributed communication. + +## Model Parallelization When `parallel_training=mp`, CoMLRL deploys the agents and critics across the specified devices via `agent_devices` / `critic_devices`. The training and inference for each model (agent/critic) are running separately on its assigned device. @@ -23,14 +33,14 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py iac.critic_devices='["cuda:2","cuda:3"]' ``` -## Distributed Data Parallel Scheduler +## Distributed Data Parallelization When `parallel_training=ddp`, CoMLRL launches multiple processes (one per GPU) and synchronizes gradients across them. Each process runs the full training loop across multiple models, but only on its assigned GPU. The model parameters are kept in sync across processes using PyTorch's DDP. DDP improves the training throughput, but requires more GPU memory since each process holds a full copy of the models. DDP also requires more careful setup (e.g., environment variables, process launching) and may not be compatible with all reward functions. ```bash -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py \ - --config configs/iac_xxx.yaml \ +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py + --config configs/iac_xxx.yaml --override iac.parallel_training=ddp ``` diff --git a/docs/content/docs/user-guide/multi-turn-training.md b/docs/content/docs/user-guide/multi-turn-training.md index b53c0fe..dbf7061 100644 --- a/docs/content/docs/user-guide/multi-turn-training.md +++ b/docs/content/docs/user-guide/multi-turn-training.md @@ -1,6 +1,6 @@ --- -title: Multi-Turn Training -linkTitle: Multi-Turn Training +title: Multi-Turn Environment +linkTitle: Multi-Turn Environment weight: 5 math: true --- From 352076d48013436069765b970966577784d41043 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 17:05:26 -0500 Subject: [PATCH 12/21] fix --- comlrl/trainers/actor_critic/iac.py | 3 ++- comlrl/trainers/actor_critic/maac.py | 3 ++- comlrl/trainers/reinforce/magrpo.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 8c0da14..1e2e422 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -485,7 +485,8 @@ def _generate_rollout( if self.args.top_k is not None: generation_kwargs["top_k"] = self.args.top_k - sequences = agent_model.generate(**generation_kwargs) + generation_model = unwrap_model(agent_model) + sequences = generation_model.generate(**generation_kwargs) if sequences.size(1) <= prompt_len: raise RuntimeError("Model produced an empty completion during rollout.") diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index e415d01..5d865cc 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -462,7 +462,8 @@ def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: if self.args.top_k is not None: generation_kwargs["top_k"] = self.args.top_k - sequences = agent_model.generate(**generation_kwargs) + generation_model = unwrap_model(agent_model) + sequences = generation_model.generate(**generation_kwargs) if sequences.size(1) <= prompt_len: raise RuntimeError("Model produced an empty completion during rollout.") diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index e7fe486..923c6ef 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -1193,7 +1193,7 @@ def _generate_completions( kwargs.pop("do_sample", None) generation_kwargs.update(kwargs) - generation_output = agent.generate(**generation_kwargs) + generation_output = agent_module.generate(**generation_kwargs) except Exception as e: raise ValueError(f"Generation failed: {str(e)}") From 2b2564f9081ddd6c4bf7ab175b381b4328d48844 Mon Sep 17 00:00:00 2001 From: N!no Date: Sun, 15 Feb 2026 23:29:22 -0500 Subject: [PATCH 13/21] ud --- comlrl/trainers/actor_critic/ac_base.py | 102 +++++++++++- comlrl/trainers/actor_critic/iac.py | 36 ++-- comlrl/trainers/actor_critic/maac.py | 74 +++++---- comlrl/trainers/reinforce/magrpo.py | 157 +++++++++++++++--- docs/content/docs/dev/changelog.md | 2 +- .../docs/user-guide/multi-turn-training.md | 4 +- tests/test_parallel_agent_tasks.py | 86 ++++++++++ 7 files changed, 377 insertions(+), 84 deletions(-) create mode 100644 tests/test_parallel_agent_tasks.py diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index e5ca9f9..c175b31 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -1,5 +1,6 @@ from collections import defaultdict from typing import Any, Dict, List, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed import torch import wandb @@ -18,6 +19,51 @@ class ActorCriticTrainerBase: """Shared training utilities for actor-critic style trainers.""" + def _parallel_agent_mode_enabled(self) -> bool: + if str(getattr(self, "parallel_training", "")).lower() != "mp": + return False + num_agents = int(getattr(getattr(self, "args", None), "num_agents", 0) or 0) + if num_agents <= 1: + return False + devices = getattr(self, "agent_devices", None) + if not devices: + return True + unique = {str(device) for device in devices} + return len(unique) > 1 + + def _run_agent_tasks( + self, + fn, + *, + agent_indices: Optional[List[int]] = None, + parallel: Optional[bool] = None, + ) -> List[Any]: + num_agents = int(getattr(getattr(self, "args", None), "num_agents", 0) or 0) + indices = ( + list(agent_indices) + if agent_indices is not None + else list(range(max(num_agents, 0))) + ) + if not indices: + return [] + + use_parallel = ( + self._parallel_agent_mode_enabled() if parallel is None else bool(parallel) + ) + if not use_parallel or len(indices) <= 1: + return [fn(agent_idx) for agent_idx in indices] + + results: Dict[int, Any] = {} + max_workers = min(len(indices), max(1, len(indices))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(fn, agent_idx): agent_idx for agent_idx in indices + } + for future in as_completed(futures): + agent_idx = futures[future] + results[agent_idx] = future.result() + return [results[agent_idx] for agent_idx in indices] + def _filter_model_kwargs(self, cfg: Optional[Dict[str, Any]]) -> Dict[str, Any]: torch_dtype = None if isinstance(cfg, dict): @@ -226,10 +272,9 @@ def _process_buffer( self, agent_idx: int, buffer: List[Any], - epoch_metrics: Dict[str, List[float]], - ) -> None: + ) -> Dict[str, Any]: if not buffer: - return + return {"metric_values": {}, "log_metrics": {}} has_turn_idx = any( "turn_idx" in (getattr(s, "metadata", {}) or {}) for s in buffer @@ -242,13 +287,49 @@ def _process_buffer( buffer.clear() combined_log: Dict[str, float] = {} + metric_values: Dict[str, List[float]] = {} for t_idx in sorted(turn_groups.keys()): samples = turn_groups[t_idx] metrics = self._update(agent_idx, samples) tagged = self._tag_metrics(metrics, agent_idx, turn_idx=t_idx) combined_log.update(tagged) for key, value in tagged.items(): - epoch_metrics[key].append(value) + metric_values.setdefault(key, []).append(value) + return {"metric_values": metric_values, "log_metrics": combined_log} + + def _drain_ready_agent_buffers( + self, + ready_agents: List[int], + epoch_metrics: Dict[str, List[float]], + ) -> None: + if not ready_agents: + return + + unique_ready = sorted({int(idx) for idx in ready_agents}) + run_parallel = bool( + getattr( + self, + "_parallel_update_enabled", + self._parallel_agent_mode_enabled(), + ) + ) + + def _process(agent_idx: int) -> Dict[str, Any]: + return self._process_buffer(agent_idx, self.rollout_buffers[agent_idx]) + + results = self._run_agent_tasks( + _process, + agent_indices=unique_ready, + parallel=run_parallel, + ) + + combined_log: Dict[str, float] = {} + for result in results: + metric_values = result.get("metric_values", {}) + for key, values in metric_values.items(): + for value in values: + epoch_metrics[key].append(value) + combined_log.update(result.get("log_metrics", {})) if combined_log and self._should_log_train(): self._log_metrics(combined_log) @@ -256,21 +337,24 @@ def _process_buffer( def _run_batch(self, batch, epoch_metrics: Dict[str, List[float]]) -> None: for item in batch: rollouts = self._collect_rollouts(item) + ready_agents: List[int] = [] for sample in rollouts: agent_idx = sample.agent_idx buffer = self.rollout_buffers[agent_idx] buffer.append(sample) if len(buffer) >= self.args.rollout_buffer_size: - self._process_buffer(agent_idx, buffer, epoch_metrics) + ready_agents.append(agent_idx) + if ready_agents: + self._drain_ready_agent_buffers(ready_agents, epoch_metrics) if self.args.num_agents > 0: # Count joint-action reward evaluations (one per agent group). self.env_step += len(rollouts) // self.args.num_agents def _flush_buffers(self, epoch_metrics: Dict[str, List[float]]) -> None: - for agent_idx, buffer in enumerate(self.rollout_buffers): - if not buffer: - continue - self._process_buffer(agent_idx, buffer, epoch_metrics) + ready_agents = [ + agent_idx for agent_idx, buffer in enumerate(self.rollout_buffers) if buffer + ] + self._drain_ready_agent_buffers(ready_agents, epoch_metrics) def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 1e2e422..c93de4a 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -572,17 +572,23 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: if num_turns > 1: return self._collect_rollouts_multi_turn(item, num_turns) - prompts: List[str] = [] - completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] num_ret = int(getattr(self.args, "num_generations", 1)) + turn_prompts = [ + self._resolve_turn_prompt(item, agent_idx) + for agent_idx in range(self.args.num_agents) + ] - for agent_idx, agent_model in enumerate(self.agents): - prompt = self._resolve_turn_prompt(item, agent_idx) + def _generate_agent(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] + prompt = turn_prompts[agent_idx] gen = self._generate_rollout(agent_model, prompt, agent_idx, num_ret) - completions_per_agent.append(gen["completions"]) - rollout_data.append({"agent_idx": agent_idx, **gen}) - prompts.append(prompt) + return {"agent_idx": agent_idx, **gen} + + rollout_data = self._run_agent_tasks(_generate_agent) + prompts: List[str] = [entry["prompt"] for entry in rollout_data] + completions_per_agent: List[List[str]] = [ + entry["completions"] for entry in rollout_data + ] rewards = call_reward_function( self.reward_func, @@ -689,13 +695,17 @@ def _collect_rollouts_multi_turn( ] completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] - for agent_idx, agent_model in enumerate(self.agents): + for agent_idx, prompt in enumerate(turn_prompts): + prompt_history[agent_idx].append(prompt) + + def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] prompt = turn_prompts[agent_idx] gen = self._generate_rollout(agent_model, prompt, agent_idx, num_ret=1) - completions_per_agent.append(gen["completions"]) - rollout_data.append({"agent_idx": agent_idx, **gen}) - prompt_history[agent_idx].append(prompt) + return {"agent_idx": agent_idx, **gen} + + rollout_data = self._run_agent_tasks(_generate_agent_turn) + completions_per_agent = [entry["completions"] for entry in rollout_data] rewards = call_reward_function( self.reward_func, diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 5d865cc..216b408 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -171,6 +171,7 @@ def __init__( ) self.device = self.agent_devices[0] self.dist_env = TorchrunScheduler.mp_context(self.device) + self._parallel_update_enabled = False tokenizers = resolve_tokenizers(agent_model, tokenizer, agents) if isinstance(tokenizers, list): @@ -496,26 +497,31 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: if num_turns > 1: return self._collect_rollouts_multi_turn(item, num_turns) - prompts: List[str] = [] - completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] num_ret = int(self.args.num_generations) + turn_prompts = [ + self._resolve_turn_prompt(item, agent_idx) + for agent_idx in range(self.args.num_agents) + ] - for agent_idx, agent_model in enumerate(self.agents): - prompt = self._resolve_turn_prompt(item, agent_idx) + def _generate_agent(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] + prompt = turn_prompts[agent_idx] gen = self._generate(agent_model, prompt, agent_idx) - prompts.append(prompt) - completions_per_agent.append(gen["completions"]) - rollout_data.append( - { - "agent_idx": agent_idx, - "prompt": prompt, - "prompt_len": gen["prompt_len"], - "sequences": gen["sequences"], - "attention_mask": gen["attention_mask"], - "response_lens": gen["response_lens"], - } - ) + return { + "agent_idx": agent_idx, + "prompt": prompt, + "prompt_len": gen["prompt_len"], + "sequences": gen["sequences"], + "attention_mask": gen["attention_mask"], + "response_lens": gen["response_lens"], + "completion_texts": gen["completions"], + } + + rollout_data = self._run_agent_tasks(_generate_agent) + prompts: List[str] = [entry["prompt"] for entry in rollout_data] + completions_per_agent: List[List[str]] = [ + entry["completion_texts"] for entry in rollout_data + ] rewards = call_reward_function( self.reward_func, @@ -676,23 +682,27 @@ def _collect_rollouts_multi_turn( ] completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] - for agent_idx, agent_model in enumerate(self.agents): + for agent_idx, prompt in enumerate(turn_prompts): + prompt_history[agent_idx].append(prompt) + + def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] prompt = turn_prompts[agent_idx] gen = self._generate(agent_model, prompt, agent_idx) - completions_per_agent.append(gen["completions"]) - rollout_data.append( - { - "agent_idx": agent_idx, - "prompt": prompt, - "prompt_len": gen["prompt_len"], - "sequences": gen["sequences"], - "attention_mask": gen["attention_mask"], - "response_lens": gen["response_lens"], - "completion_texts": gen["completions"], - } - ) - prompt_history[agent_idx].append(prompt) + return { + "agent_idx": agent_idx, + "prompt": prompt, + "prompt_len": gen["prompt_len"], + "sequences": gen["sequences"], + "attention_mask": gen["attention_mask"], + "response_lens": gen["response_lens"], + "completion_texts": gen["completions"], + } + + rollout_data = self._run_agent_tasks(_generate_agent_turn) + completions_per_agent = [ + entry["completion_texts"] for entry in rollout_data + ] rewards = call_reward_function( self.reward_func, diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 923c6ef..cb641f8 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -3,6 +3,7 @@ import random from dataclasses import dataclass import itertools +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Type, Sequence import numpy as np @@ -324,6 +325,48 @@ def __init__( if isinstance(out, dict) and "verbose" in out: self.verbose = bool(out.get("verbose")) + def _parallel_agent_mode_enabled(self) -> bool: + if str(getattr(self, "parallel_training", "")).lower() != "mp": + return False + if int(getattr(self, "num_agents", 0) or 0) <= 1: + return False + devices = getattr(self, "agent_devices", None) + if not devices: + return True + unique = {str(device) for device in devices} + return len(unique) > 1 + + def _run_agent_tasks( + self, + fn, + *, + agent_indices: Optional[List[int]] = None, + parallel: Optional[bool] = None, + ) -> List[Any]: + indices = ( + list(agent_indices) + if agent_indices is not None + else list(range(self.num_agents)) + ) + if not indices: + return [] + use_parallel = ( + self._parallel_agent_mode_enabled() if parallel is None else bool(parallel) + ) + if not use_parallel or len(indices) <= 1: + return [fn(agent_idx) for agent_idx in indices] + + results: Dict[int, Any] = {} + max_workers = min(len(indices), max(1, len(indices))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(fn, agent_idx): agent_idx for agent_idx in indices + } + for future in as_completed(futures): + agent_idx = futures[future] + results[agent_idx] = future.result() + return [results[agent_idx] for agent_idx in indices] + def _init_wandb(self): """Initialize Weights & Biases for tracking with multi-turn config.""" if not self.dist_env.is_main: @@ -617,8 +660,10 @@ def _evaluate_sample( "External transition must return a list or tuple of external prompts for each agent" ) - # Generate and extract one completion from each agent for evaluation - for agent_idx in range(self.num_agents): + # Generate and extract one completion from each agent for evaluation. + # In MP mode this executes generation concurrently, while keeping + # synchronous consumption order by agent index. + def _generate_eval_agent(agent_idx: int) -> Dict[str, Any]: agent_completions = self._generate_completions_with_external_prompts( self.agents[agent_idx], [batch_item], @@ -627,12 +672,17 @@ def _evaluate_sample( max_new_tokens=self.args.max_new_tokens, external_prompts=agent_external_prompts[agent_idx], ) - # Extract the completion directly - completion = agent_completions["completions"][0][0] - # Record prompt used this turn - used_prompt = agent_completions["prompts"][0] - eval_prompt_history[agent_idx].append(used_prompt) - agent_sample_completions[agent_idx].append(completion) + return { + "agent_idx": agent_idx, + "completion": agent_completions["completions"][0][0], + "used_prompt": agent_completions["prompts"][0], + } + + eval_outputs = self._run_agent_tasks(_generate_eval_agent) + for output in eval_outputs: + agent_idx = int(output["agent_idx"]) + eval_prompt_history[agent_idx].append(output["used_prompt"]) + agent_sample_completions[agent_idx].append(output["completion"]) # Compute immediate reward at this turn (single joint sample) agent_completions_for_reward = [ @@ -764,9 +814,12 @@ def train(self, **kwargs): **kwargs, ) - for agent_idx, buffer in enumerate(self.rollout_buffers): - if buffer: - self._process_buffer(agent_idx, buffer) + ready_agents = [ + agent_idx + for agent_idx, buffer in enumerate(self.rollout_buffers) + if buffer + ] + self._drain_ready_buffers(ready_agents) # Log per-turn epoch averages inline (avoid custom system/* metrics) epoch_log: Dict[str, Any] = {} @@ -824,9 +877,8 @@ def build_node( prompt_history_per_agent: Optional[List[List[str]]] = None, response_history_per_agent: Optional[List[List[str]]] = None, ): - comps_per_agent = [] - for agent_idx in range(self.num_agents): - comps = self._generate_completions_with_external_prompts( + def _generate_agent_node(agent_idx: int) -> Dict[str, Any]: + return self._generate_completions_with_external_prompts( self.agents[agent_idx], [batch_item], agent_idx=agent_idx, @@ -837,7 +889,8 @@ def build_node( ), **kwargs, ) - comps_per_agent.append(comps) + + comps_per_agent = self._run_agent_tasks(_generate_agent_node) agent_completions_list = [ comps_per_agent[i]["completions"][0] for i in range(self.num_agents) @@ -1073,10 +1126,22 @@ def post_order_update(node): post_order_update(root) + grouped_pending: Dict[int, List[Tuple[int, NodeSample]]] = {} for agent_idx, samples in enumerate(pending_samples): samples.sort(key=lambda s: s.node_env_step) for sample in samples: - self._append_to_buffer(agent_idx, sample) + step = int(sample.node_env_step) + grouped_pending.setdefault(step, []).append((agent_idx, sample)) + + for step in sorted(grouped_pending.keys()): + ready_agents: List[int] = [] + step_samples = sorted(grouped_pending[step], key=lambda x: x[0]) + for agent_idx, sample in step_samples: + buffer = self.rollout_buffers[agent_idx] + buffer.append(sample) + if len(buffer) >= int(self.args.rollout_buffer_size): + ready_agents.append(agent_idx) + self._drain_ready_buffers(ready_agents) # Build per-turn batch summary batch_loss = float(np.mean(np.abs(root.get("returns") or [0.0]))) @@ -1392,12 +1457,6 @@ def _pack_completions_for_buffer( "completion_input_ids": packed_completion_ids, } - def _append_to_buffer(self, agent_idx: int, sample: NodeSample) -> None: - buffer = self.rollout_buffers[agent_idx] - buffer.append(sample) - if len(buffer) >= int(self.args.rollout_buffer_size): - self._process_buffer(agent_idx, buffer) - def _should_log_train(self, step: int) -> bool: interval = int(getattr(self.args, "logging_steps", 1)) if interval <= 1: @@ -1411,18 +1470,22 @@ def _should_log_train(self, step: int) -> bool: return True return False - def _process_buffer(self, agent_idx: int, buffer: List[NodeSample]) -> None: + def _process_buffer( + self, agent_idx: int, buffer: List[NodeSample] + ) -> Dict[str, Any]: if not buffer: - return + return {"log_entries": []} turn_groups: Dict[int, List[NodeSample]] = {} for sample in buffer: t_idx = int(sample.turn_idx) turn_groups.setdefault(t_idx, []).append(sample) buffer.clear() + + log_entries: List[Dict[str, Any]] = [] for t_idx in sorted(turn_groups.keys()): samples = turn_groups[t_idx] self._update_from_samples(agent_idx, samples) - if self.wandb_initialized and wandb.run is not None and samples: + if samples: batch_log: Dict[str, Any] = {} prefix = f"turn_{t_idx + 1}/" batch_log[prefix + "reward_mean"] = float( @@ -1432,8 +1495,48 @@ def _process_buffer(self, agent_idx: int, buffer: List[NodeSample]) -> None: np.mean([s.node_mean_return for s in samples]) ) step = max(s.node_env_step for s in samples) - if self._should_log_train(step): - wandb.log(batch_log, step=step) + log_entries.append( + { + "agent_idx": int(agent_idx), + "step": int(step), + "metrics": batch_log, + } + ) + return {"log_entries": log_entries} + + def _drain_ready_buffers(self, ready_agents: List[int]) -> None: + if not ready_agents: + return + unique_ready = sorted({int(idx) for idx in ready_agents}) + run_parallel = self._parallel_agent_mode_enabled() + + def _process(agent_idx: int) -> Dict[str, Any]: + return self._process_buffer(agent_idx, self.rollout_buffers[agent_idx]) + + results = self._run_agent_tasks( + _process, + agent_indices=unique_ready, + parallel=run_parallel, + ) + + if not (self.wandb_initialized and wandb.run is not None): + return + + all_log_entries: List[Dict[str, Any]] = [] + for result in results: + all_log_entries.extend(result.get("log_entries", [])) + + all_log_entries.sort( + key=lambda entry: ( + int(entry.get("step", 0)), + int(entry.get("agent_idx", 0)), + ) + ) + for entry in all_log_entries: + step = int(entry.get("step", self.env_step)) + metrics = entry.get("metrics") or {} + if metrics and self._should_log_train(step): + wandb.log(metrics, step=step) def _update_from_samples(self, agent_idx: int, samples: List[NodeSample]) -> None: if not samples: diff --git a/docs/content/docs/dev/changelog.md b/docs/content/docs/dev/changelog.md index 12ab69b..5c4d82f 100644 --- a/docs/content/docs/dev/changelog.md +++ b/docs/content/docs/dev/changelog.md @@ -7,7 +7,7 @@ weight: 3 ## Version 1.3.7 -- Remove the redundant sampling hyperparameters in algorithms, change the sampling logics. +- Remove the redundant sampling hyperparameters in algorithms. - Allow multi-gpu training with MP and DDP. ## Version 1.3.6 diff --git a/docs/content/docs/user-guide/multi-turn-training.md b/docs/content/docs/user-guide/multi-turn-training.md index dbf7061..73520f8 100644 --- a/docs/content/docs/user-guide/multi-turn-training.md +++ b/docs/content/docs/user-guide/multi-turn-training.md @@ -1,6 +1,6 @@ --- -title: Multi-Turn Environment -linkTitle: Multi-Turn Environment +title: Multi-Turn Interaction +linkTitle: Multi-Turn Interaction weight: 5 math: true --- diff --git a/tests/test_parallel_agent_tasks.py b/tests/test_parallel_agent_tasks.py new file mode 100644 index 0000000..b20459d --- /dev/null +++ b/tests/test_parallel_agent_tasks.py @@ -0,0 +1,86 @@ +import time +from types import SimpleNamespace + +from comlrl.trainers.actor_critic.ac_base import ActorCriticTrainerBase +from comlrl.trainers.actor_critic.iac import IACTrainer +from comlrl.trainers.actor_critic.maac import MAACTrainer +from comlrl.trainers.reinforce.magrpo import MAGRPOTrainer + + +class _DummyACTrainer(ActorCriticTrainerBase): + def __init__(self, *, parallel_training: str, num_agents: int = 2): + self.parallel_training = parallel_training + self.args = SimpleNamespace(num_agents=num_agents) + self.agent_devices = [f"cuda:{idx}" for idx in range(num_agents)] + + +def test_ac_run_agent_tasks_keeps_index_order_in_mp(): + trainer = _DummyACTrainer(parallel_training="mp", num_agents=2) + completion_order = [] + + def _task(agent_idx: int) -> str: + if agent_idx == 0: + time.sleep(0.05) + completion_order.append(agent_idx) + return f"agent-{agent_idx}" + + outputs = trainer._run_agent_tasks(_task) + assert outputs == ["agent-0", "agent-1"] + assert completion_order == [1, 0] + + +def test_ac_run_agent_tasks_is_sequential_when_mp_disabled(): + trainer = _DummyACTrainer(parallel_training="ddp", num_agents=2) + completion_order = [] + + def _task(agent_idx: int) -> int: + completion_order.append(agent_idx) + return agent_idx + + outputs = trainer._run_agent_tasks(_task) + assert outputs == [0, 1] + assert completion_order == [0, 1] + + +def test_iac_parallel_updates_enabled_only_for_mp_mode(): + trainer = IACTrainer.__new__(IACTrainer) + trainer.args = SimpleNamespace(num_agents=2) + trainer.agent_devices = ["cuda:0", "cuda:1"] + trainer.parallel_training = "ddp" + assert trainer._parallel_agent_mode_enabled() is False + trainer.parallel_training = "mp" + assert trainer._parallel_agent_mode_enabled() is True + + +def test_magrpo_run_agent_tasks_keeps_index_order_in_mp(): + trainer = MAGRPOTrainer.__new__(MAGRPOTrainer) + trainer.parallel_training = "mp" + trainer.num_agents = 2 + trainer.agent_devices = ["cuda:0", "cuda:1"] + completion_order = [] + + def _task(agent_idx: int) -> str: + if agent_idx == 0: + time.sleep(0.05) + completion_order.append(agent_idx) + return f"agent-{agent_idx}" + + outputs = trainer._run_agent_tasks(_task) + assert outputs == ["agent-0", "agent-1"] + assert completion_order == [1, 0] + + +def test_maac_parallel_updates_are_always_serialized(): + trainer = MAACTrainer.__new__(MAACTrainer) + trainer._parallel_update_enabled = False + trainer.parallel_training = "mp" + trainer.args = SimpleNamespace(num_agents=2) + trainer.agent_devices = ["cuda:0", "cuda:1"] + run_parallel = bool( + getattr( + trainer, + "_parallel_update_enabled", + trainer._parallel_agent_mode_enabled(), + ) + ) + assert run_parallel is False From c68231da40c3cb28eeffa6b0b630527ae37e4b69 Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 00:00:59 -0500 Subject: [PATCH 14/21] ud --- comlrl/schedulers/__init__.py | 3 +- comlrl/schedulers/torchrun_scheduler.py | 67 --------- comlrl/trainers/actor_critic/ac_base.py | 61 +------- comlrl/trainers/actor_critic/iac.py | 56 +++----- comlrl/trainers/actor_critic/maac.py | 60 +++----- comlrl/trainers/reinforce/magrpo.py | 98 +++---------- comlrl/utils/__init__.py | 4 - comlrl/utils/distributed.py | 130 ++---------------- docs/content/docs/dev/changelog.md | 2 +- .../docs/user-guide/multi-gpu-training.md | 44 ++---- tests/test_config_constraints.py | 3 + tests/test_distributed_metrics.py | 115 ++++------------ tests/test_parallel_agent_tasks.py | 4 +- tests/test_torchrun_scheduler.py | 75 ---------- 14 files changed, 128 insertions(+), 594 deletions(-) delete mode 100644 comlrl/schedulers/torchrun_scheduler.py delete mode 100644 tests/test_torchrun_scheduler.py diff --git a/comlrl/schedulers/__init__.py b/comlrl/schedulers/__init__.py index 69af71e..fcd442b 100644 --- a/comlrl/schedulers/__init__.py +++ b/comlrl/schedulers/__init__.py @@ -1,4 +1,3 @@ from .device_scheduler import DeviceScheduler -from .torchrun_scheduler import TorchrunScheduler -__all__ = ["DeviceScheduler", "TorchrunScheduler"] +__all__ = ["DeviceScheduler"] diff --git a/comlrl/schedulers/torchrun_scheduler.py b/comlrl/schedulers/torchrun_scheduler.py deleted file mode 100644 index 9b5e600..0000000 --- a/comlrl/schedulers/torchrun_scheduler.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import os -from typing import Optional - -import torch - -from comlrl.utils.distributed import DistributedContext, init_distributed, local_context - - -class TorchrunScheduler: - """Resolve and initialize process-level parallel execution mode.""" - - _VALID_MODES = {"auto", "ddp", "mp"} - - @staticmethod - def world_size_from_env() -> int: - try: - return int(os.environ.get("WORLD_SIZE", "1")) - except (TypeError, ValueError): - return 1 - - @staticmethod - def _missing_ddp_env_vars() -> list[str]: - required = ("WORLD_SIZE", "RANK", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT") - return [k for k in required if str(os.environ.get(k, "")).strip() == ""] - - @classmethod - def resolve_mode(cls, requested_mode: Optional[str]) -> str: - mode = str(requested_mode or "auto").strip().lower() - if mode not in cls._VALID_MODES: - raise ValueError("parallel_training must be one of: auto, ddp, mp.") - - world_size = cls.world_size_from_env() - if mode == "auto": - if world_size <= 1: - return "mp" - # In shared cluster environments WORLD_SIZE may be exported globally. - # Only switch to DDP when torchrun-style variables are complete. - return "ddp" if not cls._missing_ddp_env_vars() else "mp" - if mode == "ddp": - if world_size <= 1: - raise ValueError( - "parallel_training='ddp' requires WORLD_SIZE>1. " - "Use torchrun --nproc_per_node=... to launch." - ) - missing = cls._missing_ddp_env_vars() - if missing: - raise ValueError( - "parallel_training='ddp' requires torchrun environment variables. " - f"Missing: {', '.join(missing)}." - ) - if mode == "mp" and world_size > 1: - raise ValueError( - "parallel_training='mp' requires WORLD_SIZE=1 (single process)." - ) - return mode - - @staticmethod - def ddp_context() -> DistributedContext: - return init_distributed() - - @staticmethod - def mp_context( - device: Optional[torch.device] = None, - ) -> DistributedContext: - return local_context(device) diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index c175b31..f9363ac 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -6,14 +6,6 @@ import wandb from tqdm import tqdm # type: ignore from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from comlrl.utils.distributed import barrier as dist_barrier -from comlrl.utils.distributed import reduce_metrics_dict - -try: - from datasets import IterableDataset as HFIterableDataset -except Exception: # pragma: no cover - HFIterableDataset = None class ActorCriticTrainerBase: @@ -54,7 +46,7 @@ def _run_agent_tasks( return [fn(agent_idx) for agent_idx in indices] results: Dict[int, Any] = {} - max_workers = min(len(indices), max(1, len(indices))) + max_workers = len(indices) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(fn, agent_idx): agent_idx for agent_idx in indices @@ -237,23 +229,11 @@ def _tag_metrics( prefix = f"turn_{turn_idx + 1}/" return {prefix + key: value for key, value in metrics.items()} - def _log_metrics( - self, metrics: Dict[str, float], *, synchronize: bool = True - ) -> None: + def _log_metrics(self, metrics: Dict[str, float]) -> None: if not metrics: return - dist_env = getattr(self, "dist_env", None) - metrics_to_log = metrics - if dist_env is not None and dist_env.enabled and synchronize: - metrics_to_log = reduce_metrics_dict(metrics, dist_env) - if not dist_env.is_main: - return - elif dist_env is not None and dist_env.enabled and not dist_env.is_main: - return - if not metrics_to_log: - return if self.wandb_initialized and wandb is not None: - wandb.log(metrics_to_log, step=self.env_step) + wandb.log(metrics, step=self.env_step) def _should_log_train(self) -> bool: interval = int(getattr(self.args, "logging_steps", 1)) @@ -359,27 +339,10 @@ def _flush_buffers(self, epoch_metrics: Dict[str, List[float]]) -> None: def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Training requires a dataset.") - dist_env = getattr(self, "dist_env", None) - sampler = None - if ( - dist_env is not None - and dist_env.enabled - and ( - HFIterableDataset is None - or not isinstance(self.train_dataset, HFIterableDataset) - ) - ): - sampler = DistributedSampler( - self.train_dataset, - num_replicas=dist_env.world_size, - rank=dist_env.rank, - shuffle=False, - ) return DataLoader( self.train_dataset, batch_size=1, shuffle=False, - sampler=sampler, collate_fn=lambda batch: batch, ) @@ -394,9 +357,6 @@ def get_eval_dataloader(self) -> Optional[DataLoader]: ) def evaluate(self) -> Dict[str, float]: - dist_env = getattr(self, "dist_env", None) - if dist_env is not None and dist_env.enabled and not dist_env.is_main: - return {} if self.eval_dataset is None: return {} @@ -428,7 +388,7 @@ def evaluate(self) -> Dict[str, float]: eval_log[f"eval/turn_{turn_idx + 1}/{key}"] = value if eval_log: - self._log_metrics(eval_log, synchronize=False) + self._log_metrics(eval_log) return eval_log def train(self) -> None: @@ -437,9 +397,6 @@ def train(self) -> None: for epoch in range(total_epochs): epoch_metrics = defaultdict(list) dataloader = self.get_train_dataloader() - sampler = getattr(dataloader, "sampler", None) - if isinstance(sampler, DistributedSampler): - sampler.set_epoch(epoch) it = self._iter_dataloader(dataloader, epoch, total_epochs) for batch_idx, batch in it: if ( @@ -447,15 +404,7 @@ def train(self) -> None: and self.args.eval_interval > 0 and batch_idx % int(self.args.eval_interval) == 0 ): - dist_env = getattr(self, "dist_env", None) - if dist_env is not None and dist_env.enabled: - # Keep all ranks step-aligned to avoid DDP hangs during eval windows. - dist_barrier(dist_env) - if dist_env.is_main: - self.evaluate() - dist_barrier(dist_env) - else: - self.evaluate() + self.evaluate() self._run_batch(batch, epoch_metrics) self._flush_buffers(epoch_metrics) diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index c93de4a..4b93fe1 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -10,11 +10,8 @@ from datasets import Dataset, IterableDataset from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase from comlrl.models.actor_critic import CausalLMWithValueHead -from comlrl.schedulers import DeviceScheduler, TorchrunScheduler -from comlrl.utils.distributed import ( - unwrap_model, - wrap_ddp, -) +from comlrl.schedulers import DeviceScheduler +from comlrl.utils.distributed import local_context, unwrap_model from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -50,7 +47,7 @@ class IACConfig: value_head_hidden_dim: Optional[int] = None num_agents: int = 2 num_turns: int = 2 - parallel_training: str = "auto" + parallel_training: str = "mp" agent_devices: Optional[Union[str, Sequence[str]]] = None critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False @@ -90,9 +87,9 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") - mode = str(self.parallel_training or "auto").lower() - if mode not in {"auto", "ddp", "mp"}: - raise ValueError("parallel_training must be one of: auto, ddp, mp.") + mode = str(self.parallel_training or "mp").lower() + if mode != "mp": + raise ValueError("parallel_training only supports: mp.") @dataclass @@ -172,30 +169,19 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} self.critic_type = (self.args.critic_type or "v").lower() - self.parallel_training = TorchrunScheduler.resolve_mode( - getattr(self.args, "parallel_training", "auto") + self.parallel_training = ( + str(getattr(self.args, "parallel_training", "mp")).strip().lower() ) - if self.parallel_training == "ddp": - if ( - getattr(self.args, "agent_devices", None) is not None - or getattr(self.args, "critic_devices", None) is not None - ): - raise ValueError( - "agent_devices/critic_devices are only valid in parallel_training='mp'." - ) - self.dist_env = TorchrunScheduler.ddp_context() - self.device = self.dist_env.device - self.agent_devices = [self.device] * self.args.num_agents - self.critic_devices = [self.device] * self.args.num_agents - else: - self.agent_devices, self.critic_devices = DeviceScheduler.assign_devices( - self.args.num_agents, - getattr(self.args, "agent_devices", None), - getattr(self.args, "critic_devices", None), - use_separate_critic=self.args.use_separate_critic, - ) - self.device = self.agent_devices[0] - self.dist_env = TorchrunScheduler.mp_context(self.device) + if self.parallel_training != "mp": + raise ValueError("parallel_training only supports: mp.") + self.agent_devices, self.critic_devices = DeviceScheduler.assign_devices( + self.args.num_agents, + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + use_separate_critic=self.args.use_separate_critic, + ) + self.device = self.agent_devices[0] + self.dist_env = local_context(self.device) self.agents: List[CausalLMWithValueHead] = [] self.critics: List[CausalLMWithValueHead] = [] @@ -308,10 +294,6 @@ def __init__( else: apply_tokenizer_specials(self.tokenizer, [*self.agents, *self.critics]) - if self.dist_env.enabled: - self.agents = [wrap_ddp(agent, self.dist_env) for agent in self.agents] - self.critics = [wrap_ddp(critic, self.dist_env) for critic in self.critics] - self.agent_optimizers = [] self.critic_optimizers = [] @@ -1101,7 +1083,7 @@ def _update( return averaged def save_model(self, output_dir: str) -> None: - if self.dist_env.enabled and not self.dist_env.is_main: + if not self.dist_env.is_main: return os.makedirs(output_dir, exist_ok=True) if self.args.num_agents == 1: diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 216b408..83d6fb9 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -11,11 +11,8 @@ from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase import wandb from comlrl.models.actor_critic import CausalLMWithValueHead -from comlrl.schedulers import DeviceScheduler, TorchrunScheduler -from comlrl.utils.distributed import ( - unwrap_model, - wrap_ddp, -) +from comlrl.schedulers import DeviceScheduler +from comlrl.utils.distributed import local_context, unwrap_model from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -46,7 +43,7 @@ class MAACConfig: num_agents: int = 2 num_generations: int = 1 num_turns: int = 2 - parallel_training: str = "auto" + parallel_training: str = "mp" agent_devices: Optional[Union[str, Sequence[str]]] = None critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False @@ -84,9 +81,9 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") - mode = str(self.parallel_training or "auto").lower() - if mode not in {"auto", "ddp", "mp"}: - raise ValueError("parallel_training must be one of: auto, ddp, mp.") + mode = str(self.parallel_training or "mp").lower() + if mode != "mp": + raise ValueError("parallel_training only supports: mp.") class MAACTrainer(ActorCriticTrainerBase): @@ -145,32 +142,21 @@ def __init__( self.eval_dataset = eval_dataset self.metrics_callback = metrics_callback self.model_config = model_config or {} - self.parallel_training = TorchrunScheduler.resolve_mode( - getattr(self.args, "parallel_training", "auto") + self.parallel_training = ( + str(getattr(self.args, "parallel_training", "mp")).strip().lower() ) - if self.parallel_training == "ddp": - if ( - getattr(self.args, "agent_devices", None) is not None - or getattr(self.args, "critic_devices", None) is not None - ): - raise ValueError( - "agent_devices/critic_devices are only valid in parallel_training='mp'." - ) - self.dist_env = TorchrunScheduler.ddp_context() - self.device = self.dist_env.device - self.agent_devices = [self.device] * self.args.num_agents - self.critic_device = self.device - else: - self.agent_devices = DeviceScheduler.resolve_devices( - getattr(self.args, "agent_devices", None), - self.args.num_agents, - kind="agent_devices", - ) - self.critic_device = DeviceScheduler.assign_shared_critic_device( - self.agent_devices, getattr(self.args, "critic_devices", None) - ) - self.device = self.agent_devices[0] - self.dist_env = TorchrunScheduler.mp_context(self.device) + if self.parallel_training != "mp": + raise ValueError("parallel_training only supports: mp.") + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.args.num_agents, + kind="agent_devices", + ) + self.critic_device = DeviceScheduler.assign_shared_critic_device( + self.agent_devices, getattr(self.args, "critic_devices", None) + ) + self.device = self.agent_devices[0] + self.dist_env = local_context(self.device) self._parallel_update_enabled = False tokenizers = resolve_tokenizers(agent_model, tokenizer, agents) @@ -271,10 +257,6 @@ def __init__( else: apply_tokenizer_specials(self.tokenizer, [*self.agents, self.critics[0]]) - if self.dist_env.enabled: - self.agents = [wrap_ddp(agent, self.dist_env) for agent in self.agents] - self.critics = [wrap_ddp(self.critics[0], self.dist_env)] - self.formatters = build_formatters(formatters, self.args.num_agents) try: self._reward_signature = inspect.signature(reward_func) @@ -1027,7 +1009,7 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: print(f"Epoch {epoch + 1}/{total_epochs} metrics: {to_print}") def save_model(self, output_dir: str) -> None: - if self.dist_env.enabled and not self.dist_env.is_main: + if not self.dist_env.is_main: return os.makedirs(output_dir, exist_ok=True) for agent_idx, actor in enumerate(self.agents): diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index cb641f8..6593b57 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -11,16 +11,13 @@ import wandb from datasets import Dataset, IterableDataset from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizerBase -from comlrl.schedulers import DeviceScheduler, TorchrunScheduler +from comlrl.schedulers import DeviceScheduler from comlrl.utils.distributed import ( - barrier as dist_barrier, - reduce_metrics_dict, + local_context, unwrap_model, - wrap_ddp, ) from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import infer_model_name, resolve_model_sources @@ -40,7 +37,7 @@ class MAGRPOConfig: agent_learning_rate: float = 5.0e-6 logging_steps: int = 50 num_agents: int = 2 - parallel_training: str = "auto" + parallel_training: str = "mp" agent_devices: Optional[Union[str, Sequence[str]]] = None # Sampling/generation @@ -90,9 +87,9 @@ def __post_init__(self) -> None: self.train_batch_size = self.rollout_buffer_size if self.train_batch_size < 1: raise ValueError("train_batch_size must be >= 1.") - mode = str(self.parallel_training or "auto").lower() - if mode not in {"auto", "ddp", "mp"}: - raise ValueError("parallel_training must be one of: auto, ddp, mp.") + mode = str(self.parallel_training or "mp").lower() + if mode != "mp": + raise ValueError("parallel_training only supports: mp.") @dataclass @@ -159,19 +156,13 @@ def __init__( args: Optional[MAGRPOConfig] = None, ): self.args = args if args is not None else self.default_config_cls() - self.parallel_training = TorchrunScheduler.resolve_mode( - getattr(self.args, "parallel_training", "auto") + self.parallel_training = ( + str(getattr(self.args, "parallel_training", "mp")).strip().lower() ) - if self.parallel_training == "ddp": - if getattr(self.args, "agent_devices", None) is not None: - raise ValueError( - "agent_devices is only valid in parallel_training='mp'." - ) - self.dist_env = TorchrunScheduler.ddp_context() - self.device = self.dist_env.device - else: - self.dist_env = TorchrunScheduler.mp_context() - self.device = self.dist_env.device + if self.parallel_training != "mp": + raise ValueError("parallel_training only supports: mp.") + self.dist_env = local_context() + self.device = self.dist_env.device if agent_model is None and agents is None: raise ValueError("Either agent_model or agents must be provided.") @@ -217,16 +208,13 @@ def __init__( self.num_agents = expected_count self.model_name = model_name - if self.parallel_training == "ddp": - self.agent_devices = [self.device] * self.num_agents - else: - self.agent_devices = DeviceScheduler.resolve_devices( - getattr(self.args, "agent_devices", None), - self.num_agents, - kind="agent_devices", - ) - self.device = self.agent_devices[0] - self.dist_env = TorchrunScheduler.mp_context(self.device) + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.num_agents, + kind="agent_devices", + ) + self.device = self.agent_devices[0] + self.dist_env = local_context(self.device) if actor_sources and all(isinstance(src, str) for src in actor_sources): from transformers import AutoModelForCausalLM @@ -288,9 +276,6 @@ def __init__( self.eval_aggregator = eval_aggregator self.external_transition = external_transition - if self.dist_env.enabled: - self.agents = [wrap_ddp(agent, self.dist_env) for agent in self.agents] - self.optimizers = [ torch.optim.AdamW( agent.parameters(), @@ -484,23 +469,12 @@ def get_train_dataloader(self) -> DataLoader: """Returns the training DataLoader.""" if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - sampler = None - if self.dist_env.enabled and not isinstance( - self.train_dataset, IterableDataset - ): - sampler = DistributedSampler( - self.train_dataset, - num_replicas=self.dist_env.world_size, - rank=self.dist_env.rank, - shuffle=False, - ) return DataLoader( self.train_dataset, batch_size=1, collate_fn=lambda examples: examples, shuffle=False, - sampler=sampler, drop_last=False, num_workers=0, ) @@ -529,8 +503,6 @@ def evaluate(self, num_eval_samples: int = 4) -> Dict[str, float]: Returns: Dictionary containing evaluation metrics """ - if self.dist_env.enabled and not self.dist_env.is_main: - return {} if self.eval_dataset is None: return {} @@ -772,10 +744,7 @@ def train(self, **kwargs): ] # immediate rewards epoch_turn_returns = [[] for _ in range(self.args.num_turns)] # returns dl = self.get_train_dataloader() - sampler = getattr(dl, "sampler", None) - if isinstance(sampler, DistributedSampler): - sampler.set_epoch(epoch) - if getattr(self, "verbose", True) and self.dist_env.is_main: + if getattr(self, "verbose", True): it = enumerate( tqdm( dl, @@ -790,19 +759,7 @@ def train(self, **kwargs): if int(self.args.eval_interval) > 0 and ( batch_idx % int(self.args.eval_interval) == 0 ): - if self.dist_env.enabled: - # Keep all ranks synchronized around evaluation windows. - dist_barrier(self.dist_env) - if self.dist_env.is_main: - # evaluate() already logs its metrics. - _ = self.evaluate( - num_eval_samples=int(self.args.eval_num_samples) - ) - dist_barrier(self.dist_env) - else: - _ = self.evaluate( - num_eval_samples=int(self.args.eval_num_samples) - ) + _ = self.evaluate(num_eval_samples=int(self.args.eval_num_samples)) # Process single batch item (batch_size=1 enforced) batch_item = batch[0] @@ -834,16 +791,7 @@ def train(self, **kwargs): np.mean(epoch_turn_returns[turn_idx]) ) - if self.dist_env.enabled: - reduced_epoch_log = reduce_metrics_dict(epoch_log, self.dist_env) - if ( - self.dist_env.is_main - and reduced_epoch_log - and self.wandb_initialized - and wandb.run is not None - ): - wandb.log(reduced_epoch_log, step=self.env_step) - elif epoch_log and self.wandb_initialized and wandb.run is not None: + if epoch_log and self.wandb_initialized and wandb.run is not None: wandb.log(epoch_log, step=self.env_step) def _train_step_returns( @@ -1682,8 +1630,6 @@ def save_model(self, output_dir): Args: output_dir: Directory to save the models to """ - if self.dist_env.enabled and not self.dist_env.is_main: - return os.makedirs(output_dir, exist_ok=True) for agent_idx, agent in enumerate(self.agents): diff --git a/comlrl/utils/__init__.py b/comlrl/utils/__init__.py index b303b2d..bca4c0d 100644 --- a/comlrl/utils/__init__.py +++ b/comlrl/utils/__init__.py @@ -6,11 +6,9 @@ DistributedContext, all_gather_objects, barrier, - init_distributed, is_main_process, local_context, unwrap_model, - wrap_ddp, ) from .tokenizer_utils import ( apply_tokenizer_specials, @@ -33,9 +31,7 @@ "call_reward_function", "normalize_reward_lengths", "DistributedContext", - "init_distributed", "local_context", - "wrap_ddp", "unwrap_model", "is_main_process", "barrier", diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py index 68767a0..b23143e 100644 --- a/comlrl/utils/distributed.py +++ b/comlrl/utils/distributed.py @@ -1,99 +1,30 @@ from __future__ import annotations -import os from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP @dataclass(frozen=True) class DistributedContext: enabled: bool - rank: int - world_size: int - local_rank: int is_main: bool device: torch.device def local_context(device: Optional[torch.device] = None) -> DistributedContext: if device is None: - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return DistributedContext( enabled=False, - rank=0, - world_size=1, - local_rank=0, is_main=True, device=device, ) -def init_distributed(backend: Optional[str] = None) -> DistributedContext: - world_size = int(os.environ.get("WORLD_SIZE", "1")) - rank = int(os.environ.get("RANK", "0")) - local_rank = int(os.environ.get("LOCAL_RANK", str(rank))) - enabled = world_size > 1 - - if torch.cuda.is_available(): - if enabled: - device_count = torch.cuda.device_count() - if device_count < 1: - raise RuntimeError( - "DDP requested but no CUDA devices are visible to this process." - ) - if local_rank < 0 or local_rank >= device_count: - raise ValueError( - "Invalid distributed GPU mapping: " - f"LOCAL_RANK={local_rank}, visible_cuda_devices={device_count}. " - "Make sure nproc_per_node does not exceed visible GPUs." - ) - torch.cuda.set_device(local_rank) - device = torch.device(f"cuda:{local_rank}" if enabled else "cuda") - else: - device = torch.device("cpu") - - if enabled and not dist.is_initialized(): - backend_name = backend or ("nccl" if device.type == "cuda" else "gloo") - dist.init_process_group(backend=backend_name, rank=rank, world_size=world_size) - - return DistributedContext( - enabled=enabled, - rank=rank, - world_size=world_size, - local_rank=local_rank, - is_main=(rank == 0), - device=device, - ) - - -def wrap_ddp( - model: torch.nn.Module, - ctx: DistributedContext, - *, - find_unused_parameters: bool = False, -) -> torch.nn.Module: - if not ctx.enabled: - return model - if isinstance(model, DDP): - return model - kwargs = { - "find_unused_parameters": find_unused_parameters, - } - if ctx.device.type == "cuda": - kwargs["device_ids"] = [ctx.local_rank] - kwargs["output_device"] = ctx.local_rank - return DDP(model, **kwargs) - - def unwrap_model(model: Any) -> Any: - return model.module if isinstance(model, DDP) else model + return getattr(model, "module", model) def is_main_process(ctx: Optional[DistributedContext]) -> bool: @@ -102,59 +33,18 @@ def is_main_process(ctx: Optional[DistributedContext]) -> bool: return bool(ctx.is_main) -def barrier(ctx: Optional[DistributedContext]) -> None: - if ctx is not None and ctx.enabled and dist.is_initialized(): - dist.barrier() +def barrier(ctx: Optional[DistributedContext]) -> None: # noqa: ARG001 + return -def all_gather_objects(obj: Any, ctx: Optional[DistributedContext]) -> List[Any]: - if ctx is None or not ctx.enabled: - return [obj] - gathered: List[Any] = [None for _ in range(ctx.world_size)] - dist.all_gather_object(gathered, obj) - return gathered +def all_gather_objects( + obj: Any, ctx: Optional[DistributedContext] +) -> List[Any]: # noqa: ARG001 + return [obj] def reduce_metrics_dict( metrics: Dict[str, float], - ctx: Optional[DistributedContext], + ctx: Optional[DistributedContext], # noqa: ARG001 ) -> Dict[str, float]: - """Average scalar metrics across distributed ranks. - - This helper must be called by all ranks in the same order. - """ - if ctx is None or not ctx.enabled: - return dict(metrics) - if not metrics: - return {} - - keys = sorted(metrics.keys()) - gathered_keys = all_gather_objects(keys, ctx) - same_keyset = all(k == keys for k in gathered_keys) - - if same_keyset: - values = torch.tensor( - [float(metrics[k]) for k in keys], device=ctx.device, dtype=torch.float64 - ) - dist.all_reduce(values, op=dist.ReduceOp.SUM) - values /= float(ctx.world_size) - reduced = {k: float(values[i].item()) for i, k in enumerate(keys)} - return reduced if ctx.is_main else {} - - union_keys = sorted({k for key_list in gathered_keys for k in key_list}) - value_tensor = torch.tensor( - [float(metrics.get(k, 0.0)) for k in union_keys], - device=ctx.device, - dtype=torch.float64, - ) - count_tensor = torch.tensor( - [1.0 if k in metrics else 0.0 for k in union_keys], - device=ctx.device, - dtype=torch.float64, - ) - dist.all_reduce(value_tensor, op=dist.ReduceOp.SUM) - dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) - count_tensor = torch.clamp(count_tensor, min=1.0) - averaged = value_tensor / count_tensor - reduced = {k: float(averaged[i].item()) for i, k in enumerate(union_keys)} - return reduced if ctx.is_main else {} + return dict(metrics) diff --git a/docs/content/docs/dev/changelog.md b/docs/content/docs/dev/changelog.md index 5c4d82f..c908174 100644 --- a/docs/content/docs/dev/changelog.md +++ b/docs/content/docs/dev/changelog.md @@ -8,7 +8,7 @@ weight: 3 ## Version 1.3.7 - Remove the redundant sampling hyperparameters in algorithms. -- Allow multi-gpu training with MP and DDP. +- Allow multi-gpu training with MP. ## Version 1.3.6 diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/multi-gpu-training.md index 432ec4c..9ffe62b 100644 --- a/docs/content/docs/user-guide/multi-gpu-training.md +++ b/docs/content/docs/user-guide/multi-gpu-training.md @@ -4,18 +4,18 @@ linkTitle: Training Parallelization weight: 6 --- -When multiple GPUs are available, CoMLRL can improve training throughput and reduce training time. +CoMLRL supports fine-tuning multi-LLM systems with larger models and more agents when multiple GPUs are available. +Currently, CoMLRL supports two training parallelization mode: `auto` and `mp` (model parallelization). -CoMLRL supports two schedulers for leveraging multiple GPUs: Model Parallelization (**MP**) for agent/critic deployment and PyTorch Distributed Data Parallelization (**DDP**) across multiple processes. +## Auto Parallelization + +The default parallelization mode for training is `parallel_training=auto`. If only one GPU is visible, CoMLRL fallbacks to single-GPU training. +If multiple GPUs are visible, CoMLRL uses `parallel_training=mp` to deploy the agents and critics across the specified devices via `agent_devices` / `critic_devices`. -## Concepts +{{% hint success %}} +We will support more parallelization modes (e.g., [data parallelization](https://docs.pytorch.org/docs/stable/elastic/run.html), [multi-node training](ray.io)) in the future. +{{% /hint %}} -- `CUDA_VISIBLE_DEVICES`: The GPUs visible to the current process. -- `WORLD_SIZE`: Total number of distributed processes participating in one training job. -- `RANK`: Global process index in `[0, WORLD_SIZE-1]`. -- `LOCAL_RANK`: Process index on the current node; used to select the node-local GPU. -- `MASTER_ADDR`: Address of the process-group rendezvous host (usually rank 0 node). -- `MASTER_PORT`: Port on `MASTER_ADDR` used to initialize distributed communication. ## Model Parallelization @@ -25,29 +25,15 @@ The responses are aggregated on the CPU and pass to the reward function. The rew MP supports training larger and more models than a single GPU can hold, but the training throughput is limited by the slowest model. ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py +CUDA_VISIBLE_DEVICES=0,1,2,3 python train_iac.py --config configs/iac_xxx.yaml --override + agent_model="model_a" + agents=None + critic_model="model_b" + critics=None + iac.use_separate_critic=true iac.parallel_training=mp iac.agent_devices='["cuda:0","cuda:1"]' iac.critic_devices='["cuda:2","cuda:3"]' ``` - -## Distributed Data Parallelization - -When `parallel_training=ddp`, CoMLRL launches multiple processes (one per GPU) and synchronizes gradients across them. Each process runs the full training loop across multiple models, but only on its assigned GPU. The model parameters are kept in sync across processes using PyTorch's DDP. -DDP improves the training throughput, but requires more GPU memory since each process holds a full copy of the models. DDP also requires more careful setup (e.g., environment variables, process launching) and may not be compatible with all reward functions. - -```bash -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 train.py - --config configs/iac_xxx.yaml - --override iac.parallel_training=ddp -``` - -## Auto Parallelization - -The `parallel_training` field is set to `auto` by default. -When users have `WORLD_SIZE=1` and `CUDA_VISIBLE_DEVICES=0`, CoMLRL trainers fall back to single-gpu training on `cuda:0` without launching multiple processes. -When users have multiple GPUs available, and `WORLD_SIZE=1`, CoMLRL trainers use MP to deploy models across the visible GPUs. -When users have multiple GPUs and complete torchrun distributed env vars (`WORLD_SIZE/RANK/LOCAL_RANK/MASTER_ADDR/MASTER_PORT`), CoMLRL trainers use DDP to synchronize training across processes. -These two modes are mutually exclusive. diff --git a/tests/test_config_constraints.py b/tests/test_config_constraints.py index ea5822b..6f736cf 100644 --- a/tests/test_config_constraints.py +++ b/tests/test_config_constraints.py @@ -32,6 +32,7 @@ def test_iac_config_constraints(): _assert_invalid(IACConfig, "num_generations", 0) _assert_invalid(IACConfig, "critic_type", "x") _assert_invalid(IACConfig, "parallel_training", "invalid") + _assert_invalid(IACConfig, "parallel_training", "auto") with pytest.raises(ValueError, match="num_generations"): IACConfig(num_turns=2, num_generations=2) @@ -57,6 +58,7 @@ def test_maac_config_constraints(): _assert_invalid_fields(MAACConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(MAACConfig, "critic_type", "x") _assert_invalid(MAACConfig, "parallel_training", "invalid") + _assert_invalid(MAACConfig, "parallel_training", "auto") with pytest.raises(ValueError, match="num_generations"): MAACConfig(num_turns=2, num_generations=2) @@ -82,6 +84,7 @@ def test_magrpo_config_constraints(): _assert_invalid_fields(MAGRPOConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(MAGRPOConfig, "num_generations", 1) _assert_invalid(MAGRPOConfig, "parallel_training", "invalid") + _assert_invalid(MAGRPOConfig, "parallel_training", "auto") MAGRPOConfig() MAGRPOConfig(num_generations=2) diff --git a/tests/test_distributed_metrics.py b/tests/test_distributed_metrics.py index c081662..5bfb2ca 100644 --- a/tests/test_distributed_metrics.py +++ b/tests/test_distributed_metrics.py @@ -2,115 +2,58 @@ import torch -import comlrl.utils.distributed as dist_utils from comlrl.trainers.actor_critic.ac_base import ActorCriticTrainerBase -from comlrl.utils.distributed import DistributedContext +from comlrl.utils.distributed import ( + DistributedContext, + all_gather_objects, + barrier, + local_context, + reduce_metrics_dict, +) -def _ctx(*, enabled: bool, is_main: bool, rank: int = 0, world_size: int = 1): +def _ctx() -> DistributedContext: return DistributedContext( - enabled=enabled, - rank=rank, - world_size=world_size, - local_rank=rank, - is_main=is_main, + enabled=False, + is_main=True, device=torch.device("cpu"), ) -def test_reduce_metrics_dict_local_returns_input(): - ctx = _ctx(enabled=False, is_main=True) - metrics = {"loss": 1.5, "reward": 2.5} - assert dist_utils.reduce_metrics_dict(metrics, ctx) == metrics - - -def test_reduce_metrics_dict_distributed_main_averages(monkeypatch): - ctx = _ctx(enabled=True, is_main=True, world_size=2) - metrics = {"a": 1.0, "b": 3.0} - - monkeypatch.setattr( - dist_utils, - "all_gather_objects", - lambda obj, _ctx: [obj, obj], - ) - - def _fake_all_reduce(tensor, op=None): # noqa: ARG001 - tensor += torch.tensor([3.0, 1.0], dtype=tensor.dtype, device=tensor.device) - - monkeypatch.setattr(dist_utils.dist, "all_reduce", _fake_all_reduce) - - reduced = dist_utils.reduce_metrics_dict(metrics, ctx) - assert reduced == {"a": 2.0, "b": 2.0} +def test_local_context_defaults_to_single_process(): + ctx = local_context(torch.device("cpu")) + assert ctx.enabled is False + assert ctx.is_main is True -def test_reduce_metrics_dict_distributed_non_main_returns_empty(monkeypatch): - ctx = _ctx(enabled=True, is_main=False, rank=1, world_size=2) - metrics = {"a": 1.0} - - monkeypatch.setattr( - dist_utils, - "all_gather_objects", - lambda obj, _ctx: [obj, obj], - ) - - def _fake_all_reduce(tensor, op=None): # noqa: ARG001 - tensor += torch.tensor([1.0], dtype=tensor.dtype, device=tensor.device) +def test_reduce_metrics_dict_passthrough(): + metrics = {"loss": 1.5, "reward": 2.5} + reduced = reduce_metrics_dict(metrics, _ctx()) + assert reduced == metrics + assert reduced is not metrics - monkeypatch.setattr(dist_utils.dist, "all_reduce", _fake_all_reduce) - assert dist_utils.reduce_metrics_dict(metrics, ctx) == {} +def test_barrier_and_all_gather_are_noop_in_single_process(): + ctx = _ctx() + barrier(ctx) + assert all_gather_objects({"a": 1}, ctx) == [{"a": 1}] -def test_ac_base_log_metrics_skips_reduction_when_unsynchronized(monkeypatch): +def test_ac_base_log_metrics_logs_directly(monkeypatch): trainer = ActorCriticTrainerBase() - trainer.dist_env = _ctx(enabled=True, is_main=True, world_size=2) + trainer.dist_env = _ctx() trainer.wandb_initialized = True trainer.env_step = 5 trainer.args = SimpleNamespace(logging_steps=1) trainer._last_train_log_step = -1 - called = {"reduce": 0, "log": []} - - def _fake_reduce(metrics, _ctx): # noqa: ARG001 - called["reduce"] += 1 - return {"loss": 99.0} + called = {"log": []} def _fake_log(metrics, step): # noqa: ARG001 called["log"].append(dict(metrics)) - monkeypatch.setattr( - "comlrl.trainers.actor_critic.ac_base.reduce_metrics_dict", _fake_reduce - ) - monkeypatch.setattr("wandb.log", _fake_log) - - trainer._log_metrics({"loss": 1.0}, synchronize=False) - assert called["reduce"] == 0 - assert called["log"] == [{"loss": 1.0}] - - -def test_ac_base_log_metrics_reduces_when_synchronized(monkeypatch): - trainer = ActorCriticTrainerBase() - trainer.dist_env = _ctx(enabled=True, is_main=True, world_size=2) - trainer.wandb_initialized = True - trainer.env_step = 6 - trainer.args = SimpleNamespace(logging_steps=1) - trainer._last_train_log_step = -1 - - called = {"reduce": 0, "log": []} - - def _fake_reduce(metrics, _ctx): # noqa: ARG001 - called["reduce"] += 1 - assert metrics == {"loss": 1.0} - return {"loss": 2.0} - - def _fake_log(metrics, step): # noqa: ARG001 - called["log"].append(dict(metrics)) - - monkeypatch.setattr( - "comlrl.trainers.actor_critic.ac_base.reduce_metrics_dict", _fake_reduce - ) monkeypatch.setattr("wandb.log", _fake_log) - trainer._log_metrics({"loss": 1.0}, synchronize=True) - assert called["reduce"] == 1 - assert called["log"] == [{"loss": 2.0}] + trainer._log_metrics({"loss": 1.0}) + trainer._log_metrics({"reward": 2.0}) + assert called["log"] == [{"loss": 1.0}, {"reward": 2.0}] diff --git a/tests/test_parallel_agent_tasks.py b/tests/test_parallel_agent_tasks.py index b20459d..2562758 100644 --- a/tests/test_parallel_agent_tasks.py +++ b/tests/test_parallel_agent_tasks.py @@ -30,7 +30,7 @@ def _task(agent_idx: int) -> str: def test_ac_run_agent_tasks_is_sequential_when_mp_disabled(): - trainer = _DummyACTrainer(parallel_training="ddp", num_agents=2) + trainer = _DummyACTrainer(parallel_training="none", num_agents=2) completion_order = [] def _task(agent_idx: int) -> int: @@ -46,7 +46,7 @@ def test_iac_parallel_updates_enabled_only_for_mp_mode(): trainer = IACTrainer.__new__(IACTrainer) trainer.args = SimpleNamespace(num_agents=2) trainer.agent_devices = ["cuda:0", "cuda:1"] - trainer.parallel_training = "ddp" + trainer.parallel_training = "none" assert trainer._parallel_agent_mode_enabled() is False trainer.parallel_training = "mp" assert trainer._parallel_agent_mode_enabled() is True diff --git a/tests/test_torchrun_scheduler.py b/tests/test_torchrun_scheduler.py deleted file mode 100644 index e801f5d..0000000 --- a/tests/test_torchrun_scheduler.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -from contextlib import contextmanager - -import pytest - -from comlrl.schedulers.torchrun_scheduler import TorchrunScheduler - - -@contextmanager -def _set_env(**updates): - old = {k: os.environ.get(k) for k in updates} - for key, value in updates.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = str(value) - try: - yield - finally: - for key, value in old.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value - - -def test_auto_uses_mp_when_world_size_is_one(): - with _set_env( - WORLD_SIZE="1", - RANK=None, - LOCAL_RANK=None, - MASTER_ADDR=None, - MASTER_PORT=None, - ): - assert TorchrunScheduler.resolve_mode("auto") == "mp" - - -def test_auto_uses_mp_when_torchrun_env_is_incomplete(): - with _set_env( - WORLD_SIZE="2", - RANK=None, - LOCAL_RANK=None, - MASTER_ADDR=None, - MASTER_PORT=None, - ): - assert TorchrunScheduler.resolve_mode("auto") == "mp" - - -def test_auto_uses_ddp_when_torchrun_env_is_complete(): - with _set_env( - WORLD_SIZE="2", - RANK="0", - LOCAL_RANK="0", - MASTER_ADDR="127.0.0.1", - MASTER_PORT="29500", - ): - assert TorchrunScheduler.resolve_mode("auto") == "ddp" - - -def test_ddp_requires_complete_torchrun_env(): - with _set_env( - WORLD_SIZE="2", - RANK=None, - LOCAL_RANK=None, - MASTER_ADDR=None, - MASTER_PORT=None, - ): - with pytest.raises(ValueError, match="Missing"): - TorchrunScheduler.resolve_mode("ddp") - - -def test_mp_rejects_world_size_greater_than_one(): - with _set_env(WORLD_SIZE="2"): - with pytest.raises(ValueError, match="WORLD_SIZE=1"): - TorchrunScheduler.resolve_mode("mp") From 7e9966fd16ebdfd6c994a1de62582fda879743f5 Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 00:02:24 -0500 Subject: [PATCH 15/21] udate docs --- .../{multi-gpu-training.md => training-parallelization.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/content/docs/user-guide/{multi-gpu-training.md => training-parallelization.md} (100%) diff --git a/docs/content/docs/user-guide/multi-gpu-training.md b/docs/content/docs/user-guide/training-parallelization.md similarity index 100% rename from docs/content/docs/user-guide/multi-gpu-training.md rename to docs/content/docs/user-guide/training-parallelization.md From 21c1adbb45015826e75b531d9dc8fe7fde4d9378 Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 00:08:34 -0500 Subject: [PATCH 16/21] Update training-parallelization.md --- docs/content/docs/user-guide/training-parallelization.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/docs/content/docs/user-guide/training-parallelization.md b/docs/content/docs/user-guide/training-parallelization.md index 9ffe62b..b721828 100644 --- a/docs/content/docs/user-guide/training-parallelization.md +++ b/docs/content/docs/user-guide/training-parallelization.md @@ -5,12 +5,7 @@ weight: 6 --- CoMLRL supports fine-tuning multi-LLM systems with larger models and more agents when multiple GPUs are available. -Currently, CoMLRL supports two training parallelization mode: `auto` and `mp` (model parallelization). - -## Auto Parallelization - -The default parallelization mode for training is `parallel_training=auto`. If only one GPU is visible, CoMLRL fallbacks to single-GPU training. -If multiple GPUs are visible, CoMLRL uses `parallel_training=mp` to deploy the agents and critics across the specified devices via `agent_devices` / `critic_devices`. +Currently, CoMLRL supports one training parallelization mode `mp` (model parallelization). {{% hint success %}} We will support more parallelization modes (e.g., [data parallelization](https://docs.pytorch.org/docs/stable/elastic/run.html), [multi-node training](ray.io)) in the future. From 77c2f1ccc5d7fe8782dbebab78a64fe515787fef Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 00:13:06 -0500 Subject: [PATCH 17/21] ud --- comlrl/trainers/actor_critic/iac.py | 4 ++ comlrl/trainers/actor_critic/maac.py | 4 ++ comlrl/trainers/reinforce/magrpo.py | 2 + .../user-guide/training-parallelization.md | 2 +- tests/test_config_constraints.py | 50 +++++++++++--- tests/test_model_loading.py | 38 +++++++---- tests/test_trainer_constraints.py | 66 +++++++++++-------- 7 files changed, 115 insertions(+), 51 deletions(-) diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 4b93fe1..4693c8a 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -90,6 +90,10 @@ def __post_init__(self) -> None: mode = str(self.parallel_training or "mp").lower() if mode != "mp": raise ValueError("parallel_training only supports: mp.") + if self.agent_devices is None: + raise ValueError("parallel_training='mp' requires explicit agent_devices.") + if self.critic_devices is None: + raise ValueError("parallel_training='mp' requires explicit critic_devices.") @dataclass diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 83d6fb9..5696a2b 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -84,6 +84,10 @@ def __post_init__(self) -> None: mode = str(self.parallel_training or "mp").lower() if mode != "mp": raise ValueError("parallel_training only supports: mp.") + if self.agent_devices is None: + raise ValueError("parallel_training='mp' requires explicit agent_devices.") + if self.critic_devices is None: + raise ValueError("parallel_training='mp' requires explicit critic_devices.") class MAACTrainer(ActorCriticTrainerBase): diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 6593b57..33e3676 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -90,6 +90,8 @@ def __post_init__(self) -> None: mode = str(self.parallel_training or "mp").lower() if mode != "mp": raise ValueError("parallel_training only supports: mp.") + if self.agent_devices is None: + raise ValueError("parallel_training='mp' requires explicit agent_devices.") @dataclass diff --git a/docs/content/docs/user-guide/training-parallelization.md b/docs/content/docs/user-guide/training-parallelization.md index b721828..82a7f8b 100644 --- a/docs/content/docs/user-guide/training-parallelization.md +++ b/docs/content/docs/user-guide/training-parallelization.md @@ -14,7 +14,7 @@ We will support more parallelization modes (e.g., [data parallelization](https:/ ## Model Parallelization -When `parallel_training=mp`, CoMLRL deploys the agents and critics across the specified devices via `agent_devices` / `critic_devices`. +When `parallel_training=mp`, CoMLRL requires explicit `agent_devices` / `critic_devices` configuration and deploys the agents and critics accordingly. The training and inference for each model (agent/critic) are running separately on its assigned device. The responses are aggregated on the CPU and pass to the reward function. The reward is then broadcast back to all devices for training. MP supports training larger and more models than a single GPU can hold, but the training throughput is limited by the slowest model. diff --git a/tests/test_config_constraints.py b/tests/test_config_constraints.py index 6f736cf..f3082ab 100644 --- a/tests/test_config_constraints.py +++ b/tests/test_config_constraints.py @@ -33,12 +33,26 @@ def test_iac_config_constraints(): _assert_invalid(IACConfig, "critic_type", "x") _assert_invalid(IACConfig, "parallel_training", "invalid") _assert_invalid(IACConfig, "parallel_training", "auto") + with pytest.raises(ValueError, match="agent_devices"): + IACConfig(critic_devices="cpu") + with pytest.raises(ValueError, match="critic_devices"): + IACConfig(agent_devices="cpu") with pytest.raises(ValueError, match="num_generations"): - IACConfig(num_turns=2, num_generations=2) + IACConfig( + num_turns=2, + num_generations=2, + agent_devices="cpu", + critic_devices="cpu", + ) - IACConfig() - IACConfig(num_turns=2, num_generations=1) - IACConfig(critic_type="q") + IACConfig(agent_devices="cpu", critic_devices="cpu") + IACConfig( + num_turns=2, + num_generations=1, + agent_devices="cpu", + critic_devices="cpu", + ) + IACConfig(critic_type="q", agent_devices="cpu", critic_devices="cpu") def test_maac_config_constraints(): @@ -59,12 +73,26 @@ def test_maac_config_constraints(): _assert_invalid(MAACConfig, "critic_type", "x") _assert_invalid(MAACConfig, "parallel_training", "invalid") _assert_invalid(MAACConfig, "parallel_training", "auto") + with pytest.raises(ValueError, match="agent_devices"): + MAACConfig(critic_devices="cpu") + with pytest.raises(ValueError, match="critic_devices"): + MAACConfig(agent_devices="cpu") with pytest.raises(ValueError, match="num_generations"): - MAACConfig(num_turns=2, num_generations=2) + MAACConfig( + num_turns=2, + num_generations=2, + agent_devices="cpu", + critic_devices="cpu", + ) - MAACConfig() - MAACConfig(num_turns=2, num_generations=1) - MAACConfig(critic_type="q") + MAACConfig(agent_devices="cpu", critic_devices="cpu") + MAACConfig( + num_turns=2, + num_generations=1, + agent_devices="cpu", + critic_devices="cpu", + ) + MAACConfig(critic_type="q", agent_devices="cpu", critic_devices="cpu") def test_magrpo_config_constraints(): @@ -85,6 +113,8 @@ def test_magrpo_config_constraints(): _assert_invalid(MAGRPOConfig, "num_generations", 1) _assert_invalid(MAGRPOConfig, "parallel_training", "invalid") _assert_invalid(MAGRPOConfig, "parallel_training", "auto") + with pytest.raises(ValueError, match="agent_devices"): + MAGRPOConfig() - MAGRPOConfig() - MAGRPOConfig(num_generations=2) + MAGRPOConfig(agent_devices="cpu") + MAGRPOConfig(num_generations=2, agent_devices="cpu") diff --git a/tests/test_model_loading.py b/tests/test_model_loading.py index 623f514..242faeb 100644 --- a/tests/test_model_loading.py +++ b/tests/test_model_loading.py @@ -18,6 +18,18 @@ def _reward_func(*_args, **_kwargs): return [0.0] +def _iac_cfg(**kwargs): + return IACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _maac_cfg(**kwargs): + return MAACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _magrpo_cfg(**kwargs): + return MAGRPOConfig(agent_devices="cpu", **kwargs) + + @pytest.fixture(scope="session") def tokenizer_05(): return AutoTokenizer.from_pretrained(MODEL_NAME_05) @@ -40,7 +52,7 @@ def _cleanup(*objs): def test_magrpo_model_name(): - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agent_model=MODEL_NAME_05, num_agents=2, @@ -53,7 +65,7 @@ def test_magrpo_model_name(): def test_magrpo_pretrained(tokenizer_05, model_05, model_06): - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agents=[model_05, model_06], tokenizer=tokenizer_05, @@ -66,7 +78,7 @@ def test_magrpo_pretrained(tokenizer_05, model_05, model_06): def test_maac_model_name(): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) trainer = MAACTrainer( agent_model=MODEL_NAME_05, critics=[MODEL_NAME_06], @@ -79,7 +91,7 @@ def test_maac_model_name(): def test_maac_pretrained(tokenizer_05, model_05, model_06): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) trainer = MAACTrainer( agents=[model_05, model_06], critics=[model_06], @@ -93,7 +105,7 @@ def test_maac_pretrained(tokenizer_05, model_05, model_06): def test_maac_critics_len_mismatch(tokenizer_05, model_05, model_06): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) with pytest.raises(ValueError, match="critics length"): MAACTrainer( agents=[model_05, model_06], @@ -105,7 +117,7 @@ def test_maac_critics_len_mismatch(tokenizer_05, model_05, model_06): def test_iac_model_name_critics(tokenizer_05, model_06): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=True) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=True) trainer = IACTrainer( agent_model=MODEL_NAME_05, critics=[model_06, model_06], @@ -119,7 +131,7 @@ def test_iac_model_name_critics(tokenizer_05, model_06): def test_iac_model_and_agents_names_conflict(tokenizer_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) with pytest.raises(ValueError, match="conflict"): IACTrainer( agent_model=MODEL_NAME_05, @@ -131,7 +143,7 @@ def test_iac_model_and_agents_names_conflict(tokenizer_05): def test_iac_model_and_agents_names_match(tokenizer_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) trainer = IACTrainer( agent_model=MODEL_NAME_05, agents=[MODEL_NAME_05, MODEL_NAME_05], @@ -144,7 +156,7 @@ def test_iac_model_and_agents_names_match(tokenizer_05): def test_iac_model_and_agents_len_mismatch(tokenizer_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) with pytest.raises(ValueError, match="agents length"): IACTrainer( agent_model=MODEL_NAME_05, @@ -161,7 +173,7 @@ def test_iac_model_and_agents_len_mismatch(tokenizer_05): ids=["shared_homo", "shared_hetero"], ) def test_iac_shared_heads(agents_case, tokenizer_05, model_05, model_06): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) agents = [model_05, model_05] if agents_case == "homo" else [model_05, model_06] trainer = IACTrainer( agents=agents, @@ -175,7 +187,7 @@ def test_iac_shared_heads(agents_case, tokenizer_05, model_05, model_06): def test_iac_shared_heads_rejects_critics(tokenizer_05, model_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) with pytest.raises(ValueError, match="use_separate_critic"): IACTrainer( agents=[model_05, model_05], @@ -187,7 +199,7 @@ def test_iac_shared_heads_rejects_critics(tokenizer_05, model_05): def test_iac_critics_len_mismatch(tokenizer_05, model_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=True) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=True) with pytest.raises(ValueError, match="critics length"): IACTrainer( agent_model=MODEL_NAME_05, @@ -204,7 +216,7 @@ def test_iac_critics_len_mismatch(tokenizer_05, model_05): ids=["critic_match", "critic_swapped"], ) def test_iac_separate_critics(critics_case, tokenizer_05, model_05, model_06): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=True) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=True) critics = [model_05, model_06] if critics_case == "match" else [model_06, model_05] trainer = IACTrainer( agents=[model_05, model_06], diff --git a/tests/test_trainer_constraints.py b/tests/test_trainer_constraints.py index 1c4d545..4e37f1c 100644 --- a/tests/test_trainer_constraints.py +++ b/tests/test_trainer_constraints.py @@ -18,6 +18,18 @@ def _external_transition(**_kwargs): return [""] * 1 +def _iac_cfg(**kwargs): + return IACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _maac_cfg(**kwargs): + return MAACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _magrpo_cfg(**kwargs): + return MAGRPOConfig(agent_devices="cpu", **kwargs) + + @pytest.fixture(scope="session") def dummy_tokenizer(): return SimpleNamespace( @@ -58,11 +70,11 @@ def tiny_model_b(): "factory, match", [ ( - lambda: IACTrainer(agent_model="dummy", reward_func=None, args=IACConfig()), + lambda: IACTrainer(agent_model="dummy", reward_func=None, args=_iac_cfg()), "reward_func", ), ( - lambda: IACTrainer(reward_func=_reward_func, args=IACConfig()), + lambda: IACTrainer(reward_func=_reward_func, args=_iac_cfg()), "Either agent_model or agents", ), ( @@ -70,7 +82,7 @@ def tiny_model_b(): agents=[object()], tokenizer=SimpleNamespace(pad_token="x", eos_token="x", pad_token_id=0), reward_func=_reward_func, - args=IACConfig(num_agents=1, num_turns=2), + args=_iac_cfg(num_agents=1, num_turns=2), ), "external_transition", ), @@ -79,18 +91,18 @@ def tiny_model_b(): agent_model="dummy", critics=[object()], reward_func=_reward_func, - args=IACConfig(use_separate_critic=False), + args=_iac_cfg(use_separate_critic=False), ), "use_separate_critic", ), ( lambda: MAACTrainer( - agent_model="dummy", reward_func=None, args=MAACConfig() + agent_model="dummy", reward_func=None, args=_maac_cfg() ), "reward_func", ), ( - lambda: MAACTrainer(reward_func=_reward_func, args=MAACConfig()), + lambda: MAACTrainer(reward_func=_reward_func, args=_maac_cfg()), "Either agent_model or agents", ), ( @@ -98,18 +110,18 @@ def tiny_model_b(): agents=[object()], tokenizer=SimpleNamespace(pad_token="x", eos_token="x", pad_token_id=0), reward_func=_reward_func, - args=MAACConfig(num_agents=1, num_turns=2), + args=_maac_cfg(num_agents=1, num_turns=2), ), "external_transition", ), ( lambda: MAGRPOTrainer( - agent_model="dummy", reward_func=None, args=MAGRPOConfig() + agent_model="dummy", reward_func=None, args=_magrpo_cfg() ), "reward_func", ), ( - lambda: MAGRPOTrainer(reward_func=_reward_func, args=MAGRPOConfig()), + lambda: MAGRPOTrainer(reward_func=_reward_func, args=_magrpo_cfg()), "Either agent_model or agents", ), ], @@ -131,7 +143,7 @@ def test_trainer_early_constraints(factory, match): def test_iac_separate_critic_requires_critics(dummy_tokenizer, tiny_model_a): - args = IACConfig(num_agents=1, use_separate_critic=True, num_turns=1) + args = _iac_cfg(num_agents=1, use_separate_critic=True, num_turns=1) with pytest.raises(ValueError, match="critics must be provided"): IACTrainer( agents=[tiny_model_a], @@ -142,7 +154,7 @@ def test_iac_separate_critic_requires_critics(dummy_tokenizer, tiny_model_a): def test_iac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=True, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=True, num_turns=1) with pytest.raises(ValueError, match="critics length"): IACTrainer( agents=[tiny_model_a, tiny_model_b], @@ -154,7 +166,7 @@ def test_iac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_iac_valid_shared_heads(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=False, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=False, num_turns=1) trainer = IACTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=dummy_tokenizer, @@ -166,7 +178,7 @@ def test_iac_valid_shared_heads(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_iac_accepts_tokenizer_list(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=False, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=False, num_turns=1) trainer = IACTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=[dummy_tokenizer, dummy_tokenizer], @@ -179,7 +191,7 @@ def test_iac_accepts_tokenizer_list(dummy_tokenizer, tiny_model_a, tiny_model_b) def test_iac_rejects_tokenizer_len_mismatch( dummy_tokenizer, tiny_model_a, tiny_model_b ): - args = IACConfig(num_agents=2, use_separate_critic=False, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=False, num_turns=1) with pytest.raises(ValueError, match="tokenizers length"): IACTrainer( agents=[tiny_model_a, tiny_model_b], @@ -190,7 +202,7 @@ def test_iac_rejects_tokenizer_len_mismatch( def test_iac_valid_separate_critics(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=True, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=True, num_turns=1) trainer = IACTrainer( agents=[tiny_model_a, tiny_model_b], critics=[tiny_model_b, tiny_model_a], @@ -204,7 +216,7 @@ def test_iac_valid_separate_critics(dummy_tokenizer, tiny_model_a, tiny_model_b) def test_iac_multiturn_with_transition(dummy_tokenizer, tiny_model_a): - args = IACConfig(num_agents=1, use_separate_critic=False, num_turns=2) + args = _iac_cfg(num_agents=1, use_separate_critic=False, num_turns=2) trainer = IACTrainer( agents=[tiny_model_a], tokenizer=dummy_tokenizer, @@ -222,12 +234,12 @@ def test_iac_multiturn_num_generations_mismatch(dummy_tokenizer, tiny_model_a): tokenizer=dummy_tokenizer, reward_func=_reward_func, external_transition=_external_transition, - args=IACConfig(num_agents=1, num_turns=2, num_generations=2), + args=_iac_cfg(num_agents=1, num_turns=2, num_generations=2), ) def test_maac_requires_critics(dummy_tokenizer, tiny_model_a): - args = MAACConfig(num_agents=1, num_turns=1) + args = _maac_cfg(num_agents=1, num_turns=1) with pytest.raises(ValueError, match="critics must be provided"): MAACTrainer( agents=[tiny_model_a], @@ -238,7 +250,7 @@ def test_maac_requires_critics(dummy_tokenizer, tiny_model_a): def test_maac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) with pytest.raises(ValueError, match="critics length"): MAACTrainer( agents=[tiny_model_a, tiny_model_b], @@ -250,7 +262,7 @@ def test_maac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_maac_valid(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) trainer = MAACTrainer( agents=[tiny_model_a, tiny_model_b], critics=[tiny_model_a], @@ -263,7 +275,7 @@ def test_maac_valid(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_maac_multiturn_with_transition(dummy_tokenizer, tiny_model_a): - args = MAACConfig(num_agents=1, num_turns=2, num_generations=1) + args = _maac_cfg(num_agents=1, num_turns=2, num_generations=1) trainer = MAACTrainer( agents=[tiny_model_a], critics=[tiny_model_a], @@ -284,14 +296,14 @@ def test_maac_multiturn_num_generations_mismatch(dummy_tokenizer, tiny_model_a): tokenizer=dummy_tokenizer, reward_func=_reward_func, external_transition=_external_transition, - args=MAACConfig(num_agents=1, num_turns=2, num_generations=2), + args=_maac_cfg(num_agents=1, num_turns=2, num_generations=2), ) def test_magrpo_requires_transition_for_multiturn( dummy_tokenizer, tiny_model_a, tiny_model_b ): - args = MAGRPOConfig(num_agents=2, num_turns=2, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=2, num_generations=2) with pytest.raises(ValueError, match="external_transition"): MAGRPOTrainer( agents=[tiny_model_a, tiny_model_b], @@ -302,7 +314,7 @@ def test_magrpo_requires_transition_for_multiturn( def test_magrpo_multiturn_with_transition(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAGRPOConfig(num_agents=2, num_turns=2, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=2, num_generations=2) trainer = MAGRPOTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=dummy_tokenizer, @@ -314,7 +326,7 @@ def test_magrpo_multiturn_with_transition(dummy_tokenizer, tiny_model_a, tiny_mo def test_magrpo_accepts_tokenizer_list(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=[dummy_tokenizer, dummy_tokenizer], @@ -333,7 +345,7 @@ def _fake_from_pretrained(*_args, **_kwargs): monkeypatch.setattr( "transformers.AutoModelForCausalLM.from_pretrained", _fake_from_pretrained ) - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agent_model="dummy", agents=["dummy", "dummy"], @@ -354,7 +366,7 @@ def _fake_from_pretrained(*_args, **_kwargs): monkeypatch.setattr( "transformers.AutoModelForCausalLM.from_pretrained", _fake_from_pretrained ) - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) with pytest.raises(ValueError, match="conflict"): MAGRPOTrainer( agent_model="dummy", From 5891e145c6a17300946ca88ec43a0ff29a4bbecd Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 09:15:05 -0500 Subject: [PATCH 18/21] ud --- comlrl/trainers/actor_critic/ac_base.py | 24 ++++++++------ comlrl/trainers/actor_critic/iac.py | 10 ++++-- comlrl/trainers/actor_critic/maac.py | 6 +++- comlrl/trainers/reinforce/magrpo.py | 7 ++-- docs/content/docs/dev/changelog.md | 2 +- .../user-guide/training-parallelization.md | 4 +++ tests/test_distributed_metrics.py | 32 +++++++++++++++++++ 7 files changed, 69 insertions(+), 16 deletions(-) diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index f9363ac..1ae6523 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -368,18 +368,22 @@ def evaluate(self) -> Dict[str, float]: turn_groups: Dict[int, List[Any]] = {} seen = 0 - with torch.no_grad(): - for batch in dataloader: - for item in batch: - rollouts = self._collect_rollouts(item) - for sample in rollouts: - t_idx = int(sample.metadata.get("turn_idx", 0)) - turn_groups.setdefault(t_idx, []).append(sample) - seen += 1 + self._in_eval = True + try: + with torch.no_grad(): + for batch in dataloader: + for item in batch: + rollouts = self._collect_rollouts(item) + for sample in rollouts: + t_idx = int(sample.metadata.get("turn_idx", 0)) + turn_groups.setdefault(t_idx, []).append(sample) + seen += 1 + if seen >= num_samples: + break if seen >= num_samples: break - if seen >= num_samples: - break + finally: + self._in_eval = False eval_log: Dict[str, float] = {} for turn_idx, samples in sorted(turn_groups.items()): diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 4693c8a..5c2304e 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -465,7 +465,9 @@ def _generate_rollout( "do_sample": True, "temperature": self.args.temperature, "top_p": self.args.top_p, - "num_return_sequences": num_ret, + "num_return_sequences": ( + 1 if bool(getattr(self, "_in_eval", False)) else num_ret + ), "num_beams": 1, } if self.args.top_k is not None: @@ -558,7 +560,11 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: if num_turns > 1: return self._collect_rollouts_multi_turn(item, num_turns) - num_ret = int(getattr(self.args, "num_generations", 1)) + num_ret = ( + 1 + if bool(getattr(self, "_in_eval", False)) + else int(getattr(self.args, "num_generations", 1)) + ) turn_prompts = [ self._resolve_turn_prompt(item, agent_idx) for agent_idx in range(self.args.num_agents) diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 5696a2b..8b31220 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -435,7 +435,11 @@ def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: prompt_attention_mask = encoded_prompt["attention_mask"] prompt_len = prompt_input_ids.size(1) - num_ret = int(self.args.num_generations) + num_ret = ( + 1 + if bool(getattr(self, "_in_eval", False)) + else int(self.args.num_generations) + ) generation_kwargs: Dict[str, Any] = { "input_ids": prompt_input_ids, "attention_mask": prompt_attention_mask, diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 33e3676..d8164d6 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -495,18 +495,21 @@ def get_eval_dataloader(self) -> Optional[DataLoader]: num_workers=0, ) - def evaluate(self, num_eval_samples: int = 4) -> Dict[str, float]: + def evaluate(self, num_eval_samples: Optional[int] = None) -> Dict[str, float]: """ Unified evaluation that supports both single-turn and multi-turn. Args: - num_eval_samples: Number of samples to evaluate + num_eval_samples: Number of samples to evaluate. Defaults to args.eval_num_samples. Returns: Dictionary containing evaluation metrics """ if self.eval_dataset is None: return {} + if num_eval_samples is None: + num_eval_samples = int(getattr(self.args, "eval_num_samples", 4)) + num_eval_samples = int(num_eval_samples) # Storage for completions across turns for all agents all_agent_completions_turns = [[] for _ in range(self.num_agents)] diff --git a/docs/content/docs/dev/changelog.md b/docs/content/docs/dev/changelog.md index c908174..80c44f7 100644 --- a/docs/content/docs/dev/changelog.md +++ b/docs/content/docs/dev/changelog.md @@ -10,7 +10,7 @@ weight: 3 - Remove the redundant sampling hyperparameters in algorithms. - Allow multi-gpu training with MP. -## Version 1.3.6 +## Latest Changes - Fixed critical bug of loading heterogeneous models and reform the model loading logics. - Polish the docs. diff --git a/docs/content/docs/user-guide/training-parallelization.md b/docs/content/docs/user-guide/training-parallelization.md index 82a7f8b..ab3e8a5 100644 --- a/docs/content/docs/user-guide/training-parallelization.md +++ b/docs/content/docs/user-guide/training-parallelization.md @@ -32,3 +32,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train_iac.py iac.agent_devices='["cuda:0","cuda:1"]' iac.critic_devices='["cuda:2","cuda:3"]' ``` + +{{% hint note %}} +Note that when `parallel_training=mp`, even if the same models with same sampling are used on the same seed, the training is not deterministic due to the non-deterministic GPU scheduling and aggregation on CPU. +{{% /hint %}} diff --git a/tests/test_distributed_metrics.py b/tests/test_distributed_metrics.py index 5bfb2ca..a51d551 100644 --- a/tests/test_distributed_metrics.py +++ b/tests/test_distributed_metrics.py @@ -57,3 +57,35 @@ def _fake_log(metrics, step): # noqa: ARG001 trainer._log_metrics({"loss": 1.0}) trainer._log_metrics({"reward": 2.0}) assert called["log"] == [{"loss": 1.0}, {"reward": 2.0}] + + +def test_evaluate_calls_collect_rollouts_with_eval_flag(): + class _EvalDummyTrainer(ActorCriticTrainerBase): + def __init__(self): + self.eval_dataset = [{"id": 0}, {"id": 1}] + self.args = SimpleNamespace(eval_batch_size=1, eval_num_samples=1) + self.wandb_initialized = False + self.env_step = 0 + self.dist_env = _ctx() + self.verbose = False + self._collect_calls = 0 + self.eval_flags = [] + + def _collect_rollouts(self, item): # noqa: ARG002 + self._collect_calls += 1 + self.eval_flags.append(bool(getattr(self, "_in_eval", False))) + return [ + SimpleNamespace( + metadata={}, + reward=torch.tensor([1.0]), + returns=torch.tensor([1.5]), + old_value=torch.tensor([0.5]), + ) + ] + + trainer = _EvalDummyTrainer() + metrics = trainer.evaluate() + assert trainer._collect_calls == 1 + assert trainer.eval_flags == [True] + assert trainer._in_eval is False + assert "eval/turn_1/reward_mean" in metrics From 90fe24734a49ed32cbb4dbeeb96620c7388956b9 Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 10:10:08 -0500 Subject: [PATCH 19/21] ud --- comlrl/schedulers/device_scheduler.py | 19 +++++++ comlrl/trainers/actor_critic/iac.py | 52 +++++++++++------ comlrl/trainers/actor_critic/maac.py | 56 ++++++++++++------- comlrl/trainers/reinforce/magrpo.py | 37 +++++++----- .../user-guide/training-parallelization.md | 6 +- tests/test_config_constraints.py | 56 ++++++++----------- 6 files changed, 142 insertions(+), 84 deletions(-) diff --git a/comlrl/schedulers/device_scheduler.py b/comlrl/schedulers/device_scheduler.py index bab08f3..5a813dc 100644 --- a/comlrl/schedulers/device_scheduler.py +++ b/comlrl/schedulers/device_scheduler.py @@ -70,6 +70,25 @@ def resolve_devices( raise ValueError(f"Unsupported {kind} spec: {spec!r}.") + @staticmethod + def resolve_single_device(*specs: Optional[DeviceSpec]) -> torch.device: + for spec in specs: + if spec is None: + continue + if isinstance(spec, str): + if spec.lower() == "auto": + return DeviceScheduler._auto_devices(1)[0] + return torch.device(spec) + if isinstance(spec, Sequence): + if len(spec) == 0: + raise ValueError("Device spec list must be non-empty.") + first = spec[0] + if isinstance(first, str) and first.lower() == "auto": + return DeviceScheduler._auto_devices(1)[0] + return torch.device(first) + raise ValueError(f"Unsupported device spec: {spec!r}.") + return DeviceScheduler._auto_devices(1)[0] + @staticmethod def devices_disjoint(device_groups: Iterable[Sequence[torch.device]]) -> bool: seen = set() diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 5c2304e..43a4288 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -47,7 +47,7 @@ class IACConfig: value_head_hidden_dim: Optional[int] = None num_agents: int = 2 num_turns: int = 2 - parallel_training: str = "mp" + parallel_training: str = "none" agent_devices: Optional[Union[str, Sequence[str]]] = None critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False @@ -87,13 +87,21 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") - mode = str(self.parallel_training or "mp").lower() - if mode != "mp": - raise ValueError("parallel_training only supports: mp.") - if self.agent_devices is None: - raise ValueError("parallel_training='mp' requires explicit agent_devices.") - if self.critic_devices is None: - raise ValueError("parallel_training='mp' requires explicit critic_devices.") + mode = str(self.parallel_training or "none").strip().lower() + if mode == "null": + mode = "none" + if mode not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if mode == "mp": + if self.agent_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit agent_devices." + ) + if self.critic_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit critic_devices." + ) + self.parallel_training = mode @dataclass @@ -174,16 +182,26 @@ def __init__( self.model_config = model_config or {} self.critic_type = (self.args.critic_type or "v").lower() self.parallel_training = ( - str(getattr(self.args, "parallel_training", "mp")).strip().lower() - ) - if self.parallel_training != "mp": - raise ValueError("parallel_training only supports: mp.") - self.agent_devices, self.critic_devices = DeviceScheduler.assign_devices( - self.args.num_agents, - getattr(self.args, "agent_devices", None), - getattr(self.args, "critic_devices", None), - use_separate_critic=self.args.use_separate_critic, + str(getattr(self.args, "parallel_training", "none")).strip().lower() ) + if self.parallel_training == "null": + self.parallel_training = "none" + if self.parallel_training not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if self.parallel_training == "mp": + self.agent_devices, self.critic_devices = DeviceScheduler.assign_devices( + self.args.num_agents, + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + use_separate_critic=self.args.use_separate_critic, + ) + else: + single_device = DeviceScheduler.resolve_single_device( + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + ) + self.agent_devices = [single_device] * self.args.num_agents + self.critic_devices = [single_device] * self.args.num_agents self.device = self.agent_devices[0] self.dist_env = local_context(self.device) diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 8b31220..49c0395 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -43,7 +43,7 @@ class MAACConfig: num_agents: int = 2 num_generations: int = 1 num_turns: int = 2 - parallel_training: str = "mp" + parallel_training: str = "none" agent_devices: Optional[Union[str, Sequence[str]]] = None critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False @@ -81,13 +81,21 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") - mode = str(self.parallel_training or "mp").lower() - if mode != "mp": - raise ValueError("parallel_training only supports: mp.") - if self.agent_devices is None: - raise ValueError("parallel_training='mp' requires explicit agent_devices.") - if self.critic_devices is None: - raise ValueError("parallel_training='mp' requires explicit critic_devices.") + mode = str(self.parallel_training or "none").strip().lower() + if mode == "null": + mode = "none" + if mode not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if mode == "mp": + if self.agent_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit agent_devices." + ) + if self.critic_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit critic_devices." + ) + self.parallel_training = mode class MAACTrainer(ActorCriticTrainerBase): @@ -147,18 +155,28 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} self.parallel_training = ( - str(getattr(self.args, "parallel_training", "mp")).strip().lower() - ) - if self.parallel_training != "mp": - raise ValueError("parallel_training only supports: mp.") - self.agent_devices = DeviceScheduler.resolve_devices( - getattr(self.args, "agent_devices", None), - self.args.num_agents, - kind="agent_devices", - ) - self.critic_device = DeviceScheduler.assign_shared_critic_device( - self.agent_devices, getattr(self.args, "critic_devices", None) + str(getattr(self.args, "parallel_training", "none")).strip().lower() ) + if self.parallel_training == "null": + self.parallel_training = "none" + if self.parallel_training not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if self.parallel_training == "mp": + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.args.num_agents, + kind="agent_devices", + ) + self.critic_device = DeviceScheduler.assign_shared_critic_device( + self.agent_devices, getattr(self.args, "critic_devices", None) + ) + else: + single_device = DeviceScheduler.resolve_single_device( + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + ) + self.agent_devices = [single_device] * self.args.num_agents + self.critic_device = single_device self.device = self.agent_devices[0] self.dist_env = local_context(self.device) self._parallel_update_enabled = False diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index d8164d6..44e97ee 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -37,7 +37,7 @@ class MAGRPOConfig: agent_learning_rate: float = 5.0e-6 logging_steps: int = 50 num_agents: int = 2 - parallel_training: str = "mp" + parallel_training: str = "none" agent_devices: Optional[Union[str, Sequence[str]]] = None # Sampling/generation @@ -87,11 +87,14 @@ def __post_init__(self) -> None: self.train_batch_size = self.rollout_buffer_size if self.train_batch_size < 1: raise ValueError("train_batch_size must be >= 1.") - mode = str(self.parallel_training or "mp").lower() - if mode != "mp": - raise ValueError("parallel_training only supports: mp.") - if self.agent_devices is None: + mode = str(self.parallel_training or "none").strip().lower() + if mode == "null": + mode = "none" + if mode not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if mode == "mp" and self.agent_devices is None: raise ValueError("parallel_training='mp' requires explicit agent_devices.") + self.parallel_training = mode @dataclass @@ -159,10 +162,12 @@ def __init__( ): self.args = args if args is not None else self.default_config_cls() self.parallel_training = ( - str(getattr(self.args, "parallel_training", "mp")).strip().lower() + str(getattr(self.args, "parallel_training", "none")).strip().lower() ) - if self.parallel_training != "mp": - raise ValueError("parallel_training only supports: mp.") + if self.parallel_training == "null": + self.parallel_training = "none" + if self.parallel_training not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") self.dist_env = local_context() self.device = self.dist_env.device @@ -210,11 +215,17 @@ def __init__( self.num_agents = expected_count self.model_name = model_name - self.agent_devices = DeviceScheduler.resolve_devices( - getattr(self.args, "agent_devices", None), - self.num_agents, - kind="agent_devices", - ) + if self.parallel_training == "mp": + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.num_agents, + kind="agent_devices", + ) + else: + single_device = DeviceScheduler.resolve_single_device( + getattr(self.args, "agent_devices", None) + ) + self.agent_devices = [single_device] * self.num_agents self.device = self.agent_devices[0] self.dist_env = local_context(self.device) if actor_sources and all(isinstance(src, str) for src in actor_sources): diff --git a/docs/content/docs/user-guide/training-parallelization.md b/docs/content/docs/user-guide/training-parallelization.md index ab3e8a5..b52baa3 100644 --- a/docs/content/docs/user-guide/training-parallelization.md +++ b/docs/content/docs/user-guide/training-parallelization.md @@ -5,13 +5,13 @@ weight: 6 --- CoMLRL supports fine-tuning multi-LLM systems with larger models and more agents when multiple GPUs are available. -Currently, CoMLRL supports one training parallelization mode `mp` (model parallelization). +Users can configure the parallelization training with `iac.parallel_training`. +Currently, `parallel_training` supports two modes: `none` or `null` is the default mode for single-device training; `mp`is the model parallel scheduling across explicit agent/critic devices. {{% hint success %}} We will support more parallelization modes (e.g., [data parallelization](https://docs.pytorch.org/docs/stable/elastic/run.html), [multi-node training](ray.io)) in the future. {{% /hint %}} - ## Model Parallelization When `parallel_training=mp`, CoMLRL requires explicit `agent_devices` / `critic_devices` configuration and deploys the agents and critics accordingly. @@ -34,5 +34,5 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train_iac.py ``` {{% hint note %}} -Note that when `parallel_training=mp`, even if the same models with same sampling are used on the same seed, the training is not deterministic due to the non-deterministic GPU scheduling and aggregation on CPU. +Note that when devices are changed, the training is not deterministic due to the non-deterministic GPU scheduling and aggregation on CPU. {{% /hint %}} diff --git a/tests/test_config_constraints.py b/tests/test_config_constraints.py index f3082ab..c87ee5f 100644 --- a/tests/test_config_constraints.py +++ b/tests/test_config_constraints.py @@ -34,25 +34,19 @@ def test_iac_config_constraints(): _assert_invalid(IACConfig, "parallel_training", "invalid") _assert_invalid(IACConfig, "parallel_training", "auto") with pytest.raises(ValueError, match="agent_devices"): - IACConfig(critic_devices="cpu") + IACConfig(parallel_training="mp", critic_devices="cpu") with pytest.raises(ValueError, match="critic_devices"): - IACConfig(agent_devices="cpu") + IACConfig(parallel_training="mp", agent_devices="cpu") with pytest.raises(ValueError, match="num_generations"): - IACConfig( - num_turns=2, - num_generations=2, - agent_devices="cpu", - critic_devices="cpu", - ) + IACConfig(num_turns=2, num_generations=2) + IACConfig() + IACConfig(parallel_training="none") + IACConfig(parallel_training="null") + IACConfig(parallel_training="mp", agent_devices="cpu", critic_devices="cpu") IACConfig(agent_devices="cpu", critic_devices="cpu") - IACConfig( - num_turns=2, - num_generations=1, - agent_devices="cpu", - critic_devices="cpu", - ) - IACConfig(critic_type="q", agent_devices="cpu", critic_devices="cpu") + IACConfig(num_turns=2, num_generations=1) + IACConfig(critic_type="q") def test_maac_config_constraints(): @@ -74,25 +68,19 @@ def test_maac_config_constraints(): _assert_invalid(MAACConfig, "parallel_training", "invalid") _assert_invalid(MAACConfig, "parallel_training", "auto") with pytest.raises(ValueError, match="agent_devices"): - MAACConfig(critic_devices="cpu") + MAACConfig(parallel_training="mp", critic_devices="cpu") with pytest.raises(ValueError, match="critic_devices"): - MAACConfig(agent_devices="cpu") + MAACConfig(parallel_training="mp", agent_devices="cpu") with pytest.raises(ValueError, match="num_generations"): - MAACConfig( - num_turns=2, - num_generations=2, - agent_devices="cpu", - critic_devices="cpu", - ) + MAACConfig(num_turns=2, num_generations=2) + MAACConfig() + MAACConfig(parallel_training="none") + MAACConfig(parallel_training="null") + MAACConfig(parallel_training="mp", agent_devices="cpu", critic_devices="cpu") MAACConfig(agent_devices="cpu", critic_devices="cpu") - MAACConfig( - num_turns=2, - num_generations=1, - agent_devices="cpu", - critic_devices="cpu", - ) - MAACConfig(critic_type="q", agent_devices="cpu", critic_devices="cpu") + MAACConfig(num_turns=2, num_generations=1) + MAACConfig(critic_type="q") def test_magrpo_config_constraints(): @@ -114,7 +102,11 @@ def test_magrpo_config_constraints(): _assert_invalid(MAGRPOConfig, "parallel_training", "invalid") _assert_invalid(MAGRPOConfig, "parallel_training", "auto") with pytest.raises(ValueError, match="agent_devices"): - MAGRPOConfig() + MAGRPOConfig(parallel_training="mp") + MAGRPOConfig() + MAGRPOConfig(parallel_training="none") + MAGRPOConfig(parallel_training="null") + MAGRPOConfig(parallel_training="mp", agent_devices="cpu") MAGRPOConfig(agent_devices="cpu") - MAGRPOConfig(num_generations=2, agent_devices="cpu") + MAGRPOConfig(num_generations=2) From a0e9b77c7ac05d9da373e5dd143d46495985134a Mon Sep 17 00:00:00 2001 From: N!no Date: Mon, 16 Feb 2026 13:13:02 -0500 Subject: [PATCH 20/21] Update iac.py --- comlrl/trainers/actor_critic/iac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 43a4288..db49d73 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -627,7 +627,7 @@ def _generate_agent(agent_idx: int) -> Dict[str, Any]: value = data["values"][i] reward = float(rewards_matrix[agent_idx][i]) reward_tensor = torch.tensor( - [reward], device=self._agent_device(agent_idx), dtype=torch.float32 + [reward], device=value.device, dtype=torch.float32 ) returns = reward_tensor.clone() advantage = returns - value From 72f442bbec10b6e5b3a2cfa8344aca2a92c94c22 Mon Sep 17 00:00:00 2001 From: N!no Date: Tue, 17 Feb 2026 10:19:22 -0500 Subject: [PATCH 21/21] clear redundant detach --- comlrl/trainers/actor_critic/iac.py | 50 ++++++++++++---------------- comlrl/trainers/actor_critic/maac.py | 44 +++++++++++------------- comlrl/trainers/reinforce/magrpo.py | 6 ++-- 3 files changed, 44 insertions(+), 56 deletions(-) diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index db49d73..4e40c8a 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -626,11 +626,11 @@ def _generate_agent(agent_idx: int) -> Dict[str, Any]: logprob = data["logprobs"][i] value = data["values"][i] reward = float(rewards_matrix[agent_idx][i]) - reward_tensor = torch.tensor( - [reward], device=value.device, dtype=torch.float32 - ) - returns = reward_tensor.clone() - advantage = returns - value + reward_cpu = torch.tensor([reward], dtype=torch.float32) + value_cpu = value.detach().cpu() + returns_cpu = reward_cpu.clone() + advantage_cpu = returns_cpu - value_cpu.to(dtype=returns_cpu.dtype) + logprob_cpu = logprob.detach().cpu() rollouts.append( RolloutSample( @@ -644,14 +644,14 @@ def _generate_agent(agent_idx: int) -> Dict[str, Any]: attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=returns.detach().cpu(), - advantage=advantage.detach().cpu(), + old_logprob=logprob_cpu, + old_value=value_cpu, + reward=reward_cpu, + returns=returns_cpu, + advantage=advantage_cpu, metadata={ "char_length": data["char_lengths"][i], - "value_target": returns.detach().cpu(), + "value_target": returns_cpu, }, ) ) @@ -741,11 +741,9 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: logprob = data["logprobs"][0] value = data["values"][0] reward_val = float(rewards_matrix[agent_idx][0]) - reward_tensor = torch.tensor( - [reward_val], - device=self._agent_device(agent_idx), - dtype=torch.float32, - ) + reward_cpu = torch.tensor([reward_val], dtype=torch.float32) + value_cpu = value.detach().cpu() + logprob_cpu = logprob.detach().cpu() completion_text = data["completions"][0] sample = RolloutSample( @@ -756,11 +754,11 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=reward_tensor.detach().cpu(), - advantage=torch.zeros_like(reward_tensor).detach().cpu(), + old_logprob=logprob_cpu, + old_value=value_cpu, + reward=reward_cpu, + returns=reward_cpu.clone(), + advantage=torch.zeros_like(reward_cpu), normalized_advantage=None, metadata={ "char_length": data["char_lengths"][0], @@ -787,19 +785,15 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: target = r + gamma * next_v else: target = r - sample.metadata["adv_target"] = torch.tensor([target]).detach().cpu() - sample.metadata["value_target"] = torch.tensor([target]).detach().cpu() + sample.metadata["adv_target"] = torch.tensor([target]).cpu() + sample.metadata["value_target"] = torch.tensor([target]).cpu() for agent_idx in range(self.args.num_agents): future = 0.0 for sample in reversed(per_agent_samples[agent_idx]): immediate = float(sample.reward.view(-1)[0].item()) future = immediate + gamma * future - sample.returns = ( - torch.tensor([future], device=self._agent_device(agent_idx)) - .detach() - .cpu() - ) + sample.returns = torch.tensor([future], dtype=torch.float32) sample.advantage = torch.zeros_like(sample.returns) sample.normalized_advantage = None diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 49c0395..57b4099 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -583,9 +583,6 @@ def _generate_agent(agent_idx: int) -> Dict[str, Any]: attn = data["attention_mask"][i] resp_len = data["response_lens"][i] reward = float(rewards_matrix[agent_idx][i]) - reward_tensor = torch.tensor( - [reward], device=self._agent_device(agent_idx) - ) logprob, _ = self._policy_eval( self.agents[agent_idx], @@ -605,6 +602,8 @@ def _generate_agent(agent_idx: int) -> Dict[str, Any]: joint_mask = critic_pack["attention_mask"] joint_len = int(critic_pack["prompt_len"]) value = critic_pack["value"].detach().cpu() + reward_cpu = torch.tensor([reward], dtype=torch.float32) + logprob_cpu = logprob.detach().cpu() rollouts.append( RolloutSample( agent_idx=agent_idx, @@ -617,25 +616,25 @@ def _generate_agent(agent_idx: int) -> Dict[str, Any]: attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=reward_tensor.detach().cpu(), - advantage=torch.zeros_like(reward_tensor).detach().cpu(), + old_logprob=logprob_cpu, + old_value=value, + reward=reward_cpu, + returns=reward_cpu.clone(), + advantage=torch.zeros_like(reward_cpu), normalized_advantage=None, metadata={ "joint_input_ids": joint_ids.detach().cpu(), "joint_attention_mask": joint_mask.detach().cpu(), "joint_prompt_len": joint_len, "turn_idx": 0, - "adv_target": reward_tensor.detach().cpu(), + "adv_target": reward_cpu, }, ) ) for sample in rollouts: r = float(sample.reward.view(-1)[0].item()) - sample.metadata["value_target"] = torch.tensor([r]).detach().cpu() + sample.metadata["value_target"] = torch.tensor([r]).cpu() if self.metrics_callback is not None: extra = self.metrics_callback(rollouts) @@ -744,9 +743,7 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: attn = data["attention_mask"][0] resp_len = data["response_lens"][0] reward_val = float(rewards_matrix[agent_idx][0]) - reward_tensor = torch.tensor( - [reward_val], device=self._agent_device(agent_idx) - ) + reward_cpu = torch.tensor([reward_val], dtype=torch.float32) logprob, _ = self._policy_eval( self.agents[agent_idx], @@ -758,6 +755,7 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: ) value = joint_value.detach().cpu() + logprob_cpu = logprob.detach().cpu() completion_text = data["completion_texts"][0] sample = RolloutSample( agent_idx=agent_idx, @@ -767,11 +765,11 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=reward_tensor.detach().cpu(), - advantage=torch.zeros_like(reward_tensor).detach().cpu(), + old_logprob=logprob_cpu, + old_value=value, + reward=reward_cpu, + returns=reward_cpu.clone(), + advantage=torch.zeros_like(reward_cpu), normalized_advantage=None, metadata={ "joint_input_ids": joint_ids.detach().cpu(), @@ -800,19 +798,15 @@ def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: target = r + gamma * next_v else: target = r - sample.metadata["adv_target"] = torch.tensor([target]).detach().cpu() - sample.metadata["value_target"] = torch.tensor([target]).detach().cpu() + sample.metadata["adv_target"] = torch.tensor([target]).cpu() + sample.metadata["value_target"] = torch.tensor([target]).cpu() for agent_idx in range(self.args.num_agents): future = 0.0 for sample in reversed(per_agent_samples[agent_idx]): immediate = float(sample.reward.view(-1)[0].item()) future = immediate + gamma * future - sample.returns = ( - torch.tensor([future], device=self._agent_device(agent_idx)) - .detach() - .cpu() - ) + sample.returns = torch.tensor([future], dtype=torch.float32) sample.advantage = torch.zeros_like(sample.returns) sample.normalized_advantage = None diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index 44e97ee..e27f200 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -1410,12 +1410,12 @@ def _compute_rewards( def _pack_completions_for_buffer( self, completions_data: Dict[str, Any] ) -> Dict[str, Any]: - prompt_ids = completions_data["prompt_input_ids"].detach().cpu() + prompt_ids = completions_data["prompt_input_ids"].cpu() completion_ids = completions_data["completion_input_ids"] if completion_ids and isinstance(completion_ids[0], list): - packed_completion_ids = [[t.detach().cpu() for t in completion_ids[0]]] + packed_completion_ids = [[t.cpu() for t in completion_ids[0]]] else: - packed_completion_ids = [[t.detach().cpu() for t in completion_ids]] + packed_completion_ids = [[t.cpu() for t in completion_ids]] return { "prompt_input_ids": prompt_ids, "completion_input_ids": packed_completion_ids,