Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ocean/nmmo3/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#define ACT_SIZES {26}
#define OBS_TENSOR_T ByteTensor

#define MY_STATE
#define Env MMO
#include "vecenv.h"

Expand All @@ -30,6 +31,55 @@ void my_init(Env* env, Dict* kwargs) {
init(env);
}

#define NMMO3_STATE_COLS 10

// Exports terrain, packed entity rows, and the tick counter for rendering
// and telemetry. Rows are players first, then enemies:
// (kind, r, c, hp, hp_max, comb_lvl, prof_lvl, dir, anim, in_combat),
// kind 0 = player, 1 = enemy. Terrain is allocated once in init and only
// rewritten by c_reset, so it is exported zero-copy; positions and tick
// live in scratch valid until the next my_state call and are copied by
// the binding layer.
int my_state(void* e, StateField* fields, int max_fields) {
if (max_fields < 3) {
return 0;
}
MMO* env = (MMO*)e;
int num_entities = env->num_agents + env->num_enemies;

static int* positions = NULL;
static int positions_cap = 0;
if (num_entities * NMMO3_STATE_COLS > positions_cap) {
positions_cap = num_entities * NMMO3_STATE_COLS;
positions = realloc(positions, positions_cap * sizeof(int));
}
for (int i = 0; i < num_entities; i++) {
int kind = i >= env->num_agents;
Entity* ent = kind ? &env->enemies[i - env->num_agents] : &env->players[i];
int* row = &positions[i * NMMO3_STATE_COLS];
row[0] = kind;
row[1] = ent->r;
row[2] = ent->c;
row[3] = ent->hp;
row[4] = ent->hp_max;
row[5] = ent->comb_lvl;
row[6] = ent->prof_lvl;
row[7] = ent->dir;
row[8] = ent->anim;
row[9] = ent->in_combat;
}

static int tick;
tick = env->tick;

fields[0] = (StateField){"terrain", env->terrain, "int8", 2,
{env->height, env->width}, PUFF_STATE_ZERO_COPY};
fields[1] = (StateField){"positions", positions, "int32", 2,
{num_entities, NMMO3_STATE_COLS}, 0};
fields[2] = (StateField){"tick", &tick, "int32", 1, {1}, 0};
return 3;
}

void my_log(Log* log, Dict* out) {
dict_set(out, "perf", log->perf);
dict_set(out, "score", log->score);
Expand Down
40 changes: 40 additions & 0 deletions src/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,45 @@ py::dict vec_log(VecEnv& ve) {
return result;
}

static size_t state_dtype_size(const char* dtype) {
if (!strcmp(dtype, "int8") || !strcmp(dtype, "uint8")) return 1;
if (!strcmp(dtype, "int16") || !strcmp(dtype, "uint16")) return 2;
if (!strcmp(dtype, "int32") || !strcmp(dtype, "uint32") || !strcmp(dtype, "float32")) return 4;
if (!strcmp(dtype, "int64") || !strcmp(dtype, "uint64") || !strcmp(dtype, "float64")) return 8;
throw std::runtime_error(std::string("my_state: unknown dtype ") + dtype);
}

// Snapshot of env-exported state (my_state hook). Fields are copied into
// Python-owned bytes unless the env flags them PUFF_STATE_ZERO_COPY, in which
// case a read-only view of the C buffer is returned (invalidated by close()).
// Returns an empty dict for envs that do not implement the hook.
py::dict vec_state(VecEnv& ve, int env_id) {
if (env_id < 0 || env_id >= ve.vec->size)
throw std::runtime_error("state: env_id out of range");
StateField fields[PUFF_MAX_STATE_FIELDS];
int n = my_state(static_vec_env_at(ve.vec, env_id), fields, PUFF_MAX_STATE_FIELDS);
py::dict result;
for (int i = 0; i < n; i++) {
size_t count = 1;
py::tuple shape(fields[i].ndim);
for (int d = 0; d < fields[i].ndim; d++) {
shape[d] = fields[i].dims[d];
count *= (size_t)fields[i].dims[d];
}
size_t nbytes = count * state_dtype_size(fields[i].dtype);
py::dict entry;
if (fields[i].flags & PUFF_STATE_ZERO_COPY) {
entry["data"] = py::memoryview::from_memory(fields[i].data, (py::ssize_t)nbytes);
} else {
entry["data"] = py::bytes((const char*)fields[i].data, nbytes);
}
entry["dtype"] = fields[i].dtype;
entry["shape"] = shape;
result[fields[i].name] = entry;
}
return result;
}

void vec_close(VecEnv& ve) {
static_vec_close(ve.vec);
ve.vec = nullptr;
Expand Down Expand Up @@ -614,6 +653,7 @@ PYBIND11_MODULE(_C, m) {
.def("gpu_step", &gpu_vec_step_py)
.def("cpu_step", &cpu_vec_step_py)
.def("render", [](VecEnv& ve, int env_id) { static_vec_render(ve.vec, env_id); })
.def("state", &vec_state, py::arg("env_id") = 0)
.def("log", &vec_log)
.def("close", &vec_close);

Expand Down
40 changes: 40 additions & 0 deletions src/bindings_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,45 @@ static py::dict vec_log(VecEnv& ve) {
return result;
}

static size_t state_dtype_size(const char* dtype) {
if (!strcmp(dtype, "int8") || !strcmp(dtype, "uint8")) return 1;
if (!strcmp(dtype, "int16") || !strcmp(dtype, "uint16")) return 2;
if (!strcmp(dtype, "int32") || !strcmp(dtype, "uint32") || !strcmp(dtype, "float32")) return 4;
if (!strcmp(dtype, "int64") || !strcmp(dtype, "uint64") || !strcmp(dtype, "float64")) return 8;
throw std::runtime_error(std::string("my_state: unknown dtype ") + dtype);
}

// Snapshot of env-exported state (my_state hook). Fields are copied into
// Python-owned bytes unless the env flags them PUFF_STATE_ZERO_COPY, in which
// case a read-only view of the C buffer is returned (invalidated by close()).
// Returns an empty dict for envs that do not implement the hook.
static py::dict vec_state(VecEnv& ve, int env_id) {
if (env_id < 0 || env_id >= ve.vec->size)
throw std::runtime_error("state: env_id out of range");
StateField fields[PUFF_MAX_STATE_FIELDS];
int n = my_state(static_vec_env_at(ve.vec, env_id), fields, PUFF_MAX_STATE_FIELDS);
py::dict result;
for (int i = 0; i < n; i++) {
size_t count = 1;
py::tuple shape(fields[i].ndim);
for (int d = 0; d < fields[i].ndim; d++) {
shape[d] = fields[i].dims[d];
count *= (size_t)fields[i].dims[d];
}
size_t nbytes = count * state_dtype_size(fields[i].dtype);
py::dict entry;
if (fields[i].flags & PUFF_STATE_ZERO_COPY) {
entry["data"] = py::memoryview::from_memory(fields[i].data, (py::ssize_t)nbytes);
} else {
entry["data"] = py::bytes((const char*)fields[i].data, nbytes);
}
entry["dtype"] = fields[i].dtype;
entry["shape"] = shape;
result[fields[i].name] = entry;
}
return result;
}

static void vec_close(VecEnv& ve) {
static_vec_close(ve.vec);
ve.vec = nullptr;
Expand Down Expand Up @@ -182,6 +221,7 @@ PYBIND11_MODULE(_C, m) {
.def("reset", &vec_reset)
.def("cpu_step", &cpu_vec_step_py)
.def("render", [](VecEnv& ve, int env_id) { static_vec_render(ve.vec, env_id); })
.def("state", &vec_state, py::arg("env_id") = 0)
.def("log", &vec_log)
.def("close", &vec_close);
}
40 changes: 40 additions & 0 deletions src/vecenv.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,34 @@ void my_shared_close(void* env);
void* my_get(void* env, Dict* out);
int my_put(void* env, Dict* kwargs);

// Optional structured state export for rendering/telemetry. An env opts in by
// defining MY_STATE before including vecenv.h and implementing my_state to fill
// `fields` with up to `max_fields` entries, returning the count. By default a
// field's data pointer must remain valid until the next my_state call on the
// same env; the Python binding copies immediately, so per-binding static
// scratch buffers are fine. Fields whose buffer is allocated once and whose
// pointer stays valid for the env's lifetime may set PUFF_STATE_ZERO_COPY in
// flags; the binding then returns a read-only view instead of a copy (views
// are invalidated by close). Must not be called concurrently with env
// stepping. The default implementation exports nothing.
typedef struct StateField {
const char* name;
const void* data;
const char* dtype; // numpy-style: "int8", "uint8", "int32", "float32", ...
int ndim;
int dims[4];
int flags;
} StateField;

#define PUFF_MAX_STATE_FIELDS 16
#define PUFF_STATE_ZERO_COPY 1

int my_state(void* env, StateField* fields, int max_fields);

// Address of env `env_id` inside vec->envs. Implemented in the env's
// translation unit, where sizeof(Env) is known.
void* static_vec_env_at(StaticVec* vec, int env_id);

#ifdef __cplusplus
}
#endif
Expand Down Expand Up @@ -738,6 +766,11 @@ int get_num_act_sizes(void) { return (int)(sizeof(_act_sizes) / sizeof(_act_size
const char* get_obs_dtype(void) { return dtype_symbol; }
size_t get_obs_elem_size(void) { return obs_element_size(); }

void* static_vec_env_at(StaticVec* vec, int env_id) {
assert(env_id >= 0 && env_id < vec->size);
return &((Env*)vec->envs)[env_id];
}

static inline void _static_vec_env_step(StaticVec* vec) {
memset(vec->rewards, 0, vec->total_agents * sizeof(float));
memset(vec->terminals, 0, vec->total_agents * sizeof(float));
Expand Down Expand Up @@ -796,4 +829,11 @@ int my_put(void* env, Dict* kwargs) {
}
#endif

#ifndef MY_STATE
int my_state(void* env, StateField* fields, int max_fields) {
(void)env; (void)fields; (void)max_fields;
return 0;
}
#endif

#endif // OBS_SIZE