diff --git a/house_build/configs/house_build_iac_config.yaml b/house_build/configs/house_build_iac_config.yaml index f4b034e..05ec2ed 100644 --- a/house_build/configs/house_build_iac_config.yaml +++ b/house_build/configs/house_build_iac_config.yaml @@ -3,16 +3,15 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 dtype: bf16 agents: null critic_model: - name: "Qwen/Qwen3-4B-Instruct-2507" + name: Qwen/Qwen3-4B-Instruct-2507 type: qwen - temperature: 0.6 - top_p: 0.6 max_length: 2048 dtype: bf16 @@ -21,25 +20,59 @@ critics: null dataset: name: house_build type: house_build + train_split: '[:8]' + eval_split: '[8:]' json_path: ../dataset/data.json - train_split: "[:8]" - eval_split: "[8:]" + +prompt: + use_chat_template: true + +task: + player: + hp: 5 + spider: + atk_high: 3 + atk_low: 1 + num: 3 + max_commands: 600 + limited_resource: true + block_agent1: + - white_concrete + - obsidian + - stone_stairs + - stone_bricks + - planks + - air + block_agent2: + - white_concrete + - obsidian + - stone_stairs + - stone_bricks + - planks + - air output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_iac_house_build verbose: false + save_final_model: false + save_path: output_iac_house_build external: mode: score_feedback original_prompt: true previous_response: true lim: 20 + external_prompt_passthrough: false iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 4 + use_separate_critic: true num_train_epochs: 150 agent_learning_rate: 5e-6 critic_learning_rate: 5e-6 @@ -48,10 +81,6 @@ iac: rollout_buffer_size: 1 train_batch_size: 1 max_new_tokens: 512 - temperature: 0.6 - top_p: 0.6 - top_k: null - use_separate_critic: true discount: 0.9 early_termination_threshold: 0.0 eval_interval: 10 @@ -68,20 +97,7 @@ wandb: project: house_build entity: OpenMLRL run_name: house_build_iac - dir: output - tags: ["iac", "house_build"] - -prompt: - use_chat_template: true - -task: - block_agent1: [white_concrete, obsidian, stone_stairs, stone_bricks, planks, air] - block_agent2: [white_concrete, obsidian, stone_stairs, stone_bricks, planks, air] - max_commands: 600 - limited_resource: true - player: - hp: 5 - spider: - num: 3 - atk_low: 1 - atk_high: 3 + dir: output_iac_house_build + tags: + - iac + - house_build diff --git a/house_build/configs/house_build_maac_config.yaml b/house_build/configs/house_build_maac_config.yaml index ebbb4e1..b97e05a 100644 --- a/house_build/configs/house_build_maac_config.yaml +++ b/house_build/configs/house_build_maac_config.yaml @@ -3,16 +3,15 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 dtype: bf16 agents: null critic_model: - name: "Qwen/Qwen3-4B-Instruct-2507" + name: Qwen/Qwen3-4B-Instruct-2507 type: qwen - temperature: 0.6 - top_p: 0.6 max_length: 2048 dtype: bf16 @@ -21,23 +20,56 @@ critics: null dataset: name: house_build type: house_build + train_split: '[:8]' + eval_split: '[8:]' json_path: ../dataset/data.json - train_split: "[:8]" - eval_split: "[8:]" + +prompt: + use_chat_template: true + +task: + player: + hp: 5 + spider: + atk_high: 3 + atk_low: 1 + num: 3 + max_commands: 600 + limited_resource: true + block_agent1: + - white_concrete + - obsidian + - stone_stairs + - stone_bricks + - planks + - air + block_agent2: + - white_concrete + - obsidian + - stone_stairs + - stone_bricks + - planks + - air output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_maac_house_build verbose: false + save_final_model: false + save_path: output_maac_house_build external: mode: score_feedback original_prompt: true previous_response: true lim: 20 + external_prompt_passthrough: false maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 4 critic_type: v @@ -48,9 +80,6 @@ maac: rollout_buffer_size: 1 train_batch_size: 1 max_new_tokens: 512 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: 0.0 eval_interval: 10 @@ -67,20 +96,7 @@ wandb: project: house_build entity: OpenMLRL run_name: house_build_maac - dir: output - tags: ["maac", "house_build"] - -prompt: - use_chat_template: true - -task: - block_agent1: [white_concrete, obsidian, stone_stairs, stone_bricks, planks, air] - block_agent2: [white_concrete, obsidian, stone_stairs, stone_bricks, planks, air] - max_commands: 600 - limited_resource: true - player: - hp: 5 - spider: - num: 3 - atk_low: 1 - atk_high: 3 + dir: output_maac_house_build + tags: + - maac + - house_build diff --git a/house_build/configs/house_build_magrpo_config.yaml b/house_build/configs/house_build_magrpo_config.yaml index 6d7f178..e1f2e4b 100644 --- a/house_build/configs/house_build_magrpo_config.yaml +++ b/house_build/configs/house_build_magrpo_config.yaml @@ -3,6 +3,7 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 dtype: bf16 @@ -15,23 +16,54 @@ critics: null dataset: name: house_build type: house_build + train_split: '[:8]' + eval_split: '[8:]' json_path: ../dataset/data.json - train_split: "[:8]" - eval_split: "[8:]" + +prompt: + use_chat_template: true + +task: + player: + hp: 5 + spider: + atk_high: 3 + atk_low: 1 + num: 3 + max_commands: 600 + limited_resource: true + block_agent1: + - white_concrete + - obsidian + - stone_stairs + - stone_bricks + - planks + - air + block_agent2: + - white_concrete + - obsidian + - stone_stairs + - stone_bricks + - planks + - air output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_magrpo_house_build verbose: false + save_final_model: false + save_path: output_magrpo_house_build external: mode: score_feedback original_prompt: true previous_response: true lim: 20 + external_prompt_passthrough: false magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_agents: 2 num_turns: 4 num_train_epochs: 20 @@ -39,9 +71,6 @@ magrpo: logging_steps: 5 num_generations: 2 max_new_tokens: 512 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 joint_mode: aligned early_termination_threshold: -0.1 @@ -61,20 +90,7 @@ wandb: project: house_build entity: OpenMLRL run_name: house_build_magrpo - dir: output - tags: ["magrpo", "house_build"] - -prompt: - use_chat_template: true - -task: - block_agent1: [white_concrete, obsidian, stone_stairs, stone_bricks, planks, air] - block_agent2: [white_concrete, obsidian, stone_stairs, stone_bricks, planks, air] - max_commands: 600 - limited_resource: true - player: - hp: 5 - spider: - num: 3 - atk_low: 1 - atk_high: 3 + dir: output_magrpo_house_build + tags: + - magrpo + - house_build diff --git a/house_build/train/train_iac.py b/house_build/train/train_iac.py index d999472..9a48520 100644 --- a/house_build/train/train_iac.py +++ b/house_build/train/train_iac.py @@ -17,6 +17,9 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import Dataset # type: ignore from transformers import AutoTokenizer # type: ignore @@ -41,7 +44,10 @@ ) from LLM_Collab_Minecraft.house_build.utils.config import apply_overrides, load_yaml, resolve_path from LLM_Collab_Minecraft.house_build.utils.prompting import apply_prompt_defaults -from LLM_Collab_Minecraft.house_build.utils.trainer_args import get_iac_args +from LLM_Collab_Minecraft.house_build.utils.trainer_args import ( + get_iac_args, + get_agent_sampling_config, +) def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str, Any]]: @@ -450,7 +456,8 @@ def main() -> int: tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - iac_args = get_iac_args(cfg, model_name=model_name) + sampling_cfg = get_agent_sampling_config(cfg) + iac_args = get_iac_args(cfg, sampling_cfg=sampling_cfg) formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer) prompt_to_item: Dict[str, Dict[str, Any]] = {} dataset_prompt_map: Dict[str, Dict[str, Any]] = {} diff --git a/house_build/train/train_maac.py b/house_build/train/train_maac.py index ce5dd1a..8011e50 100644 --- a/house_build/train/train_maac.py +++ b/house_build/train/train_maac.py @@ -17,6 +17,9 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import Dataset # type: ignore from transformers import AutoTokenizer # type: ignore @@ -41,7 +44,10 @@ ) from LLM_Collab_Minecraft.house_build.utils.config import apply_overrides, load_yaml, resolve_path from LLM_Collab_Minecraft.house_build.utils.prompting import apply_prompt_defaults -from LLM_Collab_Minecraft.house_build.utils.trainer_args import get_maac_args +from LLM_Collab_Minecraft.house_build.utils.trainer_args import ( + get_maac_args, + get_agent_sampling_config, +) def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str, Any]]: @@ -450,7 +456,8 @@ def main() -> int: tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - maac_args = get_maac_args(cfg, model_name=model_name) + sampling_cfg = get_agent_sampling_config(cfg) + maac_args = get_maac_args(cfg, sampling_cfg=sampling_cfg) formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer) prompt_to_item: Dict[str, Dict[str, Any]] = {} dataset_prompt_map: Dict[str, Dict[str, Any]] = {} diff --git a/house_build/train/train_magrpo.py b/house_build/train/train_magrpo.py index d74b599..919e08b 100644 --- a/house_build/train/train_magrpo.py +++ b/house_build/train/train_magrpo.py @@ -17,9 +17,12 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import Dataset # type: ignore -from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore +from transformers import AutoTokenizer # type: ignore import torch # type: ignore from comlrl.trainers.reinforce import MAGRPOTrainer # type: ignore @@ -41,7 +44,10 @@ ) from LLM_Collab_Minecraft.house_build.utils.config import apply_overrides, load_yaml, resolve_path from LLM_Collab_Minecraft.house_build.utils.prompting import apply_prompt_defaults -from LLM_Collab_Minecraft.house_build.utils.trainer_args import get_trainer_args +from LLM_Collab_Minecraft.house_build.utils.trainer_args import ( + get_trainer_args, + get_agent_sampling_config, +) def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str, Any]]: @@ -420,11 +426,7 @@ def main() -> int: ): raise ValueError("agents must be a list of model names.") agent_names = [str(x) for x in agent_names] - model_kwargs: Dict[str, Any] = {} - dtype = _map_dtype(model_cfg.get("dtype") or model_cfg.get("torch_dtype")) - if dtype is not None: - model_kwargs["torch_dtype"] = dtype tokenizer_source = agent_names[0] if agent_names else model_name if not tokenizer_source: @@ -438,17 +440,8 @@ def main() -> int: tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - agents = [] - if agent_names: - for name in agent_names: - agent = AutoModelForCausalLM.from_pretrained(name, **model_kwargs) - agents.append(agent) - else: - for _ in range(num_agents): - agent = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) - agents.append(agent) - - magrpo_args = get_trainer_args(cfg) + sampling_cfg = get_agent_sampling_config(cfg) + magrpo_args = get_trainer_args(cfg, sampling_cfg=sampling_cfg) formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer) reward_func = get_reward_function(cfg=cfg, num_agents=num_agents) @@ -527,8 +520,12 @@ def main() -> int: trainer_kwargs: Dict[str, Any] = { "agent_model": model_name or None, - "agents": agents, + "agents": agent_names, "num_agents": num_agents, + "model_config": { + "torch_dtype": dtype, + "special_tokens": model_cfg.get("special_tokens", {}), + }, "reward_func": reward_func, "formatters": formatters, "args": magrpo_args, diff --git a/house_build/utils/trainer_args.py b/house_build/utils/trainer_args.py index 2e32e51..8ff6a66 100644 --- a/house_build/utils/trainer_args.py +++ b/house_build/utils/trainer_args.py @@ -86,12 +86,67 @@ def _as_bool(x: Any, default: bool) -> bool: return bool(x) -def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: +def _as_device_spec(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + if s.lower() in ("none", "null", ""): + return None + return s + if isinstance(x, (list, tuple)): + return [str(v) for v in x] + return str(x) + + +def get_agent_sampling_config(cfg: Dict[str, Any]) -> Dict[str, Any]: + model_cfg = cfg.get("agent_model") + if not isinstance(model_cfg, dict): + raise ValueError("agent_model must be a mapping.") + missing = [key for key in ("temperature", "top_p", "top_k") if key not in model_cfg] + if missing: + raise ValueError( + f"agent_model is missing required sampling fields: {', '.join(missing)}" + ) + + def _require_float(key: str) -> float: + value = model_cfg.get(key) + if value is None or isinstance(value, bool): + raise ValueError(f"agent_model.{key} must be provided as a float.") + try: + return float(value) + except Exception as exc: + raise ValueError(f"agent_model.{key} must be a float, got {value!r}.") from exc + + top_k_raw = model_cfg.get("top_k") + if isinstance(top_k_raw, str) and top_k_raw.strip().lower() in ("none", "null", ""): + top_k_val: Optional[int] = None + elif top_k_raw is None: + top_k_val = None + else: + try: + top_k_val = int(float(top_k_raw)) + except Exception as exc: + raise ValueError( + f"agent_model.top_k must be an integer or null, got {top_k_raw!r}." + ) from exc + + return { + "temperature": _require_float("temperature"), + "top_p": _require_float("top_p"), + "top_k": top_k_val, + } + + +def get_trainer_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> MAGRPOConfig: tr = cfg.get("magrpo") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} - lr_val = tr.get("agent_learning_rate", 3e-5) + lr_val = tr.get("agent_learning_rate", 1e-5) joint_mode = tr.get("joint_mode", tr.get("joint_action_mode", None)) joint_mode_str = str(joint_mode or "aligned").strip().lower() @@ -101,37 +156,40 @@ def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: joint_mode_str = "cross" candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 3), 3), - "agent_learning_rate": _as_float(lr_val, 3e-5), - "logging_steps": _as_int(tr.get("logging_steps", 50), 50), - "num_generations": _as_int(tr.get("num_generations", 4), 4), + "num_turns": _as_int(tr.get("num_turns", 4), 4), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 20), 20), + "agent_learning_rate": _as_float(lr_val, 1e-5), + "logging_steps": _as_int(tr.get("logging_steps", 5), 5), + "num_generations": _as_int(tr.get("num_generations", 2), 2), "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), - "temperature": _as_float(tr.get("temperature", 0.2), 0.2), - "top_p": _as_float(tr.get("top_p", 0.95), 0.95), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), } - if "top_k" in tr: - candidate["top_k"] = _as_opt_int(tr.get("top_k", None), None) candidate.update( { + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), "discount": _as_float(tr.get("discount", tr.get("gamma", 0.9)), 0.9), "joint_mode": joint_mode_str, + "early_termination_threshold": _as_opt_float( + tr.get("early_termination_threshold", -0.1), -0.1 + ), } ) - if "early_termination_threshold" in tr: - candidate["early_termination_threshold"] = _as_opt_float( - tr.get("early_termination_threshold", None), None - ) candidate.update( { - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 2), 2), - "train_batch_size": _as_opt_int(tr.get("train_batch_size", None), None), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), + "train_batch_size": _as_opt_int(tr.get("train_batch_size", 1), 1), "advantage_normalization": _as_bool( tr.get("advantage_normalization", True), True ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 2), 2), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), } ) @@ -149,38 +207,47 @@ def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: return cfg_obj -def get_maac_args(cfg: Dict[str, Any], *, model_name: Optional[str] = None) -> MAACConfig: +def get_maac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> MAACConfig: tr = cfg.get("maac") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True)) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40), + "num_turns": _as_int(tr.get("num_turns", 4), 4), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 150), 150), "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6), "critic_learning_rate": _as_float( tr.get("critic_learning_rate", 5e-6), 5e-6 ), - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), "value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6), "advantage_normalization": _as_bool(adv_norm, True), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256), - "temperature": _as_float(tr.get("temperature", 0.6), 0.6), - "top_p": _as_float(tr.get("top_p", 0.6), 0.6), - "top_k": _as_opt_int(tr.get("top_k", None), None), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), "num_agents": _as_int(tr.get("num_agents", 2), 2), "num_generations": _as_int(tr.get("num_generations", 1), 1), + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), + "critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])), "discount": _as_float(tr.get("discount", 0.9), 0.9), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), "critic_type": str(tr.get("critic_type", "v")), "early_termination_threshold": _as_opt_float( - tr.get("early_termination_threshold", None), None + tr.get("early_termination_threshold", 0.0), 0.0 ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 10), 10), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), - "logging_steps": _as_int(tr.get("logging_steps", 1), 1), + "logging_steps": _as_int(tr.get("logging_steps", 40), 40), } try: @@ -195,44 +262,53 @@ def get_maac_args(cfg: Dict[str, Any], *, model_name: Optional[str] = None) -> M return MAACConfig(**filtered) -def get_iac_args(cfg: Dict[str, Any], *, model_name: Optional[str] = None) -> IACConfig: +def get_iac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> IACConfig: tr = cfg.get("iac") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} use_separate_critic = _as_bool(tr.get("use_separate_critic", True), True) adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True)) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40), + "num_turns": _as_int(tr.get("num_turns", 4), 4), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 150), 150), "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6), "critic_learning_rate": _as_opt_float( tr.get("critic_learning_rate", 5e-6), 5e-6 ), - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), "value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6), "value_clip_range": _as_opt_float(tr.get("value_clip_range", 0.05), 0.05), "advantage_normalization": _as_bool(adv_norm, True), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256), - "temperature": _as_float(tr.get("temperature", 0.6), 0.6), - "top_p": _as_float(tr.get("top_p", 0.6), 0.6), - "top_k": _as_opt_int(tr.get("top_k", None), None), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), "num_agents": _as_int(tr.get("num_agents", 2), 2), "num_generations": _as_int(tr.get("num_generations", 1), 1), "use_separate_critic": use_separate_critic, + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), + "critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])), "critic_value_head_hidden_dim": _as_opt_int( tr.get("critic_value_head_hidden_dim", None), None ), "value_head_hidden_dim": _as_opt_int(tr.get("value_head_hidden_dim", None), None), "discount": _as_float(tr.get("discount", 0.9), 0.9), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), "early_termination_threshold": _as_opt_float( - tr.get("early_termination_threshold", None), None + tr.get("early_termination_threshold", 0.0), 0.0 ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 10), 10), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), - "logging_steps": _as_int(tr.get("logging_steps", 1), 1), + "logging_steps": _as_int(tr.get("logging_steps", 40), 40), } try: diff --git a/str_build/configs/str_build_iac_config.yaml b/str_build/configs/str_build_iac_config.yaml index f9ded34..b197cc8 100644 --- a/str_build/configs/str_build_iac_config.yaml +++ b/str_build/configs/str_build_iac_config.yaml @@ -3,16 +3,15 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 dtype: bf16 agents: null critic_model: - name: "Qwen/Qwen3-4B-Instruct-2507" + name: Qwen/Qwen3-4B-Instruct-2507 type: qwen - temperature: 0.6 - top_p: 0.6 max_length: 2048 dtype: bf16 @@ -21,38 +20,56 @@ critics: null dataset: name: str_build type: str_build + train_split: '[:8]' + eval_split: '[8:]' csv_path: ../dataset/data.csv - train_split: "[:8]" - eval_split: "[8:]" - spacing: 2 local_z: 0 + spacing: 2 + +prompt: + use_chat_template: true + provide_graph: false + +task: + max_commands: 300 + block_agent1: + - oak_planks + - stone + - air + block_agent2: + - white_concrete + - obsidian + - air output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_iac_str_build verbose: false + save_final_model: false + save_path: output_iac_str_build external: mode: position_feedback original_prompt: true previous_response: true + external_prompt_passthrough: false iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 4 + use_separate_critic: true num_train_epochs: 150 - agent_learning_rate: 2.5e-6 - critic_learning_rate: 2.5e-6 + agent_learning_rate: 2.5e-06 + critic_learning_rate: 2.5e-06 value_loss_coef: 0.6 value_clip_range: 0.05 rollout_buffer_size: 1 train_batch_size: 1 max_new_tokens: 512 - temperature: 0.6 - top_p: 0.6 - top_k: null - use_separate_critic: true discount: 0.9 early_termination_threshold: -0.1 eval_interval: 10 @@ -69,14 +86,7 @@ wandb: project: str_build entity: OpenMLRL run_name: str_build_iac - dir: output - tags: ["iac", "str_build"] - -prompt: - provide_graph: false - use_chat_template: true - -task: - block_agent1: [oak_planks, stone, air] - block_agent2: [white_concrete, obsidian, air] - max_commands: 300 + dir: output_iac_str_build + tags: + - iac + - str_build diff --git a/str_build/configs/str_build_maac_config.yaml b/str_build/configs/str_build_maac_config.yaml index b37ff94..fd44b13 100644 --- a/str_build/configs/str_build_maac_config.yaml +++ b/str_build/configs/str_build_maac_config.yaml @@ -3,16 +3,15 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 dtype: bf16 agents: null critic_model: - name: "Qwen/Qwen3-4B-Instruct-2507" + name: Qwen/Qwen3-4B-Instruct-2507 type: qwen - temperature: 0.6 - top_p: 0.6 max_length: 2048 dtype: bf16 @@ -21,37 +20,55 @@ critics: null dataset: name: str_build type: str_build + train_split: '[:8]' + eval_split: '[8:]' csv_path: ../dataset/data.csv - train_split: "[:8]" - eval_split: "[8:]" - spacing: 2 local_z: 0 + spacing: 2 + +prompt: + use_chat_template: true + provide_graph: false + +task: + max_commands: 300 + block_agent1: + - oak_planks + - stone + - air + block_agent2: + - white_concrete + - obsidian + - air output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_maac_str_build verbose: false + save_final_model: false + save_path: output_maac_str_build external: mode: position_feedback original_prompt: true previous_response: true + external_prompt_passthrough: false maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 4 critic_type: v num_train_epochs: 150 - agent_learning_rate: 2.5e-6 - critic_learning_rate: 2.5e-6 + agent_learning_rate: 2.5e-06 + critic_learning_rate: 2.5e-06 value_loss_coef: 0.6 rollout_buffer_size: 1 train_batch_size: 1 max_new_tokens: 512 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.1 eval_interval: 10 @@ -68,14 +85,7 @@ wandb: project: str_build entity: OpenMLRL run_name: str_build_maac - dir: output - tags: ["maac", "str_build"] - -prompt: - provide_graph: false - use_chat_template: true - -task: - block_agent1: [oak_planks, stone, air] - block_agent2: [white_concrete, obsidian, air] - max_commands: 300 + dir: output_maac_str_build + tags: + - maac + - str_build diff --git a/str_build/configs/str_build_magrpo_config.yaml b/str_build/configs/str_build_magrpo_config.yaml index ea683d8..b6e25cc 100644 --- a/str_build/configs/str_build_magrpo_config.yaml +++ b/str_build/configs/str_build_magrpo_config.yaml @@ -3,6 +3,7 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 dtype: bf16 @@ -15,24 +16,43 @@ critics: null dataset: name: str_build type: str_build + train_split: '[:8]' + eval_split: '[8:]' csv_path: ../dataset/data.csv - train_split: "[:8]" - eval_split: "[8:]" - spacing: 2 local_z: 0 + spacing: 2 + +prompt: + use_chat_template: true + provide_graph: false + +task: + max_commands: 300 + block_agent1: + - oak_planks + - stone + - air + block_agent2: + - white_concrete + - obsidian + - air output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_magrpo_str_build verbose: false + save_final_model: false + save_path: output_magrpo_str_build external: mode: position_feedback original_prompt: true previous_response: true + external_prompt_passthrough: false magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_agents: 2 num_turns: 4 num_train_epochs: 20 @@ -40,9 +60,6 @@ magrpo: logging_steps: 1 num_generations: 2 max_new_tokens: 512 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 joint_mode: aligned early_termination_threshold: -0.1 @@ -62,14 +79,7 @@ wandb: project: str_build entity: OpenMLRL run_name: str_build_magrpo - dir: output - tags: ["magrpo", "str_build"] - -prompt: - provide_graph: false - use_chat_template: true - -task: - block_agent1: [oak_planks, stone, air] - block_agent2: [white_concrete, obsidian, air] - max_commands: 300 + dir: output_magrpo_str_build + tags: + - magrpo + - str_build diff --git a/str_build/train/train_iac.py b/str_build/train/train_iac.py index 3abdb06..3c0c26d 100644 --- a/str_build/train/train_iac.py +++ b/str_build/train/train_iac.py @@ -16,6 +16,9 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import Dataset # type: ignore from transformers import AutoTokenizer # type: ignore @@ -32,7 +35,10 @@ from LLM_Collab_Minecraft.str_build.utils.config import apply_overrides, load_yaml, resolve_path from LLM_Collab_Minecraft.str_build.utils.prompting import apply_graph_setting, apply_prompt_defaults from LLM_Collab_Minecraft.str_build.utils.str_builder import load_tasks_from_csv -from LLM_Collab_Minecraft.str_build.utils.trainer_args import get_iac_args +from LLM_Collab_Minecraft.str_build.utils.trainer_args import ( + get_iac_args, + get_agent_sampling_config, +) def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str, Any]]: @@ -317,7 +323,8 @@ def main() -> int: tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - iac_args = get_iac_args(cfg, model_name=model_name) + sampling_cfg = get_agent_sampling_config(cfg) + iac_args = get_iac_args(cfg, sampling_cfg=sampling_cfg) formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer) prompt_to_item: Dict[str, Dict[str, Any]] = {} dataset_prompt_map: Dict[str, Dict[str, Any]] = {} diff --git a/str_build/train/train_maac.py b/str_build/train/train_maac.py index 8631687..06a7a29 100644 --- a/str_build/train/train_maac.py +++ b/str_build/train/train_maac.py @@ -16,6 +16,9 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import Dataset # type: ignore from transformers import AutoTokenizer # type: ignore @@ -32,7 +35,10 @@ from LLM_Collab_Minecraft.str_build.utils.config import apply_overrides, load_yaml, resolve_path from LLM_Collab_Minecraft.str_build.utils.prompting import apply_graph_setting, apply_prompt_defaults from LLM_Collab_Minecraft.str_build.utils.str_builder import load_tasks_from_csv -from LLM_Collab_Minecraft.str_build.utils.trainer_args import get_maac_args +from LLM_Collab_Minecraft.str_build.utils.trainer_args import ( + get_maac_args, + get_agent_sampling_config, +) def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str, Any]]: @@ -317,7 +323,8 @@ def main() -> int: tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - maac_args = get_maac_args(cfg, model_name=model_name) + sampling_cfg = get_agent_sampling_config(cfg) + maac_args = get_maac_args(cfg, sampling_cfg=sampling_cfg) formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer) prompt_to_item: Dict[str, Dict[str, Any]] = {} dataset_prompt_map: Dict[str, Dict[str, Any]] = {} diff --git a/str_build/train/train_magrpo.py b/str_build/train/train_magrpo.py index 57e29d9..186bc43 100644 --- a/str_build/train/train_magrpo.py +++ b/str_build/train/train_magrpo.py @@ -16,9 +16,12 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import Dataset # type: ignore -from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore +from transformers import AutoTokenizer # type: ignore import torch # type: ignore from comlrl.trainers.reinforce import MAGRPOTrainer # type: ignore @@ -32,7 +35,10 @@ from LLM_Collab_Minecraft.str_build.utils.config import apply_overrides, load_yaml, resolve_path from LLM_Collab_Minecraft.str_build.utils.prompting import apply_graph_setting, apply_prompt_defaults from LLM_Collab_Minecraft.str_build.utils.str_builder import load_tasks_from_csv -from LLM_Collab_Minecraft.str_build.utils.trainer_args import get_trainer_args +from LLM_Collab_Minecraft.str_build.utils.trainer_args import ( + get_trainer_args, + get_agent_sampling_config, +) def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str, Any]]: @@ -287,11 +293,7 @@ def main() -> int: ): raise ValueError("agents must be a list of model names.") agent_names = [str(x) for x in agent_names] - model_kwargs: Dict[str, Any] = {} - dtype = _map_dtype(model_cfg.get("dtype") or model_cfg.get("torch_dtype")) - if dtype is not None: - model_kwargs["torch_dtype"] = dtype tokenizer_source = agent_names[0] if agent_names else model_name if not tokenizer_source: @@ -305,17 +307,8 @@ def main() -> int: tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - agents = [] - if agent_names: - for name in agent_names: - agent = AutoModelForCausalLM.from_pretrained(name, **model_kwargs) - agents.append(agent) - else: - for _ in range(num_agents): - agent = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) - agents.append(agent) - - magrpo_args = get_trainer_args(cfg) + sampling_cfg = get_agent_sampling_config(cfg) + magrpo_args = get_trainer_args(cfg, sampling_cfg=sampling_cfg) formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer) reward_func = get_reward_function(cfg=cfg, num_agents=num_agents) @@ -394,8 +387,12 @@ def main() -> int: trainer_kwargs: Dict[str, Any] = { "agent_model": model_name or None, - "agents": agents, + "agents": agent_names, "num_agents": num_agents, + "model_config": { + "torch_dtype": dtype, + "special_tokens": model_cfg.get("special_tokens", {}), + }, "reward_func": reward_func, "formatters": formatters, "args": magrpo_args, diff --git a/str_build/utils/trainer_args.py b/str_build/utils/trainer_args.py index 3fc7ded..901b9d2 100644 --- a/str_build/utils/trainer_args.py +++ b/str_build/utils/trainer_args.py @@ -86,12 +86,67 @@ def _as_bool(x: Any, default: bool) -> bool: return bool(x) -def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: +def _as_device_spec(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + if s.lower() in ("none", "null", ""): + return None + return s + if isinstance(x, (list, tuple)): + return [str(v) for v in x] + return str(x) + + +def get_agent_sampling_config(cfg: Dict[str, Any]) -> Dict[str, Any]: + model_cfg = cfg.get("agent_model") + if not isinstance(model_cfg, dict): + raise ValueError("agent_model must be a mapping.") + missing = [key for key in ("temperature", "top_p", "top_k") if key not in model_cfg] + if missing: + raise ValueError( + f"agent_model is missing required sampling fields: {', '.join(missing)}" + ) + + def _require_float(key: str) -> float: + value = model_cfg.get(key) + if value is None or isinstance(value, bool): + raise ValueError(f"agent_model.{key} must be provided as a float.") + try: + return float(value) + except Exception as exc: + raise ValueError(f"agent_model.{key} must be a float, got {value!r}.") from exc + + top_k_raw = model_cfg.get("top_k") + if isinstance(top_k_raw, str) and top_k_raw.strip().lower() in ("none", "null", ""): + top_k_val: Optional[int] = None + elif top_k_raw is None: + top_k_val = None + else: + try: + top_k_val = int(float(top_k_raw)) + except Exception as exc: + raise ValueError( + f"agent_model.top_k must be an integer or null, got {top_k_raw!r}." + ) from exc + + return { + "temperature": _require_float("temperature"), + "top_p": _require_float("top_p"), + "top_k": top_k_val, + } + + +def get_trainer_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> MAGRPOConfig: tr = cfg.get("magrpo") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} - lr_val = tr.get("agent_learning_rate", 3e-5) + lr_val = tr.get("agent_learning_rate", 5e-6) joint_mode = tr.get("joint_mode", tr.get("joint_action_mode", None)) joint_mode_str = str(joint_mode or "aligned").strip().lower() @@ -101,37 +156,40 @@ def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: joint_mode_str = "cross" candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 3), 3), - "agent_learning_rate": _as_float(lr_val, 3e-5), - "logging_steps": _as_int(tr.get("logging_steps", 50), 50), - "num_generations": _as_int(tr.get("num_generations", 4), 4), + "num_turns": _as_int(tr.get("num_turns", 4), 4), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 20), 20), + "agent_learning_rate": _as_float(lr_val, 5e-6), + "logging_steps": _as_int(tr.get("logging_steps", 1), 1), + "num_generations": _as_int(tr.get("num_generations", 2), 2), "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), - "temperature": _as_float(tr.get("temperature", 0.2), 0.2), - "top_p": _as_float(tr.get("top_p", 0.95), 0.95), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), } - if "top_k" in tr: - candidate["top_k"] = _as_opt_int(tr.get("top_k", None), None) candidate.update( { + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), "discount": _as_float(tr.get("discount", 0.9), 0.9), "joint_mode": joint_mode_str, + "early_termination_threshold": _as_opt_float( + tr.get("early_termination_threshold", -0.1), -0.1 + ), } ) - if "early_termination_threshold" in tr: - candidate["early_termination_threshold"] = _as_opt_float( - tr.get("early_termination_threshold", None), None - ) candidate.update( { - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 2), 2), - "train_batch_size": _as_opt_int(tr.get("train_batch_size", None), None), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), + "train_batch_size": _as_opt_int(tr.get("train_batch_size", 1), 1), "advantage_normalization": _as_bool( tr.get("advantage_normalization", True), True ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 2), 2), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), } ) @@ -149,38 +207,47 @@ def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: return cfg_obj -def get_maac_args(cfg: Dict[str, Any], *, model_name: Optional[str] = None) -> MAACConfig: +def get_maac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> MAACConfig: tr = cfg.get("maac") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True)) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40), - "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6), + "num_turns": _as_int(tr.get("num_turns", 4), 4), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 150), 150), + "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 2.5e-6), 2.5e-6), "critic_learning_rate": _as_float( - tr.get("critic_learning_rate", 5e-6), 5e-6 + tr.get("critic_learning_rate", 2.5e-6), 2.5e-6 ), - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), "value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6), "advantage_normalization": _as_bool(adv_norm, True), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256), - "temperature": _as_float(tr.get("temperature", 0.6), 0.6), - "top_p": _as_float(tr.get("top_p", 0.6), 0.6), - "top_k": _as_opt_int(tr.get("top_k", None), None), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), "num_agents": _as_int(tr.get("num_agents", 2), 2), "num_generations": _as_int(tr.get("num_generations", 1), 1), + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), + "critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])), "discount": _as_float(tr.get("discount", 0.9), 0.9), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), "critic_type": str(tr.get("critic_type", "v")), "early_termination_threshold": _as_opt_float( - tr.get("early_termination_threshold", None), None + tr.get("early_termination_threshold", -0.1), -0.1 ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 10), 10), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), - "logging_steps": _as_int(tr.get("logging_steps", 1), 1), + "logging_steps": _as_int(tr.get("logging_steps", 20), 20), } try: @@ -195,44 +262,53 @@ def get_maac_args(cfg: Dict[str, Any], *, model_name: Optional[str] = None) -> M return MAACConfig(**filtered) -def get_iac_args(cfg: Dict[str, Any], *, model_name: Optional[str] = None) -> IACConfig: +def get_iac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> IACConfig: tr = cfg.get("iac") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} use_separate_critic = _as_bool(tr.get("use_separate_critic", True), True) adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True)) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40), - "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6), + "num_turns": _as_int(tr.get("num_turns", 4), 4), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 150), 150), + "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 2.5e-6), 2.5e-6), "critic_learning_rate": _as_opt_float( - tr.get("critic_learning_rate", 5e-6), 5e-6 + tr.get("critic_learning_rate", 2.5e-6), 2.5e-6 ), - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), "value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6), "value_clip_range": _as_opt_float(tr.get("value_clip_range", 0.05), 0.05), "advantage_normalization": _as_bool(adv_norm, True), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256), - "temperature": _as_float(tr.get("temperature", 0.6), 0.6), - "top_p": _as_float(tr.get("top_p", 0.6), 0.6), - "top_k": _as_opt_int(tr.get("top_k", None), None), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), "num_agents": _as_int(tr.get("num_agents", 2), 2), "num_generations": _as_int(tr.get("num_generations", 1), 1), "use_separate_critic": use_separate_critic, + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), + "critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])), "critic_value_head_hidden_dim": _as_opt_int( tr.get("critic_value_head_hidden_dim", None), None ), "value_head_hidden_dim": _as_opt_int(tr.get("value_head_hidden_dim", None), None), "discount": _as_float(tr.get("discount", 0.9), 0.9), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), "early_termination_threshold": _as_opt_float( - tr.get("early_termination_threshold", None), None + tr.get("early_termination_threshold", -0.1), -0.1 ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 10), 10), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), - "logging_steps": _as_int(tr.get("logging_steps", 1), 1), + "logging_steps": _as_int(tr.get("logging_steps", 20), 20), } try: