diff --git a/comlrl/schedulers/__init__.py b/comlrl/schedulers/__init__.py new file mode 100644 index 0000000..fcd442b --- /dev/null +++ b/comlrl/schedulers/__init__.py @@ -0,0 +1,3 @@ +from .device_scheduler import DeviceScheduler + +__all__ = ["DeviceScheduler"] diff --git a/comlrl/schedulers/device_scheduler.py b/comlrl/schedulers/device_scheduler.py new file mode 100644 index 0000000..5a813dc --- /dev/null +++ b/comlrl/schedulers/device_scheduler.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import Iterable, List, Optional, Sequence, Tuple, Union + +import torch + +DeviceSpec = Union[str, Sequence[str]] + + +class DeviceScheduler: + @staticmethod + def assign_devices( + num_agents: int, + agent_devices: Optional[DeviceSpec], + critic_devices: Optional[DeviceSpec], + *, + use_separate_critic: bool, + ) -> Tuple[List[torch.device], List[torch.device]]: + agent_list = DeviceScheduler.resolve_devices( + agent_devices, num_agents, kind="agent_devices" + ) + if use_separate_critic: + critic_spec = ( + critic_devices if critic_devices is not None else agent_devices + ) + critic_list = DeviceScheduler.resolve_devices( + critic_spec, num_agents, kind="critic_devices" + ) + else: + critic_list = list(agent_list) + return agent_list, critic_list + + @staticmethod + def assign_shared_critic_device( + agent_devices: Sequence[torch.device], + critic_devices: Optional[DeviceSpec], + ) -> torch.device: + if critic_devices is None: + return agent_devices[0] + return DeviceScheduler.resolve_devices( + critic_devices, 1, kind="critic_devices" + )[0] + + @staticmethod + def resolve_devices( + spec: Optional[DeviceSpec], + num_devices: int, + *, + kind: str = "devices", + ) -> List[torch.device]: + if num_devices < 1: + raise ValueError(f"{kind} count must be >= 1.") + + if spec is None or (isinstance(spec, str) and spec.lower() == "auto"): + return DeviceScheduler._auto_devices(num_devices) + + if isinstance(spec, str): + return [torch.device(spec)] * num_devices + + if isinstance(spec, Sequence): + if len(spec) == 0: + raise ValueError(f"{kind} must be a non-empty list or 'auto'.") + if len(spec) == 1: + return [torch.device(spec[0])] * num_devices + if len(spec) != num_devices: + raise ValueError( + f"{kind} length ({len(spec)}) must be 1 or {num_devices}." + ) + return [torch.device(s) for s in spec] + + raise ValueError(f"Unsupported {kind} spec: {spec!r}.") + + @staticmethod + def resolve_single_device(*specs: Optional[DeviceSpec]) -> torch.device: + for spec in specs: + if spec is None: + continue + if isinstance(spec, str): + if spec.lower() == "auto": + return DeviceScheduler._auto_devices(1)[0] + return torch.device(spec) + if isinstance(spec, Sequence): + if len(spec) == 0: + raise ValueError("Device spec list must be non-empty.") + first = spec[0] + if isinstance(first, str) and first.lower() == "auto": + return DeviceScheduler._auto_devices(1)[0] + return torch.device(first) + raise ValueError(f"Unsupported device spec: {spec!r}.") + return DeviceScheduler._auto_devices(1)[0] + + @staticmethod + def devices_disjoint(device_groups: Iterable[Sequence[torch.device]]) -> bool: + seen = set() + for group in device_groups: + for device in group: + key = DeviceScheduler._device_key(device) + if key in seen: + return False + seen.add(key) + return True + + @staticmethod + def _device_key(device: torch.device) -> str: + if device.type == "cuda": + return f"cuda:{device.index}" + return device.type + + @staticmethod + def _auto_devices(num_devices: int) -> List[torch.device]: + if not torch.cuda.is_available(): + return [torch.device("cpu")] * num_devices + + count = int(torch.cuda.device_count()) + if count < 1: + return [torch.device("cpu")] * num_devices + if count < num_devices: + return [torch.device("cuda:0")] * num_devices + + indices = list(range(count)) + return [torch.device(f"cuda:{idx}") for idx in indices[:num_devices]] diff --git a/comlrl/trainers/actor_critic/ac_base.py b/comlrl/trainers/actor_critic/ac_base.py index 49eea91..1ae6523 100644 --- a/comlrl/trainers/actor_critic/ac_base.py +++ b/comlrl/trainers/actor_critic/ac_base.py @@ -1,5 +1,6 @@ from collections import defaultdict from typing import Any, Dict, List, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed import torch import wandb @@ -10,6 +11,51 @@ class ActorCriticTrainerBase: """Shared training utilities for actor-critic style trainers.""" + def _parallel_agent_mode_enabled(self) -> bool: + if str(getattr(self, "parallel_training", "")).lower() != "mp": + return False + num_agents = int(getattr(getattr(self, "args", None), "num_agents", 0) or 0) + if num_agents <= 1: + return False + devices = getattr(self, "agent_devices", None) + if not devices: + return True + unique = {str(device) for device in devices} + return len(unique) > 1 + + def _run_agent_tasks( + self, + fn, + *, + agent_indices: Optional[List[int]] = None, + parallel: Optional[bool] = None, + ) -> List[Any]: + num_agents = int(getattr(getattr(self, "args", None), "num_agents", 0) or 0) + indices = ( + list(agent_indices) + if agent_indices is not None + else list(range(max(num_agents, 0))) + ) + if not indices: + return [] + + use_parallel = ( + self._parallel_agent_mode_enabled() if parallel is None else bool(parallel) + ) + if not use_parallel or len(indices) <= 1: + return [fn(agent_idx) for agent_idx in indices] + + results: Dict[int, Any] = {} + max_workers = len(indices) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(fn, agent_idx): agent_idx for agent_idx in indices + } + for future in as_completed(futures): + agent_idx = futures[future] + results[agent_idx] = future.result() + return [results[agent_idx] for agent_idx in indices] + def _filter_model_kwargs(self, cfg: Optional[Dict[str, Any]]) -> Dict[str, Any]: torch_dtype = None if isinstance(cfg, dict): @@ -64,6 +110,7 @@ def _encode_prompt( prompt: str, agent_idx: Optional[int] = None, tokenizer: Optional[Any] = None, + device: Optional[torch.device] = None, ) -> Dict[str, torch.Tensor]: tokenizer = tokenizer or self._get_tokenizer(agent_idx) encoded = tokenizer( @@ -71,9 +118,10 @@ def _encode_prompt( return_tensors="pt", truncation=True, ) + target_device = device or self.device return { - "input_ids": encoded["input_ids"].to(self.device), - "attention_mask": encoded["attention_mask"].to(self.device), + "input_ids": encoded["input_ids"].to(target_device), + "attention_mask": encoded["attention_mask"].to(target_device), } def _prepare_advantages(self, rollouts: List[Any]) -> None: @@ -139,7 +187,9 @@ def _summarize_rollout_metrics(self, rollouts: List[Any]) -> Dict[str, float]: return metrics def _iter_dataloader(self, dataloader, epoch: int, total_epochs: int): - if getattr(self, "verbose", True): + dist_env = getattr(self, "dist_env", None) + is_main = bool(getattr(dist_env, "is_main", True)) + if getattr(self, "verbose", True) and is_main: return enumerate( tqdm( dataloader, @@ -165,7 +215,12 @@ def _on_epoch_end( epoch_metrics: Dict[str, List[float]], ) -> None: summary = self._summarize_epoch_metrics(epoch_metrics) - if summary and getattr(self, "verbose", True): + dist_env = getattr(self, "dist_env", None) + if ( + summary + and getattr(self, "verbose", True) + and getattr(dist_env, "is_main", True) + ): print(f"Epoch {epoch + 1}/{total_epochs} metrics: {summary}") def _tag_metrics( @@ -197,10 +252,9 @@ def _process_buffer( self, agent_idx: int, buffer: List[Any], - epoch_metrics: Dict[str, List[float]], - ) -> None: + ) -> Dict[str, Any]: if not buffer: - return + return {"metric_values": {}, "log_metrics": {}} has_turn_idx = any( "turn_idx" in (getattr(s, "metadata", {}) or {}) for s in buffer @@ -213,13 +267,49 @@ def _process_buffer( buffer.clear() combined_log: Dict[str, float] = {} + metric_values: Dict[str, List[float]] = {} for t_idx in sorted(turn_groups.keys()): samples = turn_groups[t_idx] metrics = self._update(agent_idx, samples) tagged = self._tag_metrics(metrics, agent_idx, turn_idx=t_idx) combined_log.update(tagged) for key, value in tagged.items(): - epoch_metrics[key].append(value) + metric_values.setdefault(key, []).append(value) + return {"metric_values": metric_values, "log_metrics": combined_log} + + def _drain_ready_agent_buffers( + self, + ready_agents: List[int], + epoch_metrics: Dict[str, List[float]], + ) -> None: + if not ready_agents: + return + + unique_ready = sorted({int(idx) for idx in ready_agents}) + run_parallel = bool( + getattr( + self, + "_parallel_update_enabled", + self._parallel_agent_mode_enabled(), + ) + ) + + def _process(agent_idx: int) -> Dict[str, Any]: + return self._process_buffer(agent_idx, self.rollout_buffers[agent_idx]) + + results = self._run_agent_tasks( + _process, + agent_indices=unique_ready, + parallel=run_parallel, + ) + + combined_log: Dict[str, float] = {} + for result in results: + metric_values = result.get("metric_values", {}) + for key, values in metric_values.items(): + for value in values: + epoch_metrics[key].append(value) + combined_log.update(result.get("log_metrics", {})) if combined_log and self._should_log_train(): self._log_metrics(combined_log) @@ -227,21 +317,24 @@ def _process_buffer( def _run_batch(self, batch, epoch_metrics: Dict[str, List[float]]) -> None: for item in batch: rollouts = self._collect_rollouts(item) + ready_agents: List[int] = [] for sample in rollouts: agent_idx = sample.agent_idx buffer = self.rollout_buffers[agent_idx] buffer.append(sample) if len(buffer) >= self.args.rollout_buffer_size: - self._process_buffer(agent_idx, buffer, epoch_metrics) + ready_agents.append(agent_idx) + if ready_agents: + self._drain_ready_agent_buffers(ready_agents, epoch_metrics) if self.args.num_agents > 0: # Count joint-action reward evaluations (one per agent group). self.env_step += len(rollouts) // self.args.num_agents def _flush_buffers(self, epoch_metrics: Dict[str, List[float]]) -> None: - for agent_idx, buffer in enumerate(self.rollout_buffers): - if not buffer: - continue - self._process_buffer(agent_idx, buffer, epoch_metrics) + ready_agents = [ + agent_idx for agent_idx, buffer in enumerate(self.rollout_buffers) if buffer + ] + self._drain_ready_agent_buffers(ready_agents, epoch_metrics) def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: @@ -275,18 +368,22 @@ def evaluate(self) -> Dict[str, float]: turn_groups: Dict[int, List[Any]] = {} seen = 0 - with torch.no_grad(): - for batch in dataloader: - for item in batch: - rollouts = self._collect_rollouts(item) - for sample in rollouts: - t_idx = int(sample.metadata.get("turn_idx", 0)) - turn_groups.setdefault(t_idx, []).append(sample) - seen += 1 + self._in_eval = True + try: + with torch.no_grad(): + for batch in dataloader: + for item in batch: + rollouts = self._collect_rollouts(item) + for sample in rollouts: + t_idx = int(sample.metadata.get("turn_idx", 0)) + turn_groups.setdefault(t_idx, []).append(sample) + seen += 1 + if seen >= num_samples: + break if seen >= num_samples: break - if seen >= num_samples: - break + finally: + self._in_eval = False eval_log: Dict[str, float] = {} for turn_idx, samples in sorted(turn_groups.items()): diff --git a/comlrl/trainers/actor_critic/iac.py b/comlrl/trainers/actor_critic/iac.py index 60d3e8c..4e40c8a 100644 --- a/comlrl/trainers/actor_critic/iac.py +++ b/comlrl/trainers/actor_critic/iac.py @@ -10,6 +10,8 @@ from datasets import Dataset, IterableDataset from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase from comlrl.models.actor_critic import CausalLMWithValueHead +from comlrl.schedulers import DeviceScheduler +from comlrl.utils.distributed import local_context, unwrap_model from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -45,6 +47,9 @@ class IACConfig: value_head_hidden_dim: Optional[int] = None num_agents: int = 2 num_turns: int = 2 + parallel_training: str = "none" + agent_devices: Optional[Union[str, Sequence[str]]] = None + critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False discount: float = 0.9 num_generations: int = 1 @@ -82,6 +87,21 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") + mode = str(self.parallel_training or "none").strip().lower() + if mode == "null": + mode = "none" + if mode not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if mode == "mp": + if self.agent_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit agent_devices." + ) + if self.critic_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit critic_devices." + ) + self.parallel_training = mode @dataclass @@ -161,8 +181,29 @@ def __init__( self.metrics_callback = metrics_callback self.model_config = model_config or {} self.critic_type = (self.args.critic_type or "v").lower() - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.parallel_training = ( + str(getattr(self.args, "parallel_training", "none")).strip().lower() + ) + if self.parallel_training == "null": + self.parallel_training = "none" + if self.parallel_training not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if self.parallel_training == "mp": + self.agent_devices, self.critic_devices = DeviceScheduler.assign_devices( + self.args.num_agents, + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + use_separate_critic=self.args.use_separate_critic, + ) + else: + single_device = DeviceScheduler.resolve_single_device( + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + ) + self.agent_devices = [single_device] * self.args.num_agents + self.critic_devices = [single_device] * self.args.num_agents + self.device = self.agent_devices[0] + self.dist_env = local_context(self.device) self.agents: List[CausalLMWithValueHead] = [] self.critics: List[CausalLMWithValueHead] = [] @@ -188,7 +229,7 @@ def __init__( expected_count=self.args.num_agents, model_label="agent_model", ) - for actor_source in actor_sources: + for idx, actor_source in enumerate(actor_sources): if actor_source is None: raise ValueError("A policy model identifier or instance is required.") if isinstance(actor_source, CausalLMWithValueHead): @@ -219,7 +260,7 @@ def __init__( value_head_hidden_dim=self.args.value_head_hidden_dim, attach_value_head=attach_value, ) - agent_model.to(self.device) + agent_model.to(self.agent_devices[idx]) self.agents.append(agent_model) if self.args.use_separate_critic: @@ -234,7 +275,7 @@ def __init__( expected_count=self.args.num_agents, model_label="critic_model", ) - for critic_source in critic_sources: + for idx, critic_source in enumerate(critic_sources): if isinstance(critic_source, CausalLMWithValueHead): critic_model = critic_source elif isinstance(critic_source, PreTrainedModel): @@ -261,7 +302,7 @@ def __init__( value_head_hidden_dim=self.args.critic_value_head_hidden_dim, attach_value_head=True, ) - critic_model.to(self.device) + critic_model.to(self.critic_devices[idx]) self.critics.append(critic_model) else: self.critics = [] @@ -274,6 +315,7 @@ def __init__( apply_tokenizer_specials(tok, models) else: apply_tokenizer_specials(self.tokenizer, [*self.agents, *self.critics]) + self.agent_optimizers = [] self.critic_optimizers = [] @@ -297,7 +339,7 @@ def __init__( self.wandb_config = wandb_config self.wandb_initialized = False self._last_train_log_step = -1 - if wandb_config is not None: + if wandb_config is not None and self.dist_env.is_main: self._init_wandb() self.verbose = True if isinstance(self.wandb_config, dict): @@ -307,7 +349,17 @@ def __init__( if isinstance(out, dict) and "verbose" in out: self.verbose = bool(out.get("verbose")) + def _agent_device(self, agent_idx: int) -> torch.device: + return self.agent_devices[agent_idx] + + def _critic_device(self, agent_idx: int) -> torch.device: + if self.args.use_separate_critic: + return self.critic_devices[agent_idx] + return self.agent_devices[agent_idx] + def _init_wandb(self) -> None: + if not self.dist_env.is_main: + return if self.wandb_initialized: return if wandb is None: @@ -416,7 +468,10 @@ def _generate_rollout( agent_idx: int, num_ret: int, ) -> Dict[str, Any]: - encoded_prompt = self._encode_prompt(prompt, agent_idx=agent_idx) + agent_device = self._agent_device(agent_idx) + encoded_prompt = self._encode_prompt( + prompt, agent_idx=agent_idx, device=agent_device + ) prompt_input_ids = encoded_prompt["input_ids"] prompt_attention_mask = encoded_prompt["attention_mask"] prompt_len = prompt_input_ids.size(1) @@ -428,13 +483,16 @@ def _generate_rollout( "do_sample": True, "temperature": self.args.temperature, "top_p": self.args.top_p, - "num_return_sequences": num_ret, + "num_return_sequences": ( + 1 if bool(getattr(self, "_in_eval", False)) else num_ret + ), "num_beams": 1, } if self.args.top_k is not None: generation_kwargs["top_k"] = self.args.top_k - sequences = agent_model.generate(**generation_kwargs) + generation_model = unwrap_model(agent_model) + sequences = generation_model.generate(**generation_kwargs) if sequences.size(1) <= prompt_len: raise RuntimeError("Model produced an empty completion during rollout.") @@ -453,17 +511,20 @@ def _generate_rollout( tokenizer.decode(seq[:resp_len], skip_special_tokens=True) ) - full_attention_mask = torch.ones_like(sequences, device=self.device) + full_attention_mask = torch.ones_like(sequences, device=agent_device) with torch.no_grad(): if self.args.use_separate_critic: critic_model = self.critics[agent_idx] if critic_model is None: raise RuntimeError("Critic model missing for agent.") + critic_device = self._critic_device(agent_idx) + critic_sequences = sequences.to(critic_device) + critic_mask = full_attention_mask.to(critic_device) value = self._value_for_critic_type( critic_model, - sequences, - full_attention_mask, + critic_sequences, + critic_mask, prompt_len, response_lens, ) @@ -517,17 +578,27 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: if num_turns > 1: return self._collect_rollouts_multi_turn(item, num_turns) - prompts: List[str] = [] - completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] - num_ret = int(getattr(self.args, "num_generations", 1)) + num_ret = ( + 1 + if bool(getattr(self, "_in_eval", False)) + else int(getattr(self.args, "num_generations", 1)) + ) + turn_prompts = [ + self._resolve_turn_prompt(item, agent_idx) + for agent_idx in range(self.args.num_agents) + ] - for agent_idx, agent_model in enumerate(self.agents): - prompt = self._resolve_turn_prompt(item, agent_idx) + def _generate_agent(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] + prompt = turn_prompts[agent_idx] gen = self._generate_rollout(agent_model, prompt, agent_idx, num_ret) - completions_per_agent.append(gen["completions"]) - rollout_data.append({"agent_idx": agent_idx, **gen}) - prompts.append(prompt) + return {"agent_idx": agent_idx, **gen} + + rollout_data = self._run_agent_tasks(_generate_agent) + prompts: List[str] = [entry["prompt"] for entry in rollout_data] + completions_per_agent: List[List[str]] = [ + entry["completions"] for entry in rollout_data + ] rewards = call_reward_function( self.reward_func, @@ -555,11 +626,11 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: logprob = data["logprobs"][i] value = data["values"][i] reward = float(rewards_matrix[agent_idx][i]) - reward_tensor = torch.tensor( - [reward], device=self.device, dtype=torch.float32 - ) - returns = reward_tensor.clone() - advantage = returns - value + reward_cpu = torch.tensor([reward], dtype=torch.float32) + value_cpu = value.detach().cpu() + returns_cpu = reward_cpu.clone() + advantage_cpu = returns_cpu - value_cpu.to(dtype=returns_cpu.dtype) + logprob_cpu = logprob.detach().cpu() rollouts.append( RolloutSample( @@ -573,14 +644,14 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=returns.detach().cpu(), - advantage=advantage.detach().cpu(), + old_logprob=logprob_cpu, + old_value=value_cpu, + reward=reward_cpu, + returns=returns_cpu, + advantage=advantage_cpu, metadata={ "char_length": data["char_lengths"][i], - "value_target": returns.detach().cpu(), + "value_target": returns_cpu, }, ) ) @@ -634,13 +705,17 @@ def _collect_rollouts_multi_turn( ] completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] - for agent_idx, agent_model in enumerate(self.agents): + for agent_idx, prompt in enumerate(turn_prompts): + prompt_history[agent_idx].append(prompt) + + def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] prompt = turn_prompts[agent_idx] gen = self._generate_rollout(agent_model, prompt, agent_idx, num_ret=1) - completions_per_agent.append(gen["completions"]) - rollout_data.append({"agent_idx": agent_idx, **gen}) - prompt_history[agent_idx].append(prompt) + return {"agent_idx": agent_idx, **gen} + + rollout_data = self._run_agent_tasks(_generate_agent_turn) + completions_per_agent = [entry["completions"] for entry in rollout_data] rewards = call_reward_function( self.reward_func, @@ -666,9 +741,9 @@ def _collect_rollouts_multi_turn( logprob = data["logprobs"][0] value = data["values"][0] reward_val = float(rewards_matrix[agent_idx][0]) - reward_tensor = torch.tensor( - [reward_val], device=self.device, dtype=torch.float32 - ) + reward_cpu = torch.tensor([reward_val], dtype=torch.float32) + value_cpu = value.detach().cpu() + logprob_cpu = logprob.detach().cpu() completion_text = data["completions"][0] sample = RolloutSample( @@ -679,11 +754,11 @@ def _collect_rollouts_multi_turn( attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=reward_tensor.detach().cpu(), - advantage=torch.zeros_like(reward_tensor).detach().cpu(), + old_logprob=logprob_cpu, + old_value=value_cpu, + reward=reward_cpu, + returns=reward_cpu.clone(), + advantage=torch.zeros_like(reward_cpu), normalized_advantage=None, metadata={ "char_length": data["char_lengths"][0], @@ -710,17 +785,15 @@ def _collect_rollouts_multi_turn( target = r + gamma * next_v else: target = r - sample.metadata["adv_target"] = torch.tensor([target]).detach().cpu() - sample.metadata["value_target"] = torch.tensor([target]).detach().cpu() + sample.metadata["adv_target"] = torch.tensor([target]).cpu() + sample.metadata["value_target"] = torch.tensor([target]).cpu() for agent_idx in range(self.args.num_agents): future = 0.0 for sample in reversed(per_agent_samples[agent_idx]): immediate = float(sample.reward.view(-1)[0].item()) future = immediate + gamma * future - sample.returns = ( - torch.tensor([future], device=self.device).detach().cpu() - ) + sample.returns = torch.tensor([future], dtype=torch.float32) sample.advantage = torch.zeros_like(sample.returns) sample.normalized_advantage = None @@ -879,10 +952,12 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa actor_losses: List[torch.Tensor] = [] value_losses: List[torch.Tensor] = [] + agent_device = self._agent_device(agent_idx) + critic_device = self._critic_device(agent_idx) for sample in batch: - sequences = sample.full_input_ids.to(self.device).unsqueeze(0) - attention_mask = sample.attention_mask.to(self.device).unsqueeze(0) + sequences = sample.full_input_ids.to(agent_device).unsqueeze(0) + attention_mask = sample.attention_mask.to(agent_device).unsqueeze(0) # Policy log-prob uses full sequence; value uses prompt-only baseline. logprob, _ = self._policy_eval( @@ -896,10 +971,12 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa if self.args.use_separate_critic: if critic_model is None: raise RuntimeError("Critic model not initialised.") + critic_sequences = sequences.to(critic_device) + critic_attention_mask = attention_mask.to(critic_device) value = self._critic_eval( critic_model, - sequences, - attention_mask, + critic_sequences, + critic_attention_mask, sample.prompt_len, sample.response_len, ) @@ -912,10 +989,13 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa [sample.response_len], ) - old_value = sample.old_value.to(self.device, dtype=value.dtype) - old_logprob = sample.old_logprob.to(self.device) - advantage = sample.normalized_advantage.to(self.device, dtype=value.dtype) - returns = sample.returns.to(self.device, dtype=value.dtype) + value_device = value.device + old_value = sample.old_value.to(value_device, dtype=value.dtype) + old_logprob = sample.old_logprob.to(agent_device) + policy_advantage = sample.normalized_advantage.to( + agent_device, dtype=logprob.dtype + ) + returns = sample.returns.to(value_device, dtype=value.dtype) if ( not torch.isfinite(logprob).all() @@ -924,17 +1004,17 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa raise FloatingPointError( "Encountered non-finite logprob during AC step." ) - if not torch.isfinite(advantage).all(): + if not torch.isfinite(policy_advantage).all(): raise FloatingPointError("Advantage contains non-finite values.") if not torch.isfinite(returns).all(): raise FloatingPointError("Returns contain non-finite values.") - policy_loss = -(logprob * advantage) + policy_loss = -(logprob * policy_advantage) value_target = sample.metadata.get("value_target") if value_target is None: raise RuntimeError("value_target missing for critic update.") - value_target = value_target.to(self.device, dtype=value.dtype) + value_target = value_target.to(value_device, dtype=value.dtype) if ( self.args.value_clip_range is not None and not self.args.use_separate_critic @@ -1025,16 +1105,18 @@ def _update( return averaged def save_model(self, output_dir: str) -> None: + if not self.dist_env.is_main: + return os.makedirs(output_dir, exist_ok=True) if self.args.num_agents == 1: - actor = self.agents[0] + actor = unwrap_model(self.agents[0]) actor.model.save_pretrained(output_dir) if actor.value_head is not None: torch.save( actor.value_head.state_dict(), os.path.join(output_dir, "value_head.pt"), ) - critic = self.critics[0] if self.critics else None + critic = unwrap_model(self.critics[0]) if self.critics else None if critic is not None: critic_dir = os.path.join(output_dir, "critic") os.makedirs(critic_dir, exist_ok=True) @@ -1046,6 +1128,7 @@ def save_model(self, output_dir: str) -> None: ) else: for agent_idx, actor in enumerate(self.agents): + actor = unwrap_model(actor) agent_dir = os.path.join(output_dir, f"agent_{agent_idx}") os.makedirs(agent_dir, exist_ok=True) actor.model.save_pretrained(agent_dir) @@ -1056,7 +1139,7 @@ def save_model(self, output_dir: str) -> None: ) if not self.critics or agent_idx >= len(self.critics): continue - critic = self.critics[agent_idx] + critic = unwrap_model(self.critics[agent_idx]) critic_dir = os.path.join(agent_dir, "critic") os.makedirs(critic_dir, exist_ok=True) critic.model.save_pretrained(critic_dir) diff --git a/comlrl/trainers/actor_critic/maac.py b/comlrl/trainers/actor_critic/maac.py index 380c5b3..57b4099 100644 --- a/comlrl/trainers/actor_critic/maac.py +++ b/comlrl/trainers/actor_critic/maac.py @@ -11,6 +11,8 @@ from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase import wandb from comlrl.models.actor_critic import CausalLMWithValueHead +from comlrl.schedulers import DeviceScheduler +from comlrl.utils.distributed import local_context, unwrap_model from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import resolve_model_sources from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths @@ -41,6 +43,9 @@ class MAACConfig: num_agents: int = 2 num_generations: int = 1 num_turns: int = 2 + parallel_training: str = "none" + agent_devices: Optional[Union[str, Sequence[str]]] = None + critic_devices: Optional[Union[str, Sequence[str]]] = None external_prompt_passthrough: bool = False discount: float = 0.9 critic_type: str = "v" # "v" (V(s)) or "q" (Q(s,a)) @@ -76,6 +81,21 @@ def __post_init__(self) -> None: raise ValueError("eval_batch_size must be >= 1.") if self.logging_steps < 1: raise ValueError("logging_steps must be >= 1.") + mode = str(self.parallel_training or "none").strip().lower() + if mode == "null": + mode = "none" + if mode not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if mode == "mp": + if self.agent_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit agent_devices." + ) + if self.critic_devices is None: + raise ValueError( + "parallel_training='mp' requires explicit critic_devices." + ) + self.parallel_training = mode class MAACTrainer(ActorCriticTrainerBase): @@ -134,8 +154,32 @@ def __init__( self.eval_dataset = eval_dataset self.metrics_callback = metrics_callback self.model_config = model_config or {} - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.parallel_training = ( + str(getattr(self.args, "parallel_training", "none")).strip().lower() + ) + if self.parallel_training == "null": + self.parallel_training = "none" + if self.parallel_training not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if self.parallel_training == "mp": + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.args.num_agents, + kind="agent_devices", + ) + self.critic_device = DeviceScheduler.assign_shared_critic_device( + self.agent_devices, getattr(self.args, "critic_devices", None) + ) + else: + single_device = DeviceScheduler.resolve_single_device( + getattr(self.args, "agent_devices", None), + getattr(self.args, "critic_devices", None), + ) + self.agent_devices = [single_device] * self.args.num_agents + self.critic_device = single_device + self.device = self.agent_devices[0] + self.dist_env = local_context(self.device) + self._parallel_update_enabled = False tokenizers = resolve_tokenizers(agent_model, tokenizer, agents) if isinstance(tokenizers, list): @@ -154,7 +198,7 @@ def __init__( expected_count=self.args.num_agents, model_label="agent_model", ) - for actor_source in actor_sources: + for idx, actor_source in enumerate(actor_sources): if actor_source is None: raise ValueError("agent_model must be provided for MAAC.") if isinstance(actor_source, CausalLMWithValueHead): @@ -183,7 +227,7 @@ def __init__( attach_value_head=False, value_head_hidden_dim=None, ) - agent_model.to(self.device) + agent_model.to(self.agent_devices[idx]) self.agents.append(agent_model) critic_sources, _critic_name = resolve_model_sources( @@ -225,7 +269,7 @@ def __init__( "critic_value_head_hidden_dim" ), ) - critic_model_instance.to(self.device) + critic_model_instance.to(self.critic_device) self.critics: List[CausalLMWithValueHead] = [critic_model_instance] if self.tokenizers and len(self.tokenizers) == len(self.agents): @@ -234,6 +278,7 @@ def __init__( apply_tokenizer_specials(self.tokenizers[0], [self.critics[0]]) else: apply_tokenizer_specials(self.tokenizer, [*self.agents, self.critics[0]]) + self.formatters = build_formatters(formatters, self.args.num_agents) try: self._reward_signature = inspect.signature(reward_func) @@ -258,7 +303,7 @@ def __init__( self.wandb_initialized = False self.env_step = 0 self._last_train_log_step = -1 - if wandb_config is not None: + if wandb_config is not None and self.dist_env.is_main: self._init_wandb() self.verbose = True if isinstance(self.wandb_config, dict): @@ -268,7 +313,12 @@ def __init__( if isinstance(out, dict) and "verbose" in out: self.verbose = bool(out.get("verbose")) + def _agent_device(self, agent_idx: int) -> torch.device: + return self.agent_devices[agent_idx] + def _init_wandb(self) -> None: + if not self.dist_env.is_main: + return if self.wandb_config is None: self.wandb_config = {} wandb_project = self.wandb_config.get("project", "comlrl") @@ -377,7 +427,11 @@ def _build_critic_input( return base + "\n\n" + "\n\n".join(action_lines) def _critic_value_from_text(self, critic_input: str) -> Dict[str, Any]: - encoded = self._encode_prompt(critic_input, tokenizer=self._get_tokenizer(0)) + encoded = self._encode_prompt( + critic_input, + tokenizer=self._get_tokenizer(0), + device=self.critic_device, + ) ids = encoded["input_ids"] mask = encoded["attention_mask"] prompt_len = ids.size(1) @@ -391,12 +445,19 @@ def _critic_value_from_text(self, critic_input: str) -> Dict[str, Any]: } def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: - encoded_prompt = self._encode_prompt(prompt, agent_idx=agent_idx) + agent_device = self._agent_device(agent_idx) + encoded_prompt = self._encode_prompt( + prompt, agent_idx=agent_idx, device=agent_device + ) prompt_input_ids = encoded_prompt["input_ids"] prompt_attention_mask = encoded_prompt["attention_mask"] prompt_len = prompt_input_ids.size(1) - num_ret = int(self.args.num_generations) + num_ret = ( + 1 + if bool(getattr(self, "_in_eval", False)) + else int(self.args.num_generations) + ) generation_kwargs: Dict[str, Any] = { "input_ids": prompt_input_ids, "attention_mask": prompt_attention_mask, @@ -410,7 +471,8 @@ def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: if self.args.top_k is not None: generation_kwargs["top_k"] = self.args.top_k - sequences = agent_model.generate(**generation_kwargs) + generation_model = unwrap_model(agent_model) + sequences = generation_model.generate(**generation_kwargs) if sequences.size(1) <= prompt_len: raise RuntimeError("Model produced an empty completion during rollout.") @@ -433,7 +495,7 @@ def _generate(self, agent_model, prompt: str, agent_idx: int) -> Dict[str, Any]: "prompt": prompt, "prompt_len": prompt_len, "sequences": sequences, - "attention_mask": torch.ones_like(sequences, device=self.device), + "attention_mask": torch.ones_like(sequences, device=agent_device), "response_lens": response_lens, "completions": completion_texts, } @@ -443,26 +505,31 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: if num_turns > 1: return self._collect_rollouts_multi_turn(item, num_turns) - prompts: List[str] = [] - completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] num_ret = int(self.args.num_generations) + turn_prompts = [ + self._resolve_turn_prompt(item, agent_idx) + for agent_idx in range(self.args.num_agents) + ] - for agent_idx, agent_model in enumerate(self.agents): - prompt = self._resolve_turn_prompt(item, agent_idx) + def _generate_agent(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] + prompt = turn_prompts[agent_idx] gen = self._generate(agent_model, prompt, agent_idx) - prompts.append(prompt) - completions_per_agent.append(gen["completions"]) - rollout_data.append( - { - "agent_idx": agent_idx, - "prompt": prompt, - "prompt_len": gen["prompt_len"], - "sequences": gen["sequences"], - "attention_mask": gen["attention_mask"], - "response_lens": gen["response_lens"], - } - ) + return { + "agent_idx": agent_idx, + "prompt": prompt, + "prompt_len": gen["prompt_len"], + "sequences": gen["sequences"], + "attention_mask": gen["attention_mask"], + "response_lens": gen["response_lens"], + "completion_texts": gen["completions"], + } + + rollout_data = self._run_agent_tasks(_generate_agent) + prompts: List[str] = [entry["prompt"] for entry in rollout_data] + completions_per_agent: List[List[str]] = [ + entry["completion_texts"] for entry in rollout_data + ] rewards = call_reward_function( self.reward_func, @@ -516,7 +583,6 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: attn = data["attention_mask"][i] resp_len = data["response_lens"][i] reward = float(rewards_matrix[agent_idx][i]) - reward_tensor = torch.tensor([reward], device=self.device) logprob, _ = self._policy_eval( self.agents[agent_idx], @@ -536,6 +602,8 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: joint_mask = critic_pack["attention_mask"] joint_len = int(critic_pack["prompt_len"]) value = critic_pack["value"].detach().cpu() + reward_cpu = torch.tensor([reward], dtype=torch.float32) + logprob_cpu = logprob.detach().cpu() rollouts.append( RolloutSample( agent_idx=agent_idx, @@ -548,25 +616,25 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=reward_tensor.detach().cpu(), - advantage=torch.zeros_like(reward_tensor).detach().cpu(), + old_logprob=logprob_cpu, + old_value=value, + reward=reward_cpu, + returns=reward_cpu.clone(), + advantage=torch.zeros_like(reward_cpu), normalized_advantage=None, metadata={ "joint_input_ids": joint_ids.detach().cpu(), "joint_attention_mask": joint_mask.detach().cpu(), "joint_prompt_len": joint_len, "turn_idx": 0, - "adv_target": reward_tensor.detach().cpu(), + "adv_target": reward_cpu, }, ) ) for sample in rollouts: r = float(sample.reward.view(-1)[0].item()) - sample.metadata["value_target"] = torch.tensor([r]).detach().cpu() + sample.metadata["value_target"] = torch.tensor([r]).cpu() if self.metrics_callback is not None: extra = self.metrics_callback(rollouts) @@ -621,23 +689,27 @@ def _collect_rollouts_multi_turn( ] completions_per_agent: List[List[str]] = [] - rollout_data: List[Dict[str, Any]] = [] - for agent_idx, agent_model in enumerate(self.agents): + for agent_idx, prompt in enumerate(turn_prompts): + prompt_history[agent_idx].append(prompt) + + def _generate_agent_turn(agent_idx: int) -> Dict[str, Any]: + agent_model = self.agents[agent_idx] prompt = turn_prompts[agent_idx] gen = self._generate(agent_model, prompt, agent_idx) - completions_per_agent.append(gen["completions"]) - rollout_data.append( - { - "agent_idx": agent_idx, - "prompt": prompt, - "prompt_len": gen["prompt_len"], - "sequences": gen["sequences"], - "attention_mask": gen["attention_mask"], - "response_lens": gen["response_lens"], - "completion_texts": gen["completions"], - } - ) - prompt_history[agent_idx].append(prompt) + return { + "agent_idx": agent_idx, + "prompt": prompt, + "prompt_len": gen["prompt_len"], + "sequences": gen["sequences"], + "attention_mask": gen["attention_mask"], + "response_lens": gen["response_lens"], + "completion_texts": gen["completions"], + } + + rollout_data = self._run_agent_tasks(_generate_agent_turn) + completions_per_agent = [ + entry["completion_texts"] for entry in rollout_data + ] rewards = call_reward_function( self.reward_func, @@ -671,7 +743,7 @@ def _collect_rollouts_multi_turn( attn = data["attention_mask"][0] resp_len = data["response_lens"][0] reward_val = float(rewards_matrix[agent_idx][0]) - reward_tensor = torch.tensor([reward_val], device=self.device) + reward_cpu = torch.tensor([reward_val], dtype=torch.float32) logprob, _ = self._policy_eval( self.agents[agent_idx], @@ -683,6 +755,7 @@ def _collect_rollouts_multi_turn( ) value = joint_value.detach().cpu() + logprob_cpu = logprob.detach().cpu() completion_text = data["completion_texts"][0] sample = RolloutSample( agent_idx=agent_idx, @@ -692,11 +765,11 @@ def _collect_rollouts_multi_turn( attention_mask=attn.detach().cpu(), prompt_len=data["prompt_len"], response_len=resp_len, - old_logprob=logprob.detach().cpu(), - old_value=value.detach().cpu(), - reward=reward_tensor.detach().cpu(), - returns=reward_tensor.detach().cpu(), - advantage=torch.zeros_like(reward_tensor).detach().cpu(), + old_logprob=logprob_cpu, + old_value=value, + reward=reward_cpu, + returns=reward_cpu.clone(), + advantage=torch.zeros_like(reward_cpu), normalized_advantage=None, metadata={ "joint_input_ids": joint_ids.detach().cpu(), @@ -725,17 +798,15 @@ def _collect_rollouts_multi_turn( target = r + gamma * next_v else: target = r - sample.metadata["adv_target"] = torch.tensor([target]).detach().cpu() - sample.metadata["value_target"] = torch.tensor([target]).detach().cpu() + sample.metadata["adv_target"] = torch.tensor([target]).cpu() + sample.metadata["value_target"] = torch.tensor([target]).cpu() for agent_idx in range(self.args.num_agents): future = 0.0 for sample in reversed(per_agent_samples[agent_idx]): immediate = float(sample.reward.view(-1)[0].item()) future = immediate + gamma * future - sample.returns = ( - torch.tensor([future], device=self.device).detach().cpu() - ) + sample.returns = torch.tensor([future], dtype=torch.float32) sample.advantage = torch.zeros_like(sample.returns) sample.normalized_advantage = None @@ -835,10 +906,12 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa actor_losses: List[torch.Tensor] = [] value_losses: List[torch.Tensor] = [] + agent_device = self._agent_device(agent_idx) + critic_device = self.critic_device for sample in batch: - sequences = sample.full_input_ids.to(self.device).unsqueeze(0) - attention_mask = sample.attention_mask.to(self.device).unsqueeze(0) + sequences = sample.full_input_ids.to(agent_device).unsqueeze(0) + attention_mask = sample.attention_mask.to(agent_device).unsqueeze(0) logprob, _ = self._policy_eval( agent_model, @@ -849,30 +922,33 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa output_values=False, ) - joint_ids = sample.metadata["joint_input_ids"].to(self.device) - joint_mask = sample.metadata["joint_attention_mask"].to(self.device) + joint_ids = sample.metadata["joint_input_ids"].to(critic_device) + joint_mask = sample.metadata["joint_attention_mask"].to(critic_device) joint_len = int(sample.metadata["joint_prompt_len"]) value = self._value_on_prompt_only( self.critics[0], joint_ids, joint_mask, joint_len ) - old_value = sample.old_value.to(self.device, dtype=value.dtype) - advantage = sample.normalized_advantage.to(self.device, dtype=value.dtype) + value_device = value.device + old_value = sample.old_value.to(value_device, dtype=value.dtype) + policy_advantage = sample.normalized_advantage.to( + agent_device, dtype=logprob.dtype + ) value_target = sample.metadata.get("value_target") if value_target is None: raise RuntimeError("value_target missing for critic update.") - returns = value_target.to(self.device, dtype=value.dtype) + returns = value_target.to(value_device, dtype=value.dtype) if not torch.isfinite(logprob).all(): raise FloatingPointError( "Encountered non-finite logprob during AC step." ) - if not torch.isfinite(advantage).all(): + if not torch.isfinite(policy_advantage).all(): raise FloatingPointError("Advantage contains non-finite values.") if not torch.isfinite(returns).all(): raise FloatingPointError("Returns contain non-finite values.") - policy_loss = -(logprob * advantage) + policy_loss = -(logprob * policy_advantage) value_error = (returns - value) ** 2 actor_losses.append(policy_loss) @@ -948,22 +1024,26 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: self._log_metrics(epoch_log) summary = self._summarize_epoch_metrics(epoch_metrics) - if summary and getattr(self, "verbose", True): + if summary and getattr(self, "verbose", True) and self.dist_env.is_main: to_print = epoch_log if epoch_log else summary print(f"Epoch {epoch + 1}/{total_epochs} metrics: {to_print}") def save_model(self, output_dir: str) -> None: + if not self.dist_env.is_main: + return os.makedirs(output_dir, exist_ok=True) for agent_idx, actor in enumerate(self.agents): + actor = unwrap_model(actor) agent_dir = os.path.join(output_dir, f"agent_{agent_idx}") os.makedirs(agent_dir, exist_ok=True) actor.model.save_pretrained(agent_dir) critic_dir = os.path.join(output_dir, "critic") os.makedirs(critic_dir, exist_ok=True) - self.critics[0].model.save_pretrained(critic_dir) - if self.critics[0].value_head is not None: + critic = unwrap_model(self.critics[0]) + critic.model.save_pretrained(critic_dir) + if critic.value_head is not None: torch.save( - self.critics[0].value_head.state_dict(), + critic.value_head.state_dict(), os.path.join(critic_dir, "value_head.pt"), ) if self.tokenizers: diff --git a/comlrl/trainers/reinforce/magrpo.py b/comlrl/trainers/reinforce/magrpo.py index d6289b4..e27f200 100644 --- a/comlrl/trainers/reinforce/magrpo.py +++ b/comlrl/trainers/reinforce/magrpo.py @@ -3,6 +3,7 @@ import random from dataclasses import dataclass import itertools +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Type, Sequence import numpy as np @@ -13,6 +14,11 @@ from tqdm import tqdm # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizerBase +from comlrl.schedulers import DeviceScheduler +from comlrl.utils.distributed import ( + local_context, + unwrap_model, +) from comlrl.utils.formatters import build_formatters from comlrl.utils.model_loading import infer_model_name, resolve_model_sources from comlrl.utils.reward_utils import call_reward_function @@ -31,6 +37,8 @@ class MAGRPOConfig: agent_learning_rate: float = 5.0e-6 logging_steps: int = 50 num_agents: int = 2 + parallel_training: str = "none" + agent_devices: Optional[Union[str, Sequence[str]]] = None # Sampling/generation num_generations: int = 4 @@ -79,6 +87,14 @@ def __post_init__(self) -> None: self.train_batch_size = self.rollout_buffer_size if self.train_batch_size < 1: raise ValueError("train_batch_size must be >= 1.") + mode = str(self.parallel_training or "none").strip().lower() + if mode == "null": + mode = "none" + if mode not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + if mode == "mp" and self.agent_devices is None: + raise ValueError("parallel_training='mp' requires explicit agent_devices.") + self.parallel_training = mode @dataclass @@ -144,9 +160,17 @@ def __init__( eval_aggregator: Optional[Callable] = None, args: Optional[MAGRPOConfig] = None, ): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.args = args if args is not None else self.default_config_cls() + self.parallel_training = ( + str(getattr(self.args, "parallel_training", "none")).strip().lower() + ) + if self.parallel_training == "null": + self.parallel_training = "none" + if self.parallel_training not in {"none", "mp"}: + raise ValueError("parallel_training only supports: none, mp.") + self.dist_env = local_context() + self.device = self.dist_env.device + if agent_model is None and agents is None: raise ValueError("Either agent_model or agents must be provided.") if ( @@ -191,6 +215,19 @@ def __init__( self.num_agents = expected_count self.model_name = model_name + if self.parallel_training == "mp": + self.agent_devices = DeviceScheduler.resolve_devices( + getattr(self.args, "agent_devices", None), + self.num_agents, + kind="agent_devices", + ) + else: + single_device = DeviceScheduler.resolve_single_device( + getattr(self.args, "agent_devices", None) + ) + self.agent_devices = [single_device] * self.num_agents + self.device = self.agent_devices[0] + self.dist_env = local_context(self.device) if actor_sources and all(isinstance(src, str) for src in actor_sources): from transformers import AutoModelForCausalLM @@ -210,6 +247,9 @@ def __init__( self.agents = list(actor_sources) self.critics = [] + for idx, agent in enumerate(self.agents): + agent.to(self.agent_devices[idx]) + tokenizers = resolve_tokenizers(agent_model, tokenizer, actor_sources) if isinstance(tokenizers, list): self.tokenizers = tokenizers @@ -259,7 +299,7 @@ def __init__( self.wandb_config = wandb_config self.wandb_initialized = False - if self.wandb_config is not None: + if self.wandb_config is not None and self.dist_env.is_main: self._init_wandb() self.dataset_type = dataset_type or None @@ -283,8 +323,52 @@ def __init__( if isinstance(out, dict) and "verbose" in out: self.verbose = bool(out.get("verbose")) + def _parallel_agent_mode_enabled(self) -> bool: + if str(getattr(self, "parallel_training", "")).lower() != "mp": + return False + if int(getattr(self, "num_agents", 0) or 0) <= 1: + return False + devices = getattr(self, "agent_devices", None) + if not devices: + return True + unique = {str(device) for device in devices} + return len(unique) > 1 + + def _run_agent_tasks( + self, + fn, + *, + agent_indices: Optional[List[int]] = None, + parallel: Optional[bool] = None, + ) -> List[Any]: + indices = ( + list(agent_indices) + if agent_indices is not None + else list(range(self.num_agents)) + ) + if not indices: + return [] + use_parallel = ( + self._parallel_agent_mode_enabled() if parallel is None else bool(parallel) + ) + if not use_parallel or len(indices) <= 1: + return [fn(agent_idx) for agent_idx in indices] + + results: Dict[int, Any] = {} + max_workers = min(len(indices), max(1, len(indices))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(fn, agent_idx): agent_idx for agent_idx in indices + } + for future in as_completed(futures): + agent_idx = futures[future] + results[agent_idx] = future.result() + return [results[agent_idx] for agent_idx in indices] + def _init_wandb(self): """Initialize Weights & Biases for tracking with multi-turn config.""" + if not self.dist_env.is_main: + return if not self.wandb_initialized: if self.wandb_config is None: self.wandb_config = {} @@ -422,18 +506,21 @@ def get_eval_dataloader(self) -> Optional[DataLoader]: num_workers=0, ) - def evaluate(self, num_eval_samples: int = 4) -> Dict[str, float]: + def evaluate(self, num_eval_samples: Optional[int] = None) -> Dict[str, float]: """ Unified evaluation that supports both single-turn and multi-turn. Args: - num_eval_samples: Number of samples to evaluate + num_eval_samples: Number of samples to evaluate. Defaults to args.eval_num_samples. Returns: Dictionary containing evaluation metrics """ if self.eval_dataset is None: return {} + if num_eval_samples is None: + num_eval_samples = int(getattr(self.args, "eval_num_samples", 4)) + num_eval_samples = int(num_eval_samples) # Storage for completions across turns for all agents all_agent_completions_turns = [[] for _ in range(self.num_agents)] @@ -561,8 +648,10 @@ def _evaluate_sample( "External transition must return a list or tuple of external prompts for each agent" ) - # Generate and extract one completion from each agent for evaluation - for agent_idx in range(self.num_agents): + # Generate and extract one completion from each agent for evaluation. + # In MP mode this executes generation concurrently, while keeping + # synchronous consumption order by agent index. + def _generate_eval_agent(agent_idx: int) -> Dict[str, Any]: agent_completions = self._generate_completions_with_external_prompts( self.agents[agent_idx], [batch_item], @@ -571,12 +660,17 @@ def _evaluate_sample( max_new_tokens=self.args.max_new_tokens, external_prompts=agent_external_prompts[agent_idx], ) - # Extract the completion directly - completion = agent_completions["completions"][0][0] - # Record prompt used this turn - used_prompt = agent_completions["prompts"][0] - eval_prompt_history[agent_idx].append(used_prompt) - agent_sample_completions[agent_idx].append(completion) + return { + "agent_idx": agent_idx, + "completion": agent_completions["completions"][0][0], + "used_prompt": agent_completions["prompts"][0], + } + + eval_outputs = self._run_agent_tasks(_generate_eval_agent) + for output in eval_outputs: + agent_idx = int(output["agent_idx"]) + eval_prompt_history[agent_idx].append(output["used_prompt"]) + agent_sample_completions[agent_idx].append(output["completion"]) # Compute immediate reward at this turn (single joint sample) agent_completions_for_reward = [ @@ -652,9 +746,8 @@ def train(self, **kwargs): if self.wandb_config is not None and not self.wandb_initialized: self._init_wandb() - device = self.device - for agent in self.agents: - agent.to(device) + for agent_idx, agent in enumerate(self.agents): + agent.to(self.agent_devices[agent_idx]) agent.train() # Create the data pipeline for generating examples @@ -682,7 +775,6 @@ def train(self, **kwargs): if int(self.args.eval_interval) > 0 and ( batch_idx % int(self.args.eval_interval) == 0 ): - # evaluate() already logs its metrics; avoid duplicate logging here _ = self.evaluate(num_eval_samples=int(self.args.eval_num_samples)) # Process single batch item (batch_size=1 enforced) @@ -695,25 +787,28 @@ def train(self, **kwargs): **kwargs, ) - for agent_idx, buffer in enumerate(self.rollout_buffers): - if buffer: - self._process_buffer(agent_idx, buffer) + ready_agents = [ + agent_idx + for agent_idx, buffer in enumerate(self.rollout_buffers) + if buffer + ] + self._drain_ready_buffers(ready_agents) # Log per-turn epoch averages inline (avoid custom system/* metrics) - if self.wandb_initialized and wandb.run is not None: - epoch_log: Dict[str, Any] = {} - n_turns = max(1, int(self.args.num_turns)) - for turn_idx in range(n_turns): - if epoch_turn_rewards and epoch_turn_rewards[turn_idx]: - epoch_log[f"turn_{turn_idx + 1}/epoch_reward_mean"] = float( - np.mean(epoch_turn_rewards[turn_idx]) - ) - if epoch_turn_returns and epoch_turn_returns[turn_idx]: - epoch_log[f"turn_{turn_idx + 1}/epoch_avg_return"] = float( - np.mean(epoch_turn_returns[turn_idx]) - ) - if epoch_log: - wandb.log(epoch_log, step=self.env_step) + epoch_log: Dict[str, Any] = {} + n_turns = max(1, int(self.args.num_turns)) + for turn_idx in range(n_turns): + if epoch_turn_rewards and epoch_turn_rewards[turn_idx]: + epoch_log[f"turn_{turn_idx + 1}/epoch_reward_mean"] = float( + np.mean(epoch_turn_rewards[turn_idx]) + ) + if epoch_turn_returns and epoch_turn_returns[turn_idx]: + epoch_log[f"turn_{turn_idx + 1}/epoch_avg_return"] = float( + np.mean(epoch_turn_returns[turn_idx]) + ) + + if epoch_log and self.wandb_initialized and wandb.run is not None: + wandb.log(epoch_log, step=self.env_step) def _train_step_returns( self, @@ -746,9 +841,8 @@ def build_node( prompt_history_per_agent: Optional[List[List[str]]] = None, response_history_per_agent: Optional[List[List[str]]] = None, ): - comps_per_agent = [] - for agent_idx in range(self.num_agents): - comps = self._generate_completions_with_external_prompts( + def _generate_agent_node(agent_idx: int) -> Dict[str, Any]: + return self._generate_completions_with_external_prompts( self.agents[agent_idx], [batch_item], agent_idx=agent_idx, @@ -759,7 +853,8 @@ def build_node( ), **kwargs, ) - comps_per_agent.append(comps) + + comps_per_agent = self._run_agent_tasks(_generate_agent_node) agent_completions_list = [ comps_per_agent[i]["completions"][0] for i in range(self.num_agents) @@ -995,10 +1090,22 @@ def post_order_update(node): post_order_update(root) + grouped_pending: Dict[int, List[Tuple[int, NodeSample]]] = {} for agent_idx, samples in enumerate(pending_samples): samples.sort(key=lambda s: s.node_env_step) for sample in samples: - self._append_to_buffer(agent_idx, sample) + step = int(sample.node_env_step) + grouped_pending.setdefault(step, []).append((agent_idx, sample)) + + for step in sorted(grouped_pending.keys()): + ready_agents: List[int] = [] + step_samples = sorted(grouped_pending[step], key=lambda x: x[0]) + for agent_idx, sample in step_samples: + buffer = self.rollout_buffers[agent_idx] + buffer.append(sample) + if len(buffer) >= int(self.args.rollout_buffer_size): + ready_agents.append(agent_idx) + self._drain_ready_buffers(ready_agents) # Build per-turn batch summary batch_loss = float(np.mean(np.abs(root.get("returns") or [0.0]))) @@ -1041,7 +1148,8 @@ def _generate_completions( Returns: Dict: A dictionary containing generated completions and associated data """ - device = agent.device + agent_module = unwrap_model(agent) + device = next(agent_module.parameters()).device # Apply the appropriate formatter to create prompts from batch items if prompts_override is not None: @@ -1070,7 +1178,7 @@ def _generate_completions( if self.tokenizer is None: raise ValueError("Tokenizer is required for generating completions") tokenizer = self.tokenizers[agent_idx] - apply_tokenizer_specials(tokenizer, [agent]) + apply_tokenizer_specials(tokenizer, [agent_module]) pad_id = tokenizer.pad_token_id prompt_encodings = tokenizer( @@ -1114,7 +1222,7 @@ def _generate_completions( kwargs.pop("do_sample", None) generation_kwargs.update(kwargs) - generation_output = agent.generate(**generation_kwargs) + generation_output = agent_module.generate(**generation_kwargs) except Exception as e: raise ValueError(f"Generation failed: {str(e)}") @@ -1302,23 +1410,17 @@ def _compute_rewards( def _pack_completions_for_buffer( self, completions_data: Dict[str, Any] ) -> Dict[str, Any]: - prompt_ids = completions_data["prompt_input_ids"].detach().cpu() + prompt_ids = completions_data["prompt_input_ids"].cpu() completion_ids = completions_data["completion_input_ids"] if completion_ids and isinstance(completion_ids[0], list): - packed_completion_ids = [[t.detach().cpu() for t in completion_ids[0]]] + packed_completion_ids = [[t.cpu() for t in completion_ids[0]]] else: - packed_completion_ids = [[t.detach().cpu() for t in completion_ids]] + packed_completion_ids = [[t.cpu() for t in completion_ids]] return { "prompt_input_ids": prompt_ids, "completion_input_ids": packed_completion_ids, } - def _append_to_buffer(self, agent_idx: int, sample: NodeSample) -> None: - buffer = self.rollout_buffers[agent_idx] - buffer.append(sample) - if len(buffer) >= int(self.args.rollout_buffer_size): - self._process_buffer(agent_idx, buffer) - def _should_log_train(self, step: int) -> bool: interval = int(getattr(self.args, "logging_steps", 1)) if interval <= 1: @@ -1332,18 +1434,22 @@ def _should_log_train(self, step: int) -> bool: return True return False - def _process_buffer(self, agent_idx: int, buffer: List[NodeSample]) -> None: + def _process_buffer( + self, agent_idx: int, buffer: List[NodeSample] + ) -> Dict[str, Any]: if not buffer: - return + return {"log_entries": []} turn_groups: Dict[int, List[NodeSample]] = {} for sample in buffer: t_idx = int(sample.turn_idx) turn_groups.setdefault(t_idx, []).append(sample) buffer.clear() + + log_entries: List[Dict[str, Any]] = [] for t_idx in sorted(turn_groups.keys()): samples = turn_groups[t_idx] self._update_from_samples(agent_idx, samples) - if self.wandb_initialized and wandb.run is not None and samples: + if samples: batch_log: Dict[str, Any] = {} prefix = f"turn_{t_idx + 1}/" batch_log[prefix + "reward_mean"] = float( @@ -1353,8 +1459,48 @@ def _process_buffer(self, agent_idx: int, buffer: List[NodeSample]) -> None: np.mean([s.node_mean_return for s in samples]) ) step = max(s.node_env_step for s in samples) - if self._should_log_train(step): - wandb.log(batch_log, step=step) + log_entries.append( + { + "agent_idx": int(agent_idx), + "step": int(step), + "metrics": batch_log, + } + ) + return {"log_entries": log_entries} + + def _drain_ready_buffers(self, ready_agents: List[int]) -> None: + if not ready_agents: + return + unique_ready = sorted({int(idx) for idx in ready_agents}) + run_parallel = self._parallel_agent_mode_enabled() + + def _process(agent_idx: int) -> Dict[str, Any]: + return self._process_buffer(agent_idx, self.rollout_buffers[agent_idx]) + + results = self._run_agent_tasks( + _process, + agent_indices=unique_ready, + parallel=run_parallel, + ) + + if not (self.wandb_initialized and wandb.run is not None): + return + + all_log_entries: List[Dict[str, Any]] = [] + for result in results: + all_log_entries.extend(result.get("log_entries", [])) + + all_log_entries.sort( + key=lambda entry: ( + int(entry.get("step", 0)), + int(entry.get("agent_idx", 0)), + ) + ) + for entry in all_log_entries: + step = int(entry.get("step", self.env_step)) + metrics = entry.get("metrics") or {} + if metrics and self._should_log_train(step): + wandb.log(metrics, step=step) def _update_from_samples(self, agent_idx: int, samples: List[NodeSample]) -> None: if not samples: @@ -1404,7 +1550,8 @@ def _compute_loss_with_gradients(self, agent, completions_data, returns): Returns: torch.Tensor: The computed loss with gradients attached """ - device = agent.device + agent_module = unwrap_model(agent) + device = next(agent_module.parameters()).device # Make sure we have the correct number of rewards if len(returns) == 0: @@ -1502,6 +1649,7 @@ def save_model(self, output_dir): os.makedirs(output_dir, exist_ok=True) for agent_idx, agent in enumerate(self.agents): + agent = unwrap_model(agent) agent_dir = f"{output_dir}/agent_{agent_idx}" os.makedirs(agent_dir, exist_ok=True) diff --git a/comlrl/utils/__init__.py b/comlrl/utils/__init__.py index dcc9bf6..bca4c0d 100644 --- a/comlrl/utils/__init__.py +++ b/comlrl/utils/__init__.py @@ -2,6 +2,14 @@ from .model_loading import infer_model_name, resolve_model_sources from .reward_processor import RewardProcessors from .reward_utils import call_reward_function, normalize_reward_lengths +from .distributed import ( + DistributedContext, + all_gather_objects, + barrier, + is_main_process, + local_context, + unwrap_model, +) from .tokenizer_utils import ( apply_tokenizer_specials, ensure_pad_token, @@ -22,4 +30,10 @@ "RewardProcessors", "call_reward_function", "normalize_reward_lengths", + "DistributedContext", + "local_context", + "unwrap_model", + "is_main_process", + "barrier", + "all_gather_objects", ] diff --git a/comlrl/utils/distributed.py b/comlrl/utils/distributed.py new file mode 100644 index 0000000..b23143e --- /dev/null +++ b/comlrl/utils/distributed.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch + + +@dataclass(frozen=True) +class DistributedContext: + enabled: bool + is_main: bool + device: torch.device + + +def local_context(device: Optional[torch.device] = None) -> DistributedContext: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return DistributedContext( + enabled=False, + is_main=True, + device=device, + ) + + +def unwrap_model(model: Any) -> Any: + return getattr(model, "module", model) + + +def is_main_process(ctx: Optional[DistributedContext]) -> bool: + if ctx is None: + return True + return bool(ctx.is_main) + + +def barrier(ctx: Optional[DistributedContext]) -> None: # noqa: ARG001 + return + + +def all_gather_objects( + obj: Any, ctx: Optional[DistributedContext] +) -> List[Any]: # noqa: ARG001 + return [obj] + + +def reduce_metrics_dict( + metrics: Dict[str, float], + ctx: Optional[DistributedContext], # noqa: ARG001 +) -> Dict[str, float]: + return dict(metrics) diff --git a/docs/content/docs/dev/changelog.md b/docs/content/docs/dev/changelog.md index 91aa575..80c44f7 100644 --- a/docs/content/docs/dev/changelog.md +++ b/docs/content/docs/dev/changelog.md @@ -5,10 +5,15 @@ weight: 3 --- -## Version 1.3.6 +## Version 1.3.7 -- Fixed critical bug of loading heterogeneous models and reform the model loading logics -- Polish the docs +- Remove the redundant sampling hyperparameters in algorithms. +- Allow multi-gpu training with MP. + +## Latest Changes + +- Fixed critical bug of loading heterogeneous models and reform the model loading logics. +- Polish the docs. ## Version 1.3.5 @@ -33,13 +38,13 @@ weight: 3 ## Version 1.3.2 -- Fix wandb logging issue in MAGRPOTrainer +- Fix wandb logging issue in MAGRPOTrainer. ## Version 1.3.1 -- Allow batch training in MAGRPOTrainer, IACTrainer and MAACTrainer -- Allow multi-turn training in IACTrainer and MAACTrainer -- Change the x-axis from data_step to env_step +- Allow batch training in MAGRPOTrainer, IACTrainer and MAACTrainer. +- Allow multi-turn training in IACTrainer and MAACTrainer. +- Change the x-axis from data_step to env_step. ## Version 1.3.0 diff --git a/docs/content/docs/env/code-completion.md b/docs/content/docs/env/code-completion.md index 0f99d09..2a8a3e7 100644 --- a/docs/content/docs/env/code-completion.md +++ b/docs/content/docs/env/code-completion.md @@ -5,7 +5,4 @@ bookHref: https://github.com/OpenMLRL/LLM_Collab_Code_Completion bookHidden: true --- -Multi-agent autocompletion tasks where each model fills in part of a codebase. -The [LLM_Collab_Code_Completion](https://github.com/OpenMLRL/LLM_Collab_Code_Completion) -project currently focuses on **ClassEval**, which asks teams of LLMs to finish -class skeletons based on docstrings and partially implemented methods. + diff --git a/docs/content/docs/env/coding.md b/docs/content/docs/env/coding.md index 8a139a5..341d6bc 100644 --- a/docs/content/docs/env/coding.md +++ b/docs/content/docs/env/coding.md @@ -4,11 +4,4 @@ weight: 2 bookHref: https://github.com/OpenMLRL/LLM_Collab_Code_Generation --- -A suite of cooperative programming benchmarks where agents propose, critique, and -refine solutions. The environments shipped in -[LLM_Collab_Code_Generation](https://github.com/OpenMLRL/LLM_Collab_Code_Generation) -cover: - -- **MBPP** – mostly basic Python problems for rapid iteration. -- **HumanEval** – handwritten tasks from OpenAI for exact-match grading. -- **CoopHumanEval** – HumanEval variants that explicitly require collaboration. + diff --git a/docs/content/docs/env/minecraft.md b/docs/content/docs/env/minecraft.md index b0a468b..a523014 100644 --- a/docs/content/docs/env/minecraft.md +++ b/docs/content/docs/env/minecraft.md @@ -4,9 +4,4 @@ weight: 4 bookHref: https://github.com/OpenMLRL/LLM_Collab_Minecraft --- -Multi-agent building environments in Minecraft. The -[LLM_Collab_Minecraft](https://github.com/OpenMLRL/LLM_Collab_Minecraft) -repository includes: - -- **[StrBuild](https://github.com/OpenMLRL/LLM_Collab_Minecraft/tree/main/str_build)** – agents collaboratively build structured designs from text. -- **[HouseBuild](https://github.com/OpenMLRL/LLM_Collab_Minecraft/tree/main/house_build)** – agents coordinate materials and steps to construct houses. + diff --git a/docs/content/docs/env/writing.md b/docs/content/docs/env/writing.md index 303fabd..9689122 100644 --- a/docs/content/docs/env/writing.md +++ b/docs/content/docs/env/writing.md @@ -4,10 +4,4 @@ weight: 1 bookHref: https://github.com/OpenMLRL/LLM_Collab_Writing --- -Collaborative summarization and expansion tasks for pairs (or teams) of LLMs. -The reference implementation lives in the -[LLM_Collab_Writing](https://github.com/OpenMLRL/LLM_Collab_Writing) repository -and includes: - -- **TLDR** – distills Reddit threads into concise summaries. -- **ArXiv Introductions** – grows short abstracts into multi-paragraph drafts. + diff --git a/docs/content/docs/user-guide/model-loading.md b/docs/content/docs/user-guide/model-loading.md index a290b33..de887ba 100644 --- a/docs/content/docs/user-guide/model-loading.md +++ b/docs/content/docs/user-guide/model-loading.md @@ -18,7 +18,7 @@ For example, to load 3 _Qwen/Qwen2.5-1.5B_ agents: trainer = MAGRPOTrainer( agent_model="Qwen/Qwen2.5-1.5B", agents=None, - num_agents=3, + args=MAGRPOConfig(num_agents=3, temperature=0.7, top_p=0.9, top_k=None), ) ``` @@ -34,7 +34,7 @@ For example, to load a _Qwen/Qwen2.5-Coder-3B_ and a _Qwen/Qwen2.5-Coder-7B_: trainer = MAGRPOTrainer( agent_model=None, agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], - num_agents=2, + args=MAGRPOConfig(num_agents=2, temperature=0.7, top_p=0.9, top_k=None), ) ``` @@ -51,7 +51,7 @@ trainer = MAACTrainer( agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-1.5B"], critic_model="Qwen/Qwen2.5-Coder-7B", critics=None, - num_agents=2, + args=MAACConfig(num_agents=2, temperature=0.7, top_p=0.9, top_k=None), ) ``` @@ -67,7 +67,13 @@ trainer = IACTrainer( agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], critic_model=None, critics=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], - num_agents=2, + args=IACConfig( + num_agents=2, + use_separate_critic=True, + temperature=0.7, + top_p=0.9, + top_k=None, + ), ) ``` @@ -79,7 +85,13 @@ trainer = IACTrainer( agents=["Qwen/Qwen2.5-Coder-3B", "Qwen/Qwen2.5-Coder-7B"], critic_model=None, critics=None, - num_agents=2, + args=IACConfig( + num_agents=2, + use_separate_critic=False, + temperature=0.7, + top_p=0.9, + top_k=None, + ), ) ``` diff --git a/docs/content/docs/user-guide/multi-turn-training.md b/docs/content/docs/user-guide/multi-turn-training.md index b53c0fe..73520f8 100644 --- a/docs/content/docs/user-guide/multi-turn-training.md +++ b/docs/content/docs/user-guide/multi-turn-training.md @@ -1,6 +1,6 @@ --- -title: Multi-Turn Training -linkTitle: Multi-Turn Training +title: Multi-Turn Interaction +linkTitle: Multi-Turn Interaction weight: 5 math: true --- diff --git a/docs/content/docs/user-guide/training-parallelization.md b/docs/content/docs/user-guide/training-parallelization.md new file mode 100644 index 0000000..b52baa3 --- /dev/null +++ b/docs/content/docs/user-guide/training-parallelization.md @@ -0,0 +1,38 @@ +--- +title: Training Parallelization +linkTitle: Training Parallelization +weight: 6 +--- + +CoMLRL supports fine-tuning multi-LLM systems with larger models and more agents when multiple GPUs are available. +Users can configure the parallelization training with `iac.parallel_training`. +Currently, `parallel_training` supports two modes: `none` or `null` is the default mode for single-device training; `mp`is the model parallel scheduling across explicit agent/critic devices. + +{{% hint success %}} +We will support more parallelization modes (e.g., [data parallelization](https://docs.pytorch.org/docs/stable/elastic/run.html), [multi-node training](ray.io)) in the future. +{{% /hint %}} + +## Model Parallelization + +When `parallel_training=mp`, CoMLRL requires explicit `agent_devices` / `critic_devices` configuration and deploys the agents and critics accordingly. +The training and inference for each model (agent/critic) are running separately on its assigned device. +The responses are aggregated on the CPU and pass to the reward function. The reward is then broadcast back to all devices for training. +MP supports training larger and more models than a single GPU can hold, but the training throughput is limited by the slowest model. + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python train_iac.py + --config configs/iac_xxx.yaml + --override + agent_model="model_a" + agents=None + critic_model="model_b" + critics=None + iac.use_separate_critic=true + iac.parallel_training=mp + iac.agent_devices='["cuda:0","cuda:1"]' + iac.critic_devices='["cuda:2","cuda:3"]' +``` + +{{% hint note %}} +Note that when devices are changed, the training is not deterministic due to the non-deterministic GPU scheduling and aggregation on CPU. +{{% /hint %}} diff --git a/tests/test_config_constraints.py b/tests/test_config_constraints.py index 8d0ee88..c87ee5f 100644 --- a/tests/test_config_constraints.py +++ b/tests/test_config_constraints.py @@ -31,10 +31,20 @@ def test_iac_config_constraints(): _assert_invalid_fields(IACConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(IACConfig, "num_generations", 0) _assert_invalid(IACConfig, "critic_type", "x") + _assert_invalid(IACConfig, "parallel_training", "invalid") + _assert_invalid(IACConfig, "parallel_training", "auto") + with pytest.raises(ValueError, match="agent_devices"): + IACConfig(parallel_training="mp", critic_devices="cpu") + with pytest.raises(ValueError, match="critic_devices"): + IACConfig(parallel_training="mp", agent_devices="cpu") with pytest.raises(ValueError, match="num_generations"): IACConfig(num_turns=2, num_generations=2) IACConfig() + IACConfig(parallel_training="none") + IACConfig(parallel_training="null") + IACConfig(parallel_training="mp", agent_devices="cpu", critic_devices="cpu") + IACConfig(agent_devices="cpu", critic_devices="cpu") IACConfig(num_turns=2, num_generations=1) IACConfig(critic_type="q") @@ -55,10 +65,20 @@ def test_maac_config_constraints(): ) _assert_invalid_fields(MAACConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(MAACConfig, "critic_type", "x") + _assert_invalid(MAACConfig, "parallel_training", "invalid") + _assert_invalid(MAACConfig, "parallel_training", "auto") + with pytest.raises(ValueError, match="agent_devices"): + MAACConfig(parallel_training="mp", critic_devices="cpu") + with pytest.raises(ValueError, match="critic_devices"): + MAACConfig(parallel_training="mp", agent_devices="cpu") with pytest.raises(ValueError, match="num_generations"): MAACConfig(num_turns=2, num_generations=2) MAACConfig() + MAACConfig(parallel_training="none") + MAACConfig(parallel_training="null") + MAACConfig(parallel_training="mp", agent_devices="cpu", critic_devices="cpu") + MAACConfig(agent_devices="cpu", critic_devices="cpu") MAACConfig(num_turns=2, num_generations=1) MAACConfig(critic_type="q") @@ -79,6 +99,14 @@ def test_magrpo_config_constraints(): ) _assert_invalid_fields(MAGRPOConfig, ["eval_interval", "eval_num_samples"], -1) _assert_invalid(MAGRPOConfig, "num_generations", 1) + _assert_invalid(MAGRPOConfig, "parallel_training", "invalid") + _assert_invalid(MAGRPOConfig, "parallel_training", "auto") + with pytest.raises(ValueError, match="agent_devices"): + MAGRPOConfig(parallel_training="mp") MAGRPOConfig() + MAGRPOConfig(parallel_training="none") + MAGRPOConfig(parallel_training="null") + MAGRPOConfig(parallel_training="mp", agent_devices="cpu") + MAGRPOConfig(agent_devices="cpu") MAGRPOConfig(num_generations=2) diff --git a/tests/test_distributed_metrics.py b/tests/test_distributed_metrics.py new file mode 100644 index 0000000..a51d551 --- /dev/null +++ b/tests/test_distributed_metrics.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace + +import torch + +from comlrl.trainers.actor_critic.ac_base import ActorCriticTrainerBase +from comlrl.utils.distributed import ( + DistributedContext, + all_gather_objects, + barrier, + local_context, + reduce_metrics_dict, +) + + +def _ctx() -> DistributedContext: + return DistributedContext( + enabled=False, + is_main=True, + device=torch.device("cpu"), + ) + + +def test_local_context_defaults_to_single_process(): + ctx = local_context(torch.device("cpu")) + assert ctx.enabled is False + assert ctx.is_main is True + + +def test_reduce_metrics_dict_passthrough(): + metrics = {"loss": 1.5, "reward": 2.5} + reduced = reduce_metrics_dict(metrics, _ctx()) + assert reduced == metrics + assert reduced is not metrics + + +def test_barrier_and_all_gather_are_noop_in_single_process(): + ctx = _ctx() + barrier(ctx) + assert all_gather_objects({"a": 1}, ctx) == [{"a": 1}] + + +def test_ac_base_log_metrics_logs_directly(monkeypatch): + trainer = ActorCriticTrainerBase() + trainer.dist_env = _ctx() + trainer.wandb_initialized = True + trainer.env_step = 5 + trainer.args = SimpleNamespace(logging_steps=1) + trainer._last_train_log_step = -1 + + called = {"log": []} + + def _fake_log(metrics, step): # noqa: ARG001 + called["log"].append(dict(metrics)) + + monkeypatch.setattr("wandb.log", _fake_log) + + trainer._log_metrics({"loss": 1.0}) + trainer._log_metrics({"reward": 2.0}) + assert called["log"] == [{"loss": 1.0}, {"reward": 2.0}] + + +def test_evaluate_calls_collect_rollouts_with_eval_flag(): + class _EvalDummyTrainer(ActorCriticTrainerBase): + def __init__(self): + self.eval_dataset = [{"id": 0}, {"id": 1}] + self.args = SimpleNamespace(eval_batch_size=1, eval_num_samples=1) + self.wandb_initialized = False + self.env_step = 0 + self.dist_env = _ctx() + self.verbose = False + self._collect_calls = 0 + self.eval_flags = [] + + def _collect_rollouts(self, item): # noqa: ARG002 + self._collect_calls += 1 + self.eval_flags.append(bool(getattr(self, "_in_eval", False))) + return [ + SimpleNamespace( + metadata={}, + reward=torch.tensor([1.0]), + returns=torch.tensor([1.5]), + old_value=torch.tensor([0.5]), + ) + ] + + trainer = _EvalDummyTrainer() + metrics = trainer.evaluate() + assert trainer._collect_calls == 1 + assert trainer.eval_flags == [True] + assert trainer._in_eval is False + assert "eval/turn_1/reward_mean" in metrics diff --git a/tests/test_model_loading.py b/tests/test_model_loading.py index 623f514..242faeb 100644 --- a/tests/test_model_loading.py +++ b/tests/test_model_loading.py @@ -18,6 +18,18 @@ def _reward_func(*_args, **_kwargs): return [0.0] +def _iac_cfg(**kwargs): + return IACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _maac_cfg(**kwargs): + return MAACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _magrpo_cfg(**kwargs): + return MAGRPOConfig(agent_devices="cpu", **kwargs) + + @pytest.fixture(scope="session") def tokenizer_05(): return AutoTokenizer.from_pretrained(MODEL_NAME_05) @@ -40,7 +52,7 @@ def _cleanup(*objs): def test_magrpo_model_name(): - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agent_model=MODEL_NAME_05, num_agents=2, @@ -53,7 +65,7 @@ def test_magrpo_model_name(): def test_magrpo_pretrained(tokenizer_05, model_05, model_06): - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agents=[model_05, model_06], tokenizer=tokenizer_05, @@ -66,7 +78,7 @@ def test_magrpo_pretrained(tokenizer_05, model_05, model_06): def test_maac_model_name(): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) trainer = MAACTrainer( agent_model=MODEL_NAME_05, critics=[MODEL_NAME_06], @@ -79,7 +91,7 @@ def test_maac_model_name(): def test_maac_pretrained(tokenizer_05, model_05, model_06): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) trainer = MAACTrainer( agents=[model_05, model_06], critics=[model_06], @@ -93,7 +105,7 @@ def test_maac_pretrained(tokenizer_05, model_05, model_06): def test_maac_critics_len_mismatch(tokenizer_05, model_05, model_06): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) with pytest.raises(ValueError, match="critics length"): MAACTrainer( agents=[model_05, model_06], @@ -105,7 +117,7 @@ def test_maac_critics_len_mismatch(tokenizer_05, model_05, model_06): def test_iac_model_name_critics(tokenizer_05, model_06): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=True) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=True) trainer = IACTrainer( agent_model=MODEL_NAME_05, critics=[model_06, model_06], @@ -119,7 +131,7 @@ def test_iac_model_name_critics(tokenizer_05, model_06): def test_iac_model_and_agents_names_conflict(tokenizer_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) with pytest.raises(ValueError, match="conflict"): IACTrainer( agent_model=MODEL_NAME_05, @@ -131,7 +143,7 @@ def test_iac_model_and_agents_names_conflict(tokenizer_05): def test_iac_model_and_agents_names_match(tokenizer_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) trainer = IACTrainer( agent_model=MODEL_NAME_05, agents=[MODEL_NAME_05, MODEL_NAME_05], @@ -144,7 +156,7 @@ def test_iac_model_and_agents_names_match(tokenizer_05): def test_iac_model_and_agents_len_mismatch(tokenizer_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) with pytest.raises(ValueError, match="agents length"): IACTrainer( agent_model=MODEL_NAME_05, @@ -161,7 +173,7 @@ def test_iac_model_and_agents_len_mismatch(tokenizer_05): ids=["shared_homo", "shared_hetero"], ) def test_iac_shared_heads(agents_case, tokenizer_05, model_05, model_06): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) agents = [model_05, model_05] if agents_case == "homo" else [model_05, model_06] trainer = IACTrainer( agents=agents, @@ -175,7 +187,7 @@ def test_iac_shared_heads(agents_case, tokenizer_05, model_05, model_06): def test_iac_shared_heads_rejects_critics(tokenizer_05, model_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=False) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=False) with pytest.raises(ValueError, match="use_separate_critic"): IACTrainer( agents=[model_05, model_05], @@ -187,7 +199,7 @@ def test_iac_shared_heads_rejects_critics(tokenizer_05, model_05): def test_iac_critics_len_mismatch(tokenizer_05, model_05): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=True) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=True) with pytest.raises(ValueError, match="critics length"): IACTrainer( agent_model=MODEL_NAME_05, @@ -204,7 +216,7 @@ def test_iac_critics_len_mismatch(tokenizer_05, model_05): ids=["critic_match", "critic_swapped"], ) def test_iac_separate_critics(critics_case, tokenizer_05, model_05, model_06): - args = IACConfig(num_agents=2, num_turns=1, use_separate_critic=True) + args = _iac_cfg(num_agents=2, num_turns=1, use_separate_critic=True) critics = [model_05, model_06] if critics_case == "match" else [model_06, model_05] trainer = IACTrainer( agents=[model_05, model_06], diff --git a/tests/test_parallel_agent_tasks.py b/tests/test_parallel_agent_tasks.py new file mode 100644 index 0000000..2562758 --- /dev/null +++ b/tests/test_parallel_agent_tasks.py @@ -0,0 +1,86 @@ +import time +from types import SimpleNamespace + +from comlrl.trainers.actor_critic.ac_base import ActorCriticTrainerBase +from comlrl.trainers.actor_critic.iac import IACTrainer +from comlrl.trainers.actor_critic.maac import MAACTrainer +from comlrl.trainers.reinforce.magrpo import MAGRPOTrainer + + +class _DummyACTrainer(ActorCriticTrainerBase): + def __init__(self, *, parallel_training: str, num_agents: int = 2): + self.parallel_training = parallel_training + self.args = SimpleNamespace(num_agents=num_agents) + self.agent_devices = [f"cuda:{idx}" for idx in range(num_agents)] + + +def test_ac_run_agent_tasks_keeps_index_order_in_mp(): + trainer = _DummyACTrainer(parallel_training="mp", num_agents=2) + completion_order = [] + + def _task(agent_idx: int) -> str: + if agent_idx == 0: + time.sleep(0.05) + completion_order.append(agent_idx) + return f"agent-{agent_idx}" + + outputs = trainer._run_agent_tasks(_task) + assert outputs == ["agent-0", "agent-1"] + assert completion_order == [1, 0] + + +def test_ac_run_agent_tasks_is_sequential_when_mp_disabled(): + trainer = _DummyACTrainer(parallel_training="none", num_agents=2) + completion_order = [] + + def _task(agent_idx: int) -> int: + completion_order.append(agent_idx) + return agent_idx + + outputs = trainer._run_agent_tasks(_task) + assert outputs == [0, 1] + assert completion_order == [0, 1] + + +def test_iac_parallel_updates_enabled_only_for_mp_mode(): + trainer = IACTrainer.__new__(IACTrainer) + trainer.args = SimpleNamespace(num_agents=2) + trainer.agent_devices = ["cuda:0", "cuda:1"] + trainer.parallel_training = "none" + assert trainer._parallel_agent_mode_enabled() is False + trainer.parallel_training = "mp" + assert trainer._parallel_agent_mode_enabled() is True + + +def test_magrpo_run_agent_tasks_keeps_index_order_in_mp(): + trainer = MAGRPOTrainer.__new__(MAGRPOTrainer) + trainer.parallel_training = "mp" + trainer.num_agents = 2 + trainer.agent_devices = ["cuda:0", "cuda:1"] + completion_order = [] + + def _task(agent_idx: int) -> str: + if agent_idx == 0: + time.sleep(0.05) + completion_order.append(agent_idx) + return f"agent-{agent_idx}" + + outputs = trainer._run_agent_tasks(_task) + assert outputs == ["agent-0", "agent-1"] + assert completion_order == [1, 0] + + +def test_maac_parallel_updates_are_always_serialized(): + trainer = MAACTrainer.__new__(MAACTrainer) + trainer._parallel_update_enabled = False + trainer.parallel_training = "mp" + trainer.args = SimpleNamespace(num_agents=2) + trainer.agent_devices = ["cuda:0", "cuda:1"] + run_parallel = bool( + getattr( + trainer, + "_parallel_update_enabled", + trainer._parallel_agent_mode_enabled(), + ) + ) + assert run_parallel is False diff --git a/tests/test_trainer_constraints.py b/tests/test_trainer_constraints.py index 1c4d545..4e37f1c 100644 --- a/tests/test_trainer_constraints.py +++ b/tests/test_trainer_constraints.py @@ -18,6 +18,18 @@ def _external_transition(**_kwargs): return [""] * 1 +def _iac_cfg(**kwargs): + return IACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _maac_cfg(**kwargs): + return MAACConfig(agent_devices="cpu", critic_devices="cpu", **kwargs) + + +def _magrpo_cfg(**kwargs): + return MAGRPOConfig(agent_devices="cpu", **kwargs) + + @pytest.fixture(scope="session") def dummy_tokenizer(): return SimpleNamespace( @@ -58,11 +70,11 @@ def tiny_model_b(): "factory, match", [ ( - lambda: IACTrainer(agent_model="dummy", reward_func=None, args=IACConfig()), + lambda: IACTrainer(agent_model="dummy", reward_func=None, args=_iac_cfg()), "reward_func", ), ( - lambda: IACTrainer(reward_func=_reward_func, args=IACConfig()), + lambda: IACTrainer(reward_func=_reward_func, args=_iac_cfg()), "Either agent_model or agents", ), ( @@ -70,7 +82,7 @@ def tiny_model_b(): agents=[object()], tokenizer=SimpleNamespace(pad_token="x", eos_token="x", pad_token_id=0), reward_func=_reward_func, - args=IACConfig(num_agents=1, num_turns=2), + args=_iac_cfg(num_agents=1, num_turns=2), ), "external_transition", ), @@ -79,18 +91,18 @@ def tiny_model_b(): agent_model="dummy", critics=[object()], reward_func=_reward_func, - args=IACConfig(use_separate_critic=False), + args=_iac_cfg(use_separate_critic=False), ), "use_separate_critic", ), ( lambda: MAACTrainer( - agent_model="dummy", reward_func=None, args=MAACConfig() + agent_model="dummy", reward_func=None, args=_maac_cfg() ), "reward_func", ), ( - lambda: MAACTrainer(reward_func=_reward_func, args=MAACConfig()), + lambda: MAACTrainer(reward_func=_reward_func, args=_maac_cfg()), "Either agent_model or agents", ), ( @@ -98,18 +110,18 @@ def tiny_model_b(): agents=[object()], tokenizer=SimpleNamespace(pad_token="x", eos_token="x", pad_token_id=0), reward_func=_reward_func, - args=MAACConfig(num_agents=1, num_turns=2), + args=_maac_cfg(num_agents=1, num_turns=2), ), "external_transition", ), ( lambda: MAGRPOTrainer( - agent_model="dummy", reward_func=None, args=MAGRPOConfig() + agent_model="dummy", reward_func=None, args=_magrpo_cfg() ), "reward_func", ), ( - lambda: MAGRPOTrainer(reward_func=_reward_func, args=MAGRPOConfig()), + lambda: MAGRPOTrainer(reward_func=_reward_func, args=_magrpo_cfg()), "Either agent_model or agents", ), ], @@ -131,7 +143,7 @@ def test_trainer_early_constraints(factory, match): def test_iac_separate_critic_requires_critics(dummy_tokenizer, tiny_model_a): - args = IACConfig(num_agents=1, use_separate_critic=True, num_turns=1) + args = _iac_cfg(num_agents=1, use_separate_critic=True, num_turns=1) with pytest.raises(ValueError, match="critics must be provided"): IACTrainer( agents=[tiny_model_a], @@ -142,7 +154,7 @@ def test_iac_separate_critic_requires_critics(dummy_tokenizer, tiny_model_a): def test_iac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=True, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=True, num_turns=1) with pytest.raises(ValueError, match="critics length"): IACTrainer( agents=[tiny_model_a, tiny_model_b], @@ -154,7 +166,7 @@ def test_iac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_iac_valid_shared_heads(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=False, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=False, num_turns=1) trainer = IACTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=dummy_tokenizer, @@ -166,7 +178,7 @@ def test_iac_valid_shared_heads(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_iac_accepts_tokenizer_list(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=False, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=False, num_turns=1) trainer = IACTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=[dummy_tokenizer, dummy_tokenizer], @@ -179,7 +191,7 @@ def test_iac_accepts_tokenizer_list(dummy_tokenizer, tiny_model_a, tiny_model_b) def test_iac_rejects_tokenizer_len_mismatch( dummy_tokenizer, tiny_model_a, tiny_model_b ): - args = IACConfig(num_agents=2, use_separate_critic=False, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=False, num_turns=1) with pytest.raises(ValueError, match="tokenizers length"): IACTrainer( agents=[tiny_model_a, tiny_model_b], @@ -190,7 +202,7 @@ def test_iac_rejects_tokenizer_len_mismatch( def test_iac_valid_separate_critics(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = IACConfig(num_agents=2, use_separate_critic=True, num_turns=1) + args = _iac_cfg(num_agents=2, use_separate_critic=True, num_turns=1) trainer = IACTrainer( agents=[tiny_model_a, tiny_model_b], critics=[tiny_model_b, tiny_model_a], @@ -204,7 +216,7 @@ def test_iac_valid_separate_critics(dummy_tokenizer, tiny_model_a, tiny_model_b) def test_iac_multiturn_with_transition(dummy_tokenizer, tiny_model_a): - args = IACConfig(num_agents=1, use_separate_critic=False, num_turns=2) + args = _iac_cfg(num_agents=1, use_separate_critic=False, num_turns=2) trainer = IACTrainer( agents=[tiny_model_a], tokenizer=dummy_tokenizer, @@ -222,12 +234,12 @@ def test_iac_multiturn_num_generations_mismatch(dummy_tokenizer, tiny_model_a): tokenizer=dummy_tokenizer, reward_func=_reward_func, external_transition=_external_transition, - args=IACConfig(num_agents=1, num_turns=2, num_generations=2), + args=_iac_cfg(num_agents=1, num_turns=2, num_generations=2), ) def test_maac_requires_critics(dummy_tokenizer, tiny_model_a): - args = MAACConfig(num_agents=1, num_turns=1) + args = _maac_cfg(num_agents=1, num_turns=1) with pytest.raises(ValueError, match="critics must be provided"): MAACTrainer( agents=[tiny_model_a], @@ -238,7 +250,7 @@ def test_maac_requires_critics(dummy_tokenizer, tiny_model_a): def test_maac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) with pytest.raises(ValueError, match="critics length"): MAACTrainer( agents=[tiny_model_a, tiny_model_b], @@ -250,7 +262,7 @@ def test_maac_critic_len_mismatch(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_maac_valid(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAACConfig(num_agents=2, num_turns=1) + args = _maac_cfg(num_agents=2, num_turns=1) trainer = MAACTrainer( agents=[tiny_model_a, tiny_model_b], critics=[tiny_model_a], @@ -263,7 +275,7 @@ def test_maac_valid(dummy_tokenizer, tiny_model_a, tiny_model_b): def test_maac_multiturn_with_transition(dummy_tokenizer, tiny_model_a): - args = MAACConfig(num_agents=1, num_turns=2, num_generations=1) + args = _maac_cfg(num_agents=1, num_turns=2, num_generations=1) trainer = MAACTrainer( agents=[tiny_model_a], critics=[tiny_model_a], @@ -284,14 +296,14 @@ def test_maac_multiturn_num_generations_mismatch(dummy_tokenizer, tiny_model_a): tokenizer=dummy_tokenizer, reward_func=_reward_func, external_transition=_external_transition, - args=MAACConfig(num_agents=1, num_turns=2, num_generations=2), + args=_maac_cfg(num_agents=1, num_turns=2, num_generations=2), ) def test_magrpo_requires_transition_for_multiturn( dummy_tokenizer, tiny_model_a, tiny_model_b ): - args = MAGRPOConfig(num_agents=2, num_turns=2, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=2, num_generations=2) with pytest.raises(ValueError, match="external_transition"): MAGRPOTrainer( agents=[tiny_model_a, tiny_model_b], @@ -302,7 +314,7 @@ def test_magrpo_requires_transition_for_multiturn( def test_magrpo_multiturn_with_transition(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAGRPOConfig(num_agents=2, num_turns=2, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=2, num_generations=2) trainer = MAGRPOTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=dummy_tokenizer, @@ -314,7 +326,7 @@ def test_magrpo_multiturn_with_transition(dummy_tokenizer, tiny_model_a, tiny_mo def test_magrpo_accepts_tokenizer_list(dummy_tokenizer, tiny_model_a, tiny_model_b): - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agents=[tiny_model_a, tiny_model_b], tokenizer=[dummy_tokenizer, dummy_tokenizer], @@ -333,7 +345,7 @@ def _fake_from_pretrained(*_args, **_kwargs): monkeypatch.setattr( "transformers.AutoModelForCausalLM.from_pretrained", _fake_from_pretrained ) - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) trainer = MAGRPOTrainer( agent_model="dummy", agents=["dummy", "dummy"], @@ -354,7 +366,7 @@ def _fake_from_pretrained(*_args, **_kwargs): monkeypatch.setattr( "transformers.AutoModelForCausalLM.from_pretrained", _fake_from_pretrained ) - args = MAGRPOConfig(num_agents=2, num_turns=1, num_generations=2) + args = _magrpo_cfg(num_agents=2, num_turns=1, num_generations=2) with pytest.raises(ValueError, match="conflict"): MAGRPOTrainer( agent_model="dummy",