-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathevaluate_custom_models.py
More file actions
62 lines (57 loc) · 1.74 KB
/
evaluate_custom_models.py
File metadata and controls
62 lines (57 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import metamon
from metamon.rl import (
pretrained_vs_pokeagent_ladder,
LocalPretrainedModel,
LocalFinetunedModel,
)
from metamon.rl.pretrained import SmallRL
"""
In this example, let's say we trained a new model from scratch with:
python -m metamon.rl.train \\
--run_name gen9v3 \\
--model_gin_config medium_multitaskagent.gin \\
--save_dir ~/metamon_ckpts/ \\
--train_gin_config binary_rl.gin \\
--obs_space TeamPreviewObservationSpace \\
--tokenizer DefaultObservationSpace-v1 \\
--log
"""
MyCustomModel = LocalPretrainedModel(
amago_ckpt_dir="~/metamon_ckpts/",
model_name="gen9v3",
model_gin_config="medium_multitaskagent.gin",
train_gin_config="binary_rl.gin",
default_checkpoint=40,
action_space=metamon.interface.DefaultActionSpace(),
observation_space=metamon.interface.TeamPreviewObservationSpace(),
tokenizer=metamon.tokenizer.get_tokenizer("DefaultObservationSpace-v1"),
)
"""
Then let's say we finetuned SmallRL to Gen9 with:
python -m metamon.rl.finetune_from_hf \\
--finetune_from_model SmallRL \\
--run_name smallrlfinetune \\
--save_dir ~/metamon_ckpts/ \\
--steps_per_epoch 10000 \\
--epochs 3 \\
--eval_gens 9 \\
--formats gen9ou \\
--log
"""
MyFinetunedModel = LocalFinetunedModel(
base_model=SmallRL,
amago_ckpt_dir="~/metamon_ckpts/",
model_name="smallrlfinetune",
default_checkpoint=2,
)
teams = metamon.env.get_metamon_teams("gen1ou", "competitive")
# or create a custom set of teams (metamon.env.TeamSet)
results = pretrained_vs_pokeagent_ladder(
pretrained_model=MyFinetunedModel,
username="PAC-MyTeamName",
password="my_password",
battle_format="gen1ou",
team_set=teams,
total_battles=10,
)
print(results)