Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions comlrl/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .device_scheduler import DeviceScheduler

__all__ = ["DeviceScheduler"]
121 changes: 121 additions & 0 deletions comlrl/schedulers/device_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from typing import Iterable, List, Optional, Sequence, Tuple, Union

import torch

DeviceSpec = Union[str, Sequence[str]]


class DeviceScheduler:
@staticmethod
def assign_devices(
num_agents: int,
agent_devices: Optional[DeviceSpec],
critic_devices: Optional[DeviceSpec],
*,
use_separate_critic: bool,
) -> Tuple[List[torch.device], List[torch.device]]:
agent_list = DeviceScheduler.resolve_devices(
agent_devices, num_agents, kind="agent_devices"
)
if use_separate_critic:
critic_spec = (
critic_devices if critic_devices is not None else agent_devices
)
critic_list = DeviceScheduler.resolve_devices(
critic_spec, num_agents, kind="critic_devices"
)
else:
critic_list = list(agent_list)
return agent_list, critic_list

@staticmethod
def assign_shared_critic_device(
agent_devices: Sequence[torch.device],
critic_devices: Optional[DeviceSpec],
) -> torch.device:
if critic_devices is None:
return agent_devices[0]
return DeviceScheduler.resolve_devices(
critic_devices, 1, kind="critic_devices"
)[0]

@staticmethod
def resolve_devices(
spec: Optional[DeviceSpec],
num_devices: int,
*,
kind: str = "devices",
) -> List[torch.device]:
if num_devices < 1:
raise ValueError(f"{kind} count must be >= 1.")

if spec is None or (isinstance(spec, str) and spec.lower() == "auto"):
return DeviceScheduler._auto_devices(num_devices)

if isinstance(spec, str):
return [torch.device(spec)] * num_devices

if isinstance(spec, Sequence):
if len(spec) == 0:
raise ValueError(f"{kind} must be a non-empty list or 'auto'.")
if len(spec) == 1:
return [torch.device(spec[0])] * num_devices
if len(spec) != num_devices:
raise ValueError(
f"{kind} length ({len(spec)}) must be 1 or {num_devices}."
)
return [torch.device(s) for s in spec]

raise ValueError(f"Unsupported {kind} spec: {spec!r}.")

@staticmethod
def resolve_single_device(*specs: Optional[DeviceSpec]) -> torch.device:
for spec in specs:
if spec is None:
continue
if isinstance(spec, str):
if spec.lower() == "auto":
return DeviceScheduler._auto_devices(1)[0]
return torch.device(spec)
if isinstance(spec, Sequence):
if len(spec) == 0:
raise ValueError("Device spec list must be non-empty.")
first = spec[0]
if isinstance(first, str) and first.lower() == "auto":
return DeviceScheduler._auto_devices(1)[0]
return torch.device(first)
raise ValueError(f"Unsupported device spec: {spec!r}.")
return DeviceScheduler._auto_devices(1)[0]

@staticmethod
def devices_disjoint(device_groups: Iterable[Sequence[torch.device]]) -> bool:
seen = set()
for group in device_groups:
for device in group:
key = DeviceScheduler._device_key(device)
if key in seen:
return False
seen.add(key)
return True

@staticmethod
def _device_key(device: torch.device) -> str:
if device.type == "cuda":
return f"cuda:{device.index}"
return device.type

@staticmethod
def _auto_devices(num_devices: int) -> List[torch.device]:
if not torch.cuda.is_available():
return [torch.device("cpu")] * num_devices

count = int(torch.cuda.device_count())
if count < 1:
return [torch.device("cpu")] * num_devices
if count < num_devices:
return [torch.device("cuda:0")] * num_devices

indices = list(range(count))
return [torch.device(f"cuda:{idx}") for idx in indices[:num_devices]]
143 changes: 120 additions & 23 deletions comlrl/trainers/actor_critic/ac_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
import wandb
Expand All @@ -10,6 +11,51 @@
class ActorCriticTrainerBase:
"""Shared training utilities for actor-critic style trainers."""

def _parallel_agent_mode_enabled(self) -> bool:
if str(getattr(self, "parallel_training", "")).lower() != "mp":
return False
num_agents = int(getattr(getattr(self, "args", None), "num_agents", 0) or 0)
if num_agents <= 1:
return False
devices = getattr(self, "agent_devices", None)
if not devices:
return True
unique = {str(device) for device in devices}
return len(unique) > 1

def _run_agent_tasks(
self,
fn,
*,
agent_indices: Optional[List[int]] = None,
parallel: Optional[bool] = None,
) -> List[Any]:
num_agents = int(getattr(getattr(self, "args", None), "num_agents", 0) or 0)
indices = (
list(agent_indices)
if agent_indices is not None
else list(range(max(num_agents, 0)))
)
if not indices:
return []

use_parallel = (
self._parallel_agent_mode_enabled() if parallel is None else bool(parallel)
)
if not use_parallel or len(indices) <= 1:
return [fn(agent_idx) for agent_idx in indices]

results: Dict[int, Any] = {}
max_workers = len(indices)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(fn, agent_idx): agent_idx for agent_idx in indices
}
for future in as_completed(futures):
agent_idx = futures[future]
results[agent_idx] = future.result()
return [results[agent_idx] for agent_idx in indices]

def _filter_model_kwargs(self, cfg: Optional[Dict[str, Any]]) -> Dict[str, Any]:
torch_dtype = None
if isinstance(cfg, dict):
Expand Down Expand Up @@ -64,16 +110,18 @@ def _encode_prompt(
prompt: str,
agent_idx: Optional[int] = None,
tokenizer: Optional[Any] = None,
device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
tokenizer = tokenizer or self._get_tokenizer(agent_idx)
encoded = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
)
target_device = device or self.device
return {
"input_ids": encoded["input_ids"].to(self.device),
"attention_mask": encoded["attention_mask"].to(self.device),
"input_ids": encoded["input_ids"].to(target_device),
"attention_mask": encoded["attention_mask"].to(target_device),
}

def _prepare_advantages(self, rollouts: List[Any]) -> None:
Expand Down Expand Up @@ -139,7 +187,9 @@ def _summarize_rollout_metrics(self, rollouts: List[Any]) -> Dict[str, float]:
return metrics

def _iter_dataloader(self, dataloader, epoch: int, total_epochs: int):
if getattr(self, "verbose", True):
dist_env = getattr(self, "dist_env", None)
is_main = bool(getattr(dist_env, "is_main", True))
if getattr(self, "verbose", True) and is_main:
return enumerate(
tqdm(
dataloader,
Expand All @@ -165,7 +215,12 @@ def _on_epoch_end(
epoch_metrics: Dict[str, List[float]],
) -> None:
summary = self._summarize_epoch_metrics(epoch_metrics)
if summary and getattr(self, "verbose", True):
dist_env = getattr(self, "dist_env", None)
if (
summary
and getattr(self, "verbose", True)
and getattr(dist_env, "is_main", True)
):
print(f"Epoch {epoch + 1}/{total_epochs} metrics: {summary}")

def _tag_metrics(
Expand Down Expand Up @@ -197,10 +252,9 @@ def _process_buffer(
self,
agent_idx: int,
buffer: List[Any],
epoch_metrics: Dict[str, List[float]],
) -> None:
) -> Dict[str, Any]:
if not buffer:
return
return {"metric_values": {}, "log_metrics": {}}

has_turn_idx = any(
"turn_idx" in (getattr(s, "metadata", {}) or {}) for s in buffer
Expand All @@ -213,35 +267,74 @@ def _process_buffer(
buffer.clear()

combined_log: Dict[str, float] = {}
metric_values: Dict[str, List[float]] = {}
for t_idx in sorted(turn_groups.keys()):
samples = turn_groups[t_idx]
metrics = self._update(agent_idx, samples)
tagged = self._tag_metrics(metrics, agent_idx, turn_idx=t_idx)
combined_log.update(tagged)
for key, value in tagged.items():
epoch_metrics[key].append(value)
metric_values.setdefault(key, []).append(value)
return {"metric_values": metric_values, "log_metrics": combined_log}

def _drain_ready_agent_buffers(
self,
ready_agents: List[int],
epoch_metrics: Dict[str, List[float]],
) -> None:
if not ready_agents:
return

unique_ready = sorted({int(idx) for idx in ready_agents})
run_parallel = bool(
getattr(
self,
"_parallel_update_enabled",
self._parallel_agent_mode_enabled(),
)
)

def _process(agent_idx: int) -> Dict[str, Any]:
return self._process_buffer(agent_idx, self.rollout_buffers[agent_idx])

results = self._run_agent_tasks(
_process,
agent_indices=unique_ready,
parallel=run_parallel,
)

combined_log: Dict[str, float] = {}
for result in results:
metric_values = result.get("metric_values", {})
for key, values in metric_values.items():
for value in values:
epoch_metrics[key].append(value)
combined_log.update(result.get("log_metrics", {}))

if combined_log and self._should_log_train():
self._log_metrics(combined_log)

def _run_batch(self, batch, epoch_metrics: Dict[str, List[float]]) -> None:
for item in batch:
rollouts = self._collect_rollouts(item)
ready_agents: List[int] = []
for sample in rollouts:
agent_idx = sample.agent_idx
buffer = self.rollout_buffers[agent_idx]
buffer.append(sample)
if len(buffer) >= self.args.rollout_buffer_size:
self._process_buffer(agent_idx, buffer, epoch_metrics)
ready_agents.append(agent_idx)
if ready_agents:
self._drain_ready_agent_buffers(ready_agents, epoch_metrics)
if self.args.num_agents > 0:
# Count joint-action reward evaluations (one per agent group).
self.env_step += len(rollouts) // self.args.num_agents

def _flush_buffers(self, epoch_metrics: Dict[str, List[float]]) -> None:
for agent_idx, buffer in enumerate(self.rollout_buffers):
if not buffer:
continue
self._process_buffer(agent_idx, buffer, epoch_metrics)
ready_agents = [
agent_idx for agent_idx, buffer in enumerate(self.rollout_buffers) if buffer
]
self._drain_ready_agent_buffers(ready_agents, epoch_metrics)

def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
Expand Down Expand Up @@ -275,18 +368,22 @@ def evaluate(self) -> Dict[str, float]:
turn_groups: Dict[int, List[Any]] = {}
seen = 0

with torch.no_grad():
for batch in dataloader:
for item in batch:
rollouts = self._collect_rollouts(item)
for sample in rollouts:
t_idx = int(sample.metadata.get("turn_idx", 0))
turn_groups.setdefault(t_idx, []).append(sample)
seen += 1
self._in_eval = True
try:
with torch.no_grad():
for batch in dataloader:
for item in batch:
rollouts = self._collect_rollouts(item)
for sample in rollouts:
t_idx = int(sample.metadata.get("turn_idx", 0))
turn_groups.setdefault(t_idx, []).append(sample)
seen += 1
if seen >= num_samples:
break
if seen >= num_samples:
break
if seen >= num_samples:
break
finally:
self._in_eval = False

eval_log: Dict[str, float] = {}
for turn_idx, samples in sorted(turn_groups.items()):
Expand Down
Loading