diff --git a/ocean/nmmo3/binding.c b/ocean/nmmo3/binding.c index b2d5fb4100..62c2ab0c49 100644 --- a/ocean/nmmo3/binding.c +++ b/ocean/nmmo3/binding.c @@ -4,6 +4,7 @@ #define ACT_SIZES {26} #define OBS_TENSOR_T ByteTensor +#define MY_STATE #define Env MMO #include "vecenv.h" @@ -30,6 +31,55 @@ void my_init(Env* env, Dict* kwargs) { init(env); } +#define NMMO3_STATE_COLS 10 + +// Exports terrain, packed entity rows, and the tick counter for rendering +// and telemetry. Rows are players first, then enemies: +// (kind, r, c, hp, hp_max, comb_lvl, prof_lvl, dir, anim, in_combat), +// kind 0 = player, 1 = enemy. Terrain is allocated once in init and only +// rewritten by c_reset, so it is exported zero-copy; positions and tick +// live in scratch valid until the next my_state call and are copied by +// the binding layer. +int my_state(void* e, StateField* fields, int max_fields) { + if (max_fields < 3) { + return 0; + } + MMO* env = (MMO*)e; + int num_entities = env->num_agents + env->num_enemies; + + static int* positions = NULL; + static int positions_cap = 0; + if (num_entities * NMMO3_STATE_COLS > positions_cap) { + positions_cap = num_entities * NMMO3_STATE_COLS; + positions = realloc(positions, positions_cap * sizeof(int)); + } + for (int i = 0; i < num_entities; i++) { + int kind = i >= env->num_agents; + Entity* ent = kind ? &env->enemies[i - env->num_agents] : &env->players[i]; + int* row = &positions[i * NMMO3_STATE_COLS]; + row[0] = kind; + row[1] = ent->r; + row[2] = ent->c; + row[3] = ent->hp; + row[4] = ent->hp_max; + row[5] = ent->comb_lvl; + row[6] = ent->prof_lvl; + row[7] = ent->dir; + row[8] = ent->anim; + row[9] = ent->in_combat; + } + + static int tick; + tick = env->tick; + + fields[0] = (StateField){"terrain", env->terrain, "int8", 2, + {env->height, env->width}, PUFF_STATE_ZERO_COPY}; + fields[1] = (StateField){"positions", positions, "int32", 2, + {num_entities, NMMO3_STATE_COLS}, 0}; + fields[2] = (StateField){"tick", &tick, "int32", 1, {1}, 0}; + return 3; +} + void my_log(Log* log, Dict* out) { dict_set(out, "perf", log->perf); dict_set(out, "score", log->score); diff --git a/src/bindings.cu b/src/bindings.cu index 64be61194d..cf20d605ce 100644 --- a/src/bindings.cu +++ b/src/bindings.cu @@ -381,6 +381,45 @@ py::dict vec_log(VecEnv& ve) { return result; } +static size_t state_dtype_size(const char* dtype) { + if (!strcmp(dtype, "int8") || !strcmp(dtype, "uint8")) return 1; + if (!strcmp(dtype, "int16") || !strcmp(dtype, "uint16")) return 2; + if (!strcmp(dtype, "int32") || !strcmp(dtype, "uint32") || !strcmp(dtype, "float32")) return 4; + if (!strcmp(dtype, "int64") || !strcmp(dtype, "uint64") || !strcmp(dtype, "float64")) return 8; + throw std::runtime_error(std::string("my_state: unknown dtype ") + dtype); +} + +// Snapshot of env-exported state (my_state hook). Fields are copied into +// Python-owned bytes unless the env flags them PUFF_STATE_ZERO_COPY, in which +// case a read-only view of the C buffer is returned (invalidated by close()). +// Returns an empty dict for envs that do not implement the hook. +py::dict vec_state(VecEnv& ve, int env_id) { + if (env_id < 0 || env_id >= ve.vec->size) + throw std::runtime_error("state: env_id out of range"); + StateField fields[PUFF_MAX_STATE_FIELDS]; + int n = my_state(static_vec_env_at(ve.vec, env_id), fields, PUFF_MAX_STATE_FIELDS); + py::dict result; + for (int i = 0; i < n; i++) { + size_t count = 1; + py::tuple shape(fields[i].ndim); + for (int d = 0; d < fields[i].ndim; d++) { + shape[d] = fields[i].dims[d]; + count *= (size_t)fields[i].dims[d]; + } + size_t nbytes = count * state_dtype_size(fields[i].dtype); + py::dict entry; + if (fields[i].flags & PUFF_STATE_ZERO_COPY) { + entry["data"] = py::memoryview::from_memory(fields[i].data, (py::ssize_t)nbytes); + } else { + entry["data"] = py::bytes((const char*)fields[i].data, nbytes); + } + entry["dtype"] = fields[i].dtype; + entry["shape"] = shape; + result[fields[i].name] = entry; + } + return result; +} + void vec_close(VecEnv& ve) { static_vec_close(ve.vec); ve.vec = nullptr; @@ -614,6 +653,7 @@ PYBIND11_MODULE(_C, m) { .def("gpu_step", &gpu_vec_step_py) .def("cpu_step", &cpu_vec_step_py) .def("render", [](VecEnv& ve, int env_id) { static_vec_render(ve.vec, env_id); }) + .def("state", &vec_state, py::arg("env_id") = 0) .def("log", &vec_log) .def("close", &vec_close); diff --git a/src/bindings_cpu.cpp b/src/bindings_cpu.cpp index 5ba4dc81e5..da1c172297 100644 --- a/src/bindings_cpu.cpp +++ b/src/bindings_cpu.cpp @@ -151,6 +151,45 @@ static py::dict vec_log(VecEnv& ve) { return result; } +static size_t state_dtype_size(const char* dtype) { + if (!strcmp(dtype, "int8") || !strcmp(dtype, "uint8")) return 1; + if (!strcmp(dtype, "int16") || !strcmp(dtype, "uint16")) return 2; + if (!strcmp(dtype, "int32") || !strcmp(dtype, "uint32") || !strcmp(dtype, "float32")) return 4; + if (!strcmp(dtype, "int64") || !strcmp(dtype, "uint64") || !strcmp(dtype, "float64")) return 8; + throw std::runtime_error(std::string("my_state: unknown dtype ") + dtype); +} + +// Snapshot of env-exported state (my_state hook). Fields are copied into +// Python-owned bytes unless the env flags them PUFF_STATE_ZERO_COPY, in which +// case a read-only view of the C buffer is returned (invalidated by close()). +// Returns an empty dict for envs that do not implement the hook. +static py::dict vec_state(VecEnv& ve, int env_id) { + if (env_id < 0 || env_id >= ve.vec->size) + throw std::runtime_error("state: env_id out of range"); + StateField fields[PUFF_MAX_STATE_FIELDS]; + int n = my_state(static_vec_env_at(ve.vec, env_id), fields, PUFF_MAX_STATE_FIELDS); + py::dict result; + for (int i = 0; i < n; i++) { + size_t count = 1; + py::tuple shape(fields[i].ndim); + for (int d = 0; d < fields[i].ndim; d++) { + shape[d] = fields[i].dims[d]; + count *= (size_t)fields[i].dims[d]; + } + size_t nbytes = count * state_dtype_size(fields[i].dtype); + py::dict entry; + if (fields[i].flags & PUFF_STATE_ZERO_COPY) { + entry["data"] = py::memoryview::from_memory(fields[i].data, (py::ssize_t)nbytes); + } else { + entry["data"] = py::bytes((const char*)fields[i].data, nbytes); + } + entry["dtype"] = fields[i].dtype; + entry["shape"] = shape; + result[fields[i].name] = entry; + } + return result; +} + static void vec_close(VecEnv& ve) { static_vec_close(ve.vec); ve.vec = nullptr; @@ -182,6 +221,7 @@ PYBIND11_MODULE(_C, m) { .def("reset", &vec_reset) .def("cpu_step", &cpu_vec_step_py) .def("render", [](VecEnv& ve, int env_id) { static_vec_render(ve.vec, env_id); }) + .def("state", &vec_state, py::arg("env_id") = 0) .def("log", &vec_log) .def("close", &vec_close); } diff --git a/src/vecenv.h b/src/vecenv.h index 42958d321e..37295f3f67 100644 --- a/src/vecenv.h +++ b/src/vecenv.h @@ -153,6 +153,34 @@ void my_shared_close(void* env); void* my_get(void* env, Dict* out); int my_put(void* env, Dict* kwargs); +// Optional structured state export for rendering/telemetry. An env opts in by +// defining MY_STATE before including vecenv.h and implementing my_state to fill +// `fields` with up to `max_fields` entries, returning the count. By default a +// field's data pointer must remain valid until the next my_state call on the +// same env; the Python binding copies immediately, so per-binding static +// scratch buffers are fine. Fields whose buffer is allocated once and whose +// pointer stays valid for the env's lifetime may set PUFF_STATE_ZERO_COPY in +// flags; the binding then returns a read-only view instead of a copy (views +// are invalidated by close). Must not be called concurrently with env +// stepping. The default implementation exports nothing. +typedef struct StateField { + const char* name; + const void* data; + const char* dtype; // numpy-style: "int8", "uint8", "int32", "float32", ... + int ndim; + int dims[4]; + int flags; +} StateField; + +#define PUFF_MAX_STATE_FIELDS 16 +#define PUFF_STATE_ZERO_COPY 1 + +int my_state(void* env, StateField* fields, int max_fields); + +// Address of env `env_id` inside vec->envs. Implemented in the env's +// translation unit, where sizeof(Env) is known. +void* static_vec_env_at(StaticVec* vec, int env_id); + #ifdef __cplusplus } #endif @@ -738,6 +766,11 @@ int get_num_act_sizes(void) { return (int)(sizeof(_act_sizes) / sizeof(_act_size const char* get_obs_dtype(void) { return dtype_symbol; } size_t get_obs_elem_size(void) { return obs_element_size(); } +void* static_vec_env_at(StaticVec* vec, int env_id) { + assert(env_id >= 0 && env_id < vec->size); + return &((Env*)vec->envs)[env_id]; +} + static inline void _static_vec_env_step(StaticVec* vec) { memset(vec->rewards, 0, vec->total_agents * sizeof(float)); memset(vec->terminals, 0, vec->total_agents * sizeof(float)); @@ -796,4 +829,11 @@ int my_put(void* env, Dict* kwargs) { } #endif +#ifndef MY_STATE +int my_state(void* env, StateField* fields, int max_fields) { + (void)env; (void)fields; (void)max_fields; + return 0; +} +#endif + #endif // OBS_SIZE