Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/ark/comm/queriable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,37 @@ def __init__(
):
super().__init__(node_name, session, clock, channel, data_collector)
self._handler = handler
self._queryable = self._session.declare_queryable(self._channel, self._on_query)
self._queryable = self._session.declare_queryable(self._channel,
self._on_query,
complete=False)
print(f"Declared queryable on channel: {self._channel}")

def core_registration(self):
print("..todo: register with ark core..")

def _on_query(self, query: zenoh.Query) -> None:
# If we were closed, ignore queries
if not self._active:
print("Received query on closed Queryable, ignoring")
return

try:
# Zenoh query may or may not include a payload.
# For your use-case, the request is always in query.value (bytes)
raw = bytes(query.value) if query.value is not None else b""
raw = bytes(query.payload) if query.payload is not None else b""
if not raw:
print("Received query with no payload, ignoring")
return # nothing to do

req_env = Envelope()
req_env.ParseFromString(raw)

# Decode request protobuf
req_type = msgs.get(req_env.payload_msg_type)
# req_type = msgs.get(req_env.payload_msg_type)
req_type = msgs.get(req_env.msg_type)
if req_type is None:
# Unknown message type: ignore (or reply error later)
print(f"Unknown message type '{req_env.msg_type}' in query, ignoring")
return

req_msg = req_type()
Expand All @@ -60,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None:
resp_env.sent_seq_index = self._seq_index
resp_env.src_node_name = self._node_name
resp_env.channel = self._channel
resp_env.msg_type = resp_msg.DESCRIPTOR.full_name
resp_env.payload = resp_msg.SerializeToString()

self._seq_index += 1

resp_env = Envelope.pack(self._node_name, self._clock, resp_msg)
query.reply(resp_env.SerializeToString())
with query:
query.reply(query.key_expr, resp_env.SerializeToString())

if self._data_collector:
self._data_collector.append(req_env.SerializeToString())
Expand All @@ -73,4 +82,9 @@ def _on_query(self, query: zenoh.Query) -> None:
except Exception:
# Keep it minimal: don't kill the zenoh callback thread
# You can add logging here if desired
print("Error processing query:")
# write the traceback to stdout for debugging
import traceback
traceback.print_exc()

return
21 changes: 12 additions & 9 deletions src/ark/comm/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from google.protobuf.message import Message
from ark.data.data_collector import DataCollector
from ark.comm.end_point import EndPoint
from ark_msgs.registry import msgs


class Querier(EndPoint):
Expand All @@ -11,12 +12,16 @@ def __init__(
self,
node_name: str,
session: zenoh.Session,
query_target,
clock,
channel: str,
data_collector: DataCollector | None,
):
super().__init__(node_name, session, clock, channel, data_collector)
self._querier = self._session.declare_querier(self._channel)
self._querier = self._session.declare_querier(self._channel,
target=query_target)
print(f"Declared querier on channel: {self._channel}")
self._query_selector = zenoh.Selector(self._channel)

def core_registration(self):
print("..todo: register with ark core..")
Expand Down Expand Up @@ -48,18 +53,21 @@ def query(
else:
raise TypeError("req must be a protobuf Message or bytes")

replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout)
replies = self._querier.get(payload=req_env.SerializeToString())

for reply in replies:
if reply.ok is None:
continue

resp_env = Envelope()
resp_env.ParseFromString(bytes(reply.ok))
resp_env.ParseFromString(bytes(reply.ok.payload))
resp_env.dst_node_name = self._node_name
resp_env.recv_timestamp = self._clock.now()

resp = resp_env.extract_message()
try:
resp = resp_env.extract_message()
except Exception as e:
continue

self._seq_index += 1

Expand All @@ -69,11 +77,6 @@ def query(

return resp

else:
raise TimeoutError(
f"No OK reply received for query on '{self._channel}' within {timeout}s"
)

def close(self):
super().close()
self._querier.undeclare()
1 change: 1 addition & 0 deletions src/ark/diff/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ark.diff.variable import Variable
80 changes: 80 additions & 0 deletions src/ark/diff/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from ark_msgs import Value


class Variable:

def __init__(self, name, value, mode, variables_registry, lock, clock, create_queryable_fn):
self.name = name
self.mode = mode
self._variables_registry = variables_registry
self._lock = lock
self._clock = clock
self._grads = {} # input vars: {output_name: grad_value}

if mode == "input":
self._tensor = torch.tensor(value, requires_grad=True)
self._history = {}
self._replay_tensor = None
else:
self._tensor = None
self._computation_ts = clock.now()
self._replay_fn = None
for inp_name, inp_var in variables_registry.items():
if inp_var.mode == "input":
grad_channel = f"grad/{inp_name}/{name}"

def _make_handler(iv, ov_name, reg, lk):
def handler(_req):
out_var = reg.get(ov_name)
if _req.timestamp != 0 and out_var._replay_fn:
val, grad = out_var._replay_fn(_req.timestamp, iv.name, ov_name)
return Value(val=val, grad=grad, timestamp=_req.timestamp)
with lk:
val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0
grad = iv._grads.get(ov_name, 0.0)
ts = out_var._computation_ts if out_var else 0
return Value(val=val, grad=grad, timestamp=ts)
return handler

create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry, self._lock))

def snapshot(self, ts):
"""Record current tensor value at clock timestamp ts."""
self._history[ts] = float(self._tensor.detach())

def at(self, ts):
"""Return a fresh requires_grad tensor from history at ts."""
val = self._history[ts]
self._replay_tensor = torch.tensor(val, requires_grad=True)
return self._replay_tensor

@property
def tensor(self):
return self._tensor

@tensor.setter
def tensor(self, value):
if self.mode == "output":
self._tensor = value
self._compute_and_store_grads()
else:
self._tensor.data = value.data if isinstance(value, torch.Tensor) else torch.tensor(value)

def _is_last_output(self):
output_names = [k for k, v in self._variables_registry.items() if v.mode == "output"]
return output_names and output_names[-1] == self.name

def _compute_and_store_grads(self):
if self._tensor is None or not self._tensor.requires_grad:
return
with self._lock:
for var in self._variables_registry.values():
if var.mode == "input" and var._tensor.grad is not None:
var._tensor.grad.zero_()
self._tensor.backward(retain_graph=not self._is_last_output())
for var in self._variables_registry.values():
if var.mode == "input":
grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0
var._grads[self.name] = grad
self._computation_ts = self._clock.now()
45 changes: 42 additions & 3 deletions src/ark/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import time
import threading
import torch
import zenoh
from ark.time.clock import Clock
from ark.time.rate import Rate
Expand All @@ -10,6 +12,8 @@
from ark.comm.queriable import Queryable
from ark.data.data_collector import DataCollector
from ark.core.registerable import Registerable
from ark.diff.variable import Variable
from ark_msgs import VariableInfo


class BaseNode(Registerable):
Expand All @@ -22,7 +26,8 @@ def __init__(
sim: bool = False,
collect_data: bool = False,
):
self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg))
# self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg))
self._z_cfg = z_cfg
self._session = zenoh.open(self._z_cfg)
self._env_name = env_name
self._node_name = node_name
Expand All @@ -36,6 +41,9 @@ def __init__(
self._subs = {}
self._queriers = {}
self._queriables = {}
self._variables = {}
self._grad_lock = threading.Lock()
self._registry_pub = self.create_publisher("ark/vars/register")

self._session.declare_subscriber(f"{env_name}/reset", self._on_reset)

Expand Down Expand Up @@ -73,17 +81,19 @@ def create_subscriber(self, channel, callback) -> Subscriber:
self._subs[channel] = sub
return sub

def create_querier(self, channel, timeout=10.0) -> Querier:
def create_querier(self, channel, target, timeout=10.0) -> Querier:
querier = Querier(
self._node_name,
self._session,
target,
self._clock,
channel,
self._data_collector,
timeout,
# timeout,
)
querier.core_registration()
self._queriers[channel] = querier
# print session and channelinfo for debugging
return querier

def create_queryable(self, channel, handler) -> Queryable:
Expand All @@ -99,6 +109,35 @@ def create_queryable(self, channel, handler) -> Queryable:
self._queriables[channel] = queryable
return queryable

def create_variable(self, name, value, mode="input"):
"""Create a differentiable variable.

For "output" mode, queryables are created on "grad/{input_name}/{name}"
for each existing input variable. Setting the tensor triggers an eager
backward pass that caches gradients into each input variable.

Args:
name: Variable identifier, used in channel names.
value: Initial scalar value for the underlying tensor.
mode: "input" or "output".
"""
var = Variable(name, value, mode, self._variables, self._grad_lock, self._clock, self.create_queryable)
self._variables[name] = var

if mode == "output":
grad_channels = [
f"grad/{inp_name}/{name}"
for inp_name, v in self._variables.items()
if v.mode == "input"
]
self._registry_pub.publish(VariableInfo(
output_name=name,
node_name=self._node_name,
grad_channels=grad_channels,
))

return var

def create_rate(self, hz: float):
rate = Rate(self._clock, hz)
self._rates.append(rate)
Expand Down
68 changes: 65 additions & 3 deletions src/ark/scripts/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,68 @@
import sys
import argparse
import time
import zenoh
from ark.node import BaseNode
from ark_msgs import VariableInfo


class RegistryNode(BaseNode):

def __init__(self, cfg):
super().__init__("ark", "registry", cfg)
self._var_registry: dict[str, VariableInfo] = {}
self.create_subscriber("ark/vars/register", self._on_register)

def _on_register(self, msg: VariableInfo):
name = msg.output_name
self._var_registry[name] = msg
channel = f"ark/vars/{name}"
if channel not in self._queriables:
def _make_handler(n):
def handler(_req):
return self._var_registry[n]
return handler
self.create_queryable(channel, _make_handler(name))
print(f"Registered output variable '{name}' from node '{msg.node_name}' "
f"with channels: {list(msg.grad_channels)}")

def core_registration(self):
pass

def close(self):
super().close()


def main():
print(">>Ark core<<")
print(sys.executable)
parser = argparse.ArgumentParser(
prog="ark-core", description="Ark central registry"
)
parser.add_argument("--mode", "-m", dest="mode",
choices=["peer", "client"], type=str)
parser.add_argument("--connect", "-e", dest="connect",
metavar="ENDPOINT", action="append", type=str)
parser.add_argument("--listen", "-l", dest="listen",
metavar="ENDPOINT", action="append", type=str)
args = parser.parse_args()

cfg = zenoh.Config()
if args.mode:
import json
cfg.insert_json5("mode", json.dumps(args.mode))
if args.connect:
import json
cfg.insert_json5("connect/endpoints", json.dumps(args.connect))
if args.listen:
import json
cfg.insert_json5("listen/endpoints", json.dumps(args.listen))

node = RegistryNode(cfg)
print("Ark registry running.")
try:
node.spin()
except KeyboardInterrupt:
print("Shutting down registry.")
node.close()


if __name__ == "__main__":
main()
Loading