Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ inactive_agent_threshold = 0.4
init_step = 0
; options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only"
control_mode = "control_vehicles"
; Controller used by agent 0, the canonical SDC/target.
; options: "static", "policy", "replay", "idm"
sdc_controller = "policy"
; Controller used by non-SDC vehicles.
; options: "static", "policy", "replay", "idm"
non_sdc_controller = "policy"
; Controller used by non-vehicle agents. "auto" follows non_sdc_controller unless it is "idm", then it uses "replay".
; options: "auto", "static", "policy", "replay", "idm"
non_vehicle_controller = "auto"
Comment on lines +73 to +79
; Options: "created_all_valid", "create_only_controlled"
init_mode = "create_all_valid"
; Enable computation of evaluation-only metrics
Expand Down
23 changes: 23 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,20 @@ static PyObject *my_get(PyObject *dict, Env *env) {
}
Py_DECREF(tmp);

tmp = PyLong_FromLong(a->controller);
if (!tmp) {
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}
if (PyDict_SetItemString(agent, "controller", tmp) < 0) {
Py_DECREF(tmp);
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}
Py_DECREF(tmp);

tmp = PyLong_FromLong(a->mark_as_expert);
if (!tmp) {
Py_DECREF(agent);
Expand Down Expand Up @@ -1714,6 +1728,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
int s_map_counter = starting_map_counter;
int init_mode = unpack(kwargs, "init_mode");
int control_mode = unpack(kwargs, "control_mode");
int sdc_controller = unpack(kwargs, "sdc_controller");
int non_sdc_controller = unpack(kwargs, "non_sdc_controller");
int non_vehicle_controller = unpack(kwargs, "non_vehicle_controller");
int simulation_mode = unpack(kwargs, "simulation_mode");
int init_step = unpack(kwargs, "init_step");
int seed = unpack(kwargs, "seed");
Expand Down Expand Up @@ -1850,6 +1867,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
Drive *env = calloc(1, sizeof(Drive));
env->init_mode = init_mode;
env->control_mode = control_mode;
env->sdc_controller = sdc_controller;
env->non_sdc_controller = non_sdc_controller;
env->non_vehicle_controller = non_vehicle_controller;
env->simulation_mode = simulation_mode;
env->init_step = init_step;
env->num_max_agents = max_agents_per_env;
Expand Down Expand Up @@ -1974,6 +1994,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->timestep = init_step;
env->init_mode = (int) unpack(kwargs, "init_mode");
env->control_mode = (int) unpack(kwargs, "control_mode");
env->sdc_controller = (int) unpack(kwargs, "sdc_controller");
env->non_sdc_controller = (int) unpack(kwargs, "non_sdc_controller");
env->non_vehicle_controller = (int) unpack(kwargs, "non_vehicle_controller");
env->simulation_mode = (int) unpack(kwargs, "simulation_mode");
env->reward_conditioning = (bool) unpack(kwargs, "reward_conditioning");
env->reward_randomization = (bool) unpack(kwargs, "reward_randomization");
Expand Down
1 change: 1 addition & 0 deletions pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ struct Agent {
int num_goals_reached;
int active_agent;
int mark_as_expert;
int controller;
float cumulative_displacement;
int displacement_sample_count;
float path_progression;
Expand Down
77 changes: 69 additions & 8 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
#define CONTROL_WOSAC 2
#define CONTROL_SDC_ONLY 3

// Controller modes
#define CONTROLLER_STATIC 0
#define CONTROLLER_POLICY 1
#define CONTROLLER_REPLAY 2
#define CONTROLLER_IDM 3

// Simulation modes
#define SIMULATION_GIGAFLOW 0
#define SIMULATION_REPLAY 1
Expand Down Expand Up @@ -419,6 +425,9 @@ struct Drive {
int *tracks_to_predict;
int init_mode;
int control_mode;
int sdc_controller;
int non_sdc_controller;
int non_vehicle_controller;
int simulation_mode;
int termination_mode;
float inactive_agent_threshold;
Expand Down Expand Up @@ -3366,6 +3375,28 @@ static bool should_control_agent(Drive *env, int agent_idx) {
return agent->route_length != 0;
}

static int resolve_agent_controller(Drive *env, int agent_idx, int is_active, int replay_by_default) {
if (replay_by_default) {
return CONTROLLER_REPLAY;
}

Agent *agent = &env->agents[agent_idx];
int requested_controller = CONTROLLER_STATIC;
if (agent_idx == EGO_IDX) {
requested_controller = env->sdc_controller;
} else if (agent->type == VEHICLE) {
requested_controller = env->non_sdc_controller;
} else {
requested_controller = env->non_vehicle_controller;
}

if (requested_controller == CONTROLLER_POLICY && !is_active) {
return CONTROLLER_STATIC;
}

return requested_controller;
}

void set_active_agents(Drive *env) {
// Initialize
env->active_agent_count = 0; // Policy-controlled agents
Expand Down Expand Up @@ -3400,6 +3431,8 @@ void set_active_agents(Drive *env) {

for (int i = 0; i < successfully_created; i++) {
env->active_agent_indices[i] = active_agent_indices[i];
env->agents[active_agent_indices[i]].controller
= resolve_agent_controller(env, active_agent_indices[i], 1, 0);
}
free(active_agent_indices);

Expand Down Expand Up @@ -3452,12 +3485,15 @@ void set_active_agents(Drive *env) {
active_agent_indices[env->active_agent_count] = i;
env->active_agent_count++;
env->agents[i].active_agent = 1;
env->agents[i].controller = resolve_agent_controller(env, i, 1, 0);
} else if (is_log_replay || env->init_mode != INIT_ONLY_CONTROLLABLE_AGENTS) {
// In log-replay mode, all non-controlled agents become expert_static
static_agent_indices[env->static_agent_count] = i;
env->static_agent_count++;
env->agents[i].active_agent = 0;
if (is_log_replay || env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents) {
int replay_by_default
= is_log_replay || env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents;
env->agents[i].controller = resolve_agent_controller(env, i, 0, replay_by_default);
if (env->agents[i].controller == CONTROLLER_REPLAY) {
expert_static_agent_indices[env->expert_static_agent_count] = i;
env->expert_static_agent_count++;
env->agents[i].mark_as_expert = 1;
Expand Down Expand Up @@ -5147,6 +5183,32 @@ static void move_dynamics(Drive *env, int action_idx, int agent_idx) {
return;
}

#include "idm.h"

static void move_agent_with_controller(Drive *env, int action_idx, int agent_idx) {
Agent *agent = &env->agents[agent_idx];

if (agent->controller == CONTROLLER_STATIC) {
return;
}

if (agent->controller == CONTROLLER_IDM) {
move_idm(env, agent_idx);
return;
}

if (agent->controller == CONTROLLER_REPLAY) {
if (env->simulation_mode == SIMULATION_REPLAY) {
move_expert(env, env->actions, agent_idx);
}
Comment on lines +5201 to +5203
return;
}

if (agent->controller == CONTROLLER_POLICY && action_idx >= 0) {
move_dynamics(env, action_idx, agent_idx);
}
}

Comment on lines +5186 to +5211
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can directly add this code where it is used.

static inline void sample_erratic_flags(Drive *env, Agent *agent) {
agent->is_blind_partner
= (env->partner_blindness_prob > 0.0f && random_uniform(0.0f, 1.0f) < env->partner_blindness_prob) ? 1 : 0;
Expand Down Expand Up @@ -5277,18 +5339,17 @@ void c_step(Drive *env) {
env->timestep++;

// -> 1. Apply actions and move agents
// Move static experts
for (int i = 0; i < env->expert_static_agent_count; i++) {
int expert_idx = env->expert_static_agent_indices[i];
move_expert(env, env->actions, expert_idx);
// Move background agents according to their per-agent controller.
for (int i = 0; i < env->static_agent_count; i++) {
int background_idx = env->static_agent_indices[i];
Comment on lines +5343 to +5344
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep expert_static_agent_count, you don't move to move parked cars

move_agent_with_controller(env, -1, background_idx);
}
// Move active agents with policy actions
for (int i = 0; i < env->active_agent_count; i++) {
env->logs[i].score = 0.0f;
env->logs[i].episode_length += 1;
int agent_idx = env->active_agent_indices[i];
move_dynamics(env, i, agent_idx);
// move_expert(env, env->actions, agent_idx);
move_agent_with_controller(env, i, agent_idx);
}

// -> 2. Compute metrics and rewards
Expand Down
43 changes: 43 additions & 0 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(
num_eval_scenarios=16,
init_mode="create_all_valid",
control_mode="control_vehicles",
sdc_controller="policy",
non_sdc_controller="policy",
non_vehicle_controller="auto",
map_dir=None,
target_type="static",
goal_on_lane=True,
Expand Down Expand Up @@ -236,6 +239,9 @@ def __init__(
self.init_step = init_step
self.init_mode_str = init_mode
self.control_mode_str = control_mode
self.sdc_controller_str = sdc_controller
self.non_sdc_controller_str = non_sdc_controller
self.non_vehicle_controller_str = non_vehicle_controller
self.simulation_mode_str = simulation_mode
self.map_dir = map_dir
# map_dir may point either at a directory containing .bin files or at
Expand Down Expand Up @@ -265,6 +271,34 @@ def __init__(
"control_mode must be one of 'control_vehicles', 'control_agents', 'control_wosac', or "
f"'control_sdc_only'. Got: {self.control_mode_str}"
)

controller_values = {
"static": binding.CONTROLLER_STATIC,
"policy": binding.CONTROLLER_POLICY,
"replay": binding.CONTROLLER_REPLAY,
"idm": binding.CONTROLLER_IDM,
}
controller_options = "'static', 'policy', 'replay', or 'idm'"
if self.sdc_controller_str not in controller_values:
raise ValueError(f"sdc_controller must be one of {controller_options}. Got: {self.sdc_controller_str}")
if self.non_sdc_controller_str not in controller_values:
raise ValueError(
f"non_sdc_controller must be one of {controller_options}. Got: {self.non_sdc_controller_str}"
)
if self.non_vehicle_controller_str == "auto":
if self.non_sdc_controller_str == "idm":
self.non_vehicle_controller_str = "replay"
else:
self.non_vehicle_controller_str = self.non_sdc_controller_str
elif self.non_vehicle_controller_str not in controller_values:
raise ValueError(
f"non_vehicle_controller must be 'auto' or one of {controller_options}. "
f"Got: {self.non_vehicle_controller_str}"
)
self.sdc_controller = controller_values[self.sdc_controller_str]
self.non_sdc_controller = controller_values[self.non_sdc_controller_str]
self.non_vehicle_controller = controller_values[self.non_vehicle_controller_str]
Comment on lines +275 to +300

if self.init_mode_str == "create_all_valid":
self.init_mode = 0
elif self.init_mode_str == "create_only_controlled":
Expand Down Expand Up @@ -322,6 +356,9 @@ def __init__(
eval_mode=self.eval_mode,
init_mode=self.init_mode,
control_mode=self.control_mode,
sdc_controller=self.sdc_controller,
non_sdc_controller=self.non_sdc_controller,
non_vehicle_controller=self.non_vehicle_controller,
simulation_mode=self.simulation_mode,
init_step=self.init_step,
seed=self.random_seed,
Expand Down Expand Up @@ -413,6 +450,9 @@ def _env_init_kwargs(self, map_file, max_agents):
"init_step": self.init_step,
"init_mode": self.init_mode,
"control_mode": self.control_mode,
"sdc_controller": self.sdc_controller,
"non_sdc_controller": self.non_sdc_controller,
"non_vehicle_controller": self.non_vehicle_controller,
"simulation_mode": self.simulation_mode,
"reward_conditioning": self.reward_conditioning,
"reward_randomization": self.reward_randomization,
Expand Down Expand Up @@ -502,6 +542,9 @@ def step(self, actions):
eval_mode=self.eval_mode,
init_mode=self.init_mode,
control_mode=self.control_mode,
sdc_controller=self.sdc_controller,
non_sdc_controller=self.non_sdc_controller,
non_vehicle_controller=self.non_vehicle_controller,
simulation_mode=self.simulation_mode,
init_step=self.init_step,
map_files=self.map_files,
Expand Down
Loading
Loading