From 0d9923fa242e0771d1f30363b73e5c63d0896bf9 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Mon, 9 Feb 2026 18:03:12 +0000 Subject: [PATCH 01/19] adds diff test code --- test/diff_pub.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 test/diff_pub.py diff --git a/test/diff_pub.py b/test/diff_pub.py new file mode 100644 index 0000000..40ff663 --- /dev/null +++ b/test/diff_pub.py @@ -0,0 +1,27 @@ +from ark.node import BaseNode +from itertools import count +from common import z_cfg + + +class PublisherNode(BaseNode): + + def __init__(self): + super().__init__("env", "pub", z_cfg, sim=True) + self.pub = self.create_publisher("diff_sim") + self.rate = self.create_rate(1) # 1 Hz + + def spin(self): + for c in count(): + msg = f"Hello World {c}" + self.pub.publish(msg.encode("utf-8")) + print(f"Published: {msg}") + self.rate.sleep() + + +if __name__ == "__main__": + try: + node = PublisherNode() + node.spin() + except KeyboardInterrupt: + print("Shutting down publisher node.") + node.close() From 5214e1c4dac385c52778ef4ea002aa955d63c2cd Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 10 Feb 2026 14:17:14 +0000 Subject: [PATCH 02/19] demo code that visualises translation and its derivative --- test/diff_pub.py | 27 ---------------------- test/diff_publisher.py | 35 +++++++++++++++++++++++++++++ test/plotter_subsriber.py | 47 +++++++++++++++++++++++++++++++++++++++ test/simstep.py | 18 +++++++++++++++ 4 files changed, 100 insertions(+), 27 deletions(-) delete mode 100644 test/diff_pub.py create mode 100644 test/diff_publisher.py create mode 100644 test/plotter_subsriber.py create mode 100644 test/simstep.py diff --git a/test/diff_pub.py b/test/diff_pub.py deleted file mode 100644 index 40ff663..0000000 --- a/test/diff_pub.py +++ /dev/null @@ -1,27 +0,0 @@ -from ark.node import BaseNode -from itertools import count -from common import z_cfg - - -class PublisherNode(BaseNode): - - def __init__(self): - super().__init__("env", "pub", z_cfg, sim=True) - self.pub = self.create_publisher("diff_sim") - self.rate = self.create_rate(1) # 1 Hz - - def spin(self): - for c in count(): - msg = f"Hello World {c}" - self.pub.publish(msg.encode("utf-8")) - print(f"Published: {msg}") - self.rate.sleep() - - -if __name__ == "__main__": - try: - node = PublisherNode() - node.spin() - except KeyboardInterrupt: - print("Shutting down publisher node.") - node.close() diff --git a/test/diff_publisher.py b/test/diff_publisher.py new file mode 100644 index 0000000..5bb1a04 --- /dev/null +++ b/test/diff_publisher.py @@ -0,0 +1,35 @@ +import math +import time +from ark.node import BaseNode +from ark_msgs import Translation, dTranslation +from common import z_cfg +# Lissajous parameters +A, B = 1.0, 1.0 +a, b = 3.0, 2.0 +delta = math.pi / 2 +HZ = 50 +DT = 1.0 / HZ +class DiffPublisherNode(BaseNode): + def __init__(self): + super().__init__("env", "diff_pub", z_cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.vel_pub = self.create_publisher("velocity") + self.rate = self.create_rate(HZ) + def spin(self): + t = 0.0 + while True: + x = A * math.sin(a * t + delta) + y = B * math.sin(b * t) + dx = A * a * math.cos(a * t + delta) + dy = B * b * math.cos(b * t) + self.pos_pub.publish(Translation(x=x, y=y, z=0.0)) + self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0)) + t += DT + self.rate.sleep() +if __name__ == "__main__": + try: + node = DiffPublisherNode() + node.spin() + except KeyboardInterrupt: + print("Shutting down diff publisher.") + node.close() diff --git a/test/plotter_subsriber.py b/test/plotter_subsriber.py new file mode 100644 index 0000000..2477962 --- /dev/null +++ b/test/plotter_subsriber.py @@ -0,0 +1,47 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, dTranslation +from common import z_cfg +class SubscriberPlotterNode(BaseNode): + def __init__(self): + super().__init__("env", "plotter", z_cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.vel_x, self.vel_y = [], [] + self.create_subscriber("position", self.on_position) + self.create_subscriber("velocity", self.on_velocity) + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + def on_velocity(self, msg: dTranslation): + self.vel_x.append(msg.x) + self.vel_y.append(msg.y) +def main(): + node = SubscriberPlotterNode() + threading.Thread(target=node.spin, daemon=True).start() + fig, (ax_pos, ax_vel) = plt.subplots(1, 2, figsize=(10, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-1.5, 1.5) + ax_pos.set_ylim(-1.5, 1.5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_vel.set_title("Velocity (dTranslation)") + ax_vel.set_xlabel("dx") + ax_vel.set_ylabel("dy") + ax_vel.set_xlim(-5, 5) + ax_vel.set_ylim(-5, 5) + ax_vel.set_aspect("equal") + (line_vel,) = ax_vel.plot([], [], "r-") + def update(frame): + line_pos.set_data(node.pos_x, node.pos_y) + line_vel.set_data(node.vel_x, node.vel_y) + return line_pos, line_vel + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() +if __name__ == "__main__": + main() diff --git a/test/simstep.py b/test/simstep.py new file mode 100644 index 0000000..06da237 --- /dev/null +++ b/test/simstep.py @@ -0,0 +1,18 @@ +from ark.time.simtime import SimTime +from common import z_cfg +import json +import zenoh +import time + +def main(): + z_config = zenoh.Config.from_json5(json.dumps(z_cfg)) + with zenoh.open(z_config) as z: + sim_time = SimTime(z, "clock", 1000) + sim_time.reset() + while True: + current_time = time.time() + print(f"Simulated Time: {current_time:.2f} seconds") + sim_time.tick() + +if __name__ == "__main__": + main() From dc7fbc1231b0ba208b4bbb12b7d3e718530690c0 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Wed, 11 Feb 2026 18:02:27 +0000 Subject: [PATCH 03/19] [untested] basic gradient broadcasting via query --- test/autodiff.py | 54 +++++++++++++++++++++++++++ test/diff_publisher.py | 83 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 3 deletions(-) create mode 100644 test/autodiff.py diff --git a/test/autodiff.py b/test/autodiff.py new file mode 100644 index 0000000..b66b6da --- /dev/null +++ b/test/autodiff.py @@ -0,0 +1,54 @@ +import numpy as np + +class Value: + def __init__(self, val, parents=(), backward=None, name=None): + self.val = np.asarray(val, dtype=float) + self.grad = np.zeros_like(self.val) + self._prev = parents + self._backward = backward or (lambda: None) + self.name = name + def backward(self, grad=None): + if grad is None: + grad = np.ones_like(self.val) + self.grad = self.grad + grad + topo = [] + visited = set() + def build(v): + if v not in visited: + visited.add(v) + for p in v._prev: + build(p) + topo.append(v) + build(self) + for v in reversed(topo): + v._backward() + def __add__(self, other): + out = Value(self.val + other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + out.grad + other.grad = other.grad + out.grad + out._backward = _backward + return out + def __sub__(self, other): + out = Value(self.val - other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + out.grad + other.grad = other.grad - out.grad + out._backward = _backward + return out + def __mul__(self, other): + out = Value(self.val * other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + other.val * out.grad + other.grad = other.grad + self.val * out.grad + out._backward = _backward + return out + def __neg__(self): + out = Value(-self.val, parents=(self,)) + def _backward(): + self.grad = self.grad - out.grad + out._backward = _backward + return out +def clear_grads(params): + for p in params: + p.grad = np.zeros_like(p.val) diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 5bb1a04..7d2e1b9 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -1,15 +1,17 @@ import math import time from ark.node import BaseNode -from ark_msgs import Translation, dTranslation +from ark_msgs import Translation, Value from common import z_cfg +import torch + # Lissajous parameters A, B = 1.0, 1.0 a, b = 3.0, 2.0 delta = math.pi / 2 HZ = 50 DT = 1.0 / HZ -class DiffPublisherNode(BaseNode): +class LissajousPublisherNode(BaseNode): def __init__(self): super().__init__("env", "diff_pub", z_cfg, sim=True) self.pos_pub = self.create_publisher("position") @@ -26,9 +28,84 @@ def spin(self): self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0)) t += DT self.rate.sleep() + +class LinePublisherNode(BaseNode): + + def __init__(self): + super().__init__("env", "line_pub", z_cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.rate = self.create_rate(HZ) + self.v = torch.tensor(1.0, requires_grad=True) + self.m = torch.tensor(0.5, requires_grad=True) + self.c = torch.tensor(0.0, requires_grad=True) + self.latest = { + "x": 0.0, + "y": 0.0, + "v_x": 0.0, + "v_y": 0.0, + "m_x": 0.0, + "m_y": 0.0, + "c_x": 0.0, + "c_y": 0.0, + } + self.create_queryable("grad/v/x", self._on_grad_v_x) + self.create_queryable("grad/v/y", self._on_grad_v_y) + self.create_queryable("grad/m/x", self._on_grad_m_x) + self.create_queryable("grad/m/y", self._on_grad_m_y) + self.create_queryable("grad/c/x", self._on_grad_c_x) + self.create_queryable("grad/c/y", self._on_grad_c_y) + + def _on_grad_v_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["v_x"]) + + def _on_grad_v_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["v_y"]) + + def _on_grad_m_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["m_x"]) + + def _on_grad_m_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["m_y"]) + + def _on_grad_c_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["c_x"]) + + def _on_grad_c_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["c_y"]) + + def spin(self): + t = 0.0 + while True: + t_val = torch.tensor(t, requires_grad=False) + x = self.v * t_val + y = self.m * x + self.c + self.pos_pub.publish(Translation(x=float(x), y=float(y), z=0.0)) + if self.v.grad is not None: + self.v.grad.zero_() + if self.m.grad is not None: + self.m.grad.zero_() + if self.c.grad is not None: + self.c.grad.zero_() + x.backward(retain_graph=True) + self.latest["v_x"] = float(self.v.grad) + self.latest["m_x"] = float(self.m.grad) + self.latest["c_x"] = float(self.c.grad) + self.v.grad.zero_() + self.m.grad.zero_() + self.c.grad.zero_() + y.backward() + self.latest["v_y"] = float(self.v.grad) + self.latest["m_y"] = float(self.m.grad) + self.latest["c_y"] = float(self.c.grad) + self.latest["x"] = float(x) + self.latest["y"] = float(y) + t += DT + self.rate.sleep() + if __name__ == "__main__": try: - node = DiffPublisherNode() + # node = LissajousPublisherNode() + node = LinePublisherNode() node.spin() except KeyboardInterrupt: print("Shutting down diff publisher.") From 0602f55d3abce18a176e464c69718d3dbb9df1a2 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Thu, 12 Feb 2026 10:49:09 +0000 Subject: [PATCH 04/19] basic testing of passing gradients over messages --- src/ark/node.py | 2 +- test/ad_plotter_sub.py | 61 ++++++++++++++++++++++++++++++++++++++++++ test/diff_publisher.py | 13 +++++---- 3 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 test/ad_plotter_sub.py diff --git a/src/ark/node.py b/src/ark/node.py index 6ba6c6e..a1f3103 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -80,7 +80,7 @@ def create_querier(self, channel, timeout=10.0) -> Querier: self._clock, channel, self._data_collector, - timeout, + # timeout, ) querier.core_registration() self._queriers[channel] = querier diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py new file mode 100644 index 0000000..e06bbdc --- /dev/null +++ b/test/ad_plotter_sub.py @@ -0,0 +1,61 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, Value +from common import z_cfg + +class AutodiffPlotterNode(BaseNode): + def __init__(self): + super().__init__("env", "autodiff_plotter", z_cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.grad_vx, self.grad_my = [], [] + self.create_subscriber("position", self.on_position) + self.grad_vx_querier = self.create_querier("grad/v/x") + self.grad_my_querier = self.create_querier("grad/m/y") + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + def fetch_grads(self): + req = Translation(x=0.0, y=0.0, z=0.0) + try: + resp_vx = self.grad_vx_querier.query(req) + if isinstance(resp_vx, Value): + self.grad_vx.append(resp_vx.grad) + except Exception: + pass + try: + resp_my = self.grad_my_querier.query(req) + if isinstance(resp_my, Value): + self.grad_my.append(resp_my.grad) + except Exception: + pass +def main(): + node = AutodiffPlotterNode() + threading.Thread(target=node.spin, daemon=True).start() + fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-5, 5) + ax_pos.set_ylim(-5, 5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_grad.set_title("Gradients") + ax_grad.set_xlabel("t") + ax_grad.set_ylabel("grad") + (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") + (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") + ax_grad.legend() + def update(frame): + node.fetch_grads() + line_pos.set_data(node.pos_x, node.pos_y) + line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) + line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) + return line_pos, line_grad_vx, line_grad_my + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() +if __name__ == "__main__": + main() diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 7d2e1b9..ff102d5 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -79,7 +79,8 @@ def spin(self): t_val = torch.tensor(t, requires_grad=False) x = self.v * t_val y = self.m * x + self.c - self.pos_pub.publish(Translation(x=float(x), y=float(y), z=0.0)) + self.pos_pub.publish(Translation(x=float(x.detach()), + y=float(y.detach()), z=0.0)) if self.v.grad is not None: self.v.grad.zero_() if self.m.grad is not None: @@ -88,17 +89,15 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) - self.latest["m_x"] = float(self.m.grad) - self.latest["c_x"] = float(self.c.grad) + # self.latest["m_x"] = float(self.m.grad) + # self.latest["c_x"] = float(self.c.grad) self.v.grad.zero_() - self.m.grad.zero_() - self.c.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) self.latest["m_y"] = float(self.m.grad) self.latest["c_y"] = float(self.c.grad) - self.latest["x"] = float(x) - self.latest["y"] = float(y) + self.latest["x"] = float(x.detach()) + self.latest["y"] = float(y.detach()) t += DT self.rate.sleep() From 06159f0bf52876bd21176cfe7eb2a30902d435ff Mon Sep 17 00:00:00 2001 From: kamiradi Date: Thu, 12 Feb 2026 12:37:22 +0000 Subject: [PATCH 05/19] queryable not recieving queries --- src/ark/comm/queriable.py | 8 +++++++- src/ark/node.py | 1 + test/ad_plotter_sub.py | 1 + test/diff_publisher.py | 19 +++++++++++++------ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 294346a..06da42f 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -22,18 +22,22 @@ def __init__( super().__init__(node_name, session, clock, channel, data_collector) self._handler = handler self._queryable = self._session.declare_queryable(self._channel, self._on_query) + print(f"Declared queryable on channel: {self._channel}") def core_registration(self): print("..todo: register with ark core..") def _on_query(self, query: zenoh.Query) -> None: # If we were closed, ignore queries + print("Received query, processing...") if not self._active: + print("Received query on closed Queryable, ignoring") return try: # Zenoh query may or may not include a payload. # For your use-case, the request is always in query.value (bytes) + print("Parsing query") raw = bytes(query.value) if query.value is not None else b"" if not raw: return # nothing to do @@ -42,7 +46,8 @@ def _on_query(self, query: zenoh.Query) -> None: req_env.ParseFromString(raw) # Decode request protobuf - req_type = msgs.get(req_env.payload_msg_type) + # req_type = msgs.get(req_env.payload_msg_type) + req_type = msgs.get(req_env.msg_type) if req_type is None: # Unknown message type: ignore (or reply error later) return @@ -73,4 +78,5 @@ def _on_query(self, query: zenoh.Query) -> None: except Exception: # Keep it minimal: don't kill the zenoh callback thread # You can add logging here if desired + print("Error processing query:") return diff --git a/src/ark/node.py b/src/ark/node.py index a1f3103..205f5a9 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -84,6 +84,7 @@ def create_querier(self, channel, timeout=10.0) -> Querier: ) querier.core_registration() self._queriers[channel] = querier + # print session and channelinfo for debugging return querier def create_queryable(self, channel, handler) -> Queryable: diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index e06bbdc..104948c 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -20,6 +20,7 @@ def fetch_grads(self): req = Translation(x=0.0, y=0.0, z=0.0) try: resp_vx = self.grad_vx_querier.query(req) + print(f"Queried grad_vx: {resp_vx.grad}") if isinstance(resp_vx, Value): self.grad_vx.append(resp_vx.grad) except Exception: diff --git a/test/diff_publisher.py b/test/diff_publisher.py index ff102d5..788d56f 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -48,14 +48,17 @@ def __init__(self): "c_x": 0.0, "c_y": 0.0, } - self.create_queryable("grad/v/x", self._on_grad_v_x) - self.create_queryable("grad/v/y", self._on_grad_v_y) - self.create_queryable("grad/m/x", self._on_grad_m_x) - self.create_queryable("grad/m/y", self._on_grad_m_y) - self.create_queryable("grad/c/x", self._on_grad_c_x) - self.create_queryable("grad/c/y", self._on_grad_c_y) + # declare and store all the queryables for gradients + self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + def _on_grad_v_x(self, _req): + print(f"Received query for grad_v_x, latest") return Value(val=self.latest["x"], grad=self.latest["v_x"]) def _on_grad_v_y(self, _req): @@ -65,6 +68,7 @@ def _on_grad_m_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["m_x"]) def _on_grad_m_y(self, _req): + print(f"Received query for grad_v_x, latest") return Value(val=self.latest["y"], grad=self.latest["m_y"]) def _on_grad_c_x(self, _req): @@ -89,12 +93,15 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) + # print(f"v grad {self.v.grad.item()}") # self.latest["m_x"] = float(self.m.grad) # self.latest["c_x"] = float(self.c.grad) self.v.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) + # print(f"v grad {self.v.grad.item()}") self.latest["m_y"] = float(self.m.grad) + # print(f"m grad {self.m.grad}") self.latest["c_y"] = float(self.c.grad) self.latest["x"] = float(x.detach()) self.latest["y"] = float(y.detach()) From 5d05ab3206b1fc4189800da920136ac327f7381c Mon Sep 17 00:00:00 2001 From: kamiradi Date: Fri, 13 Feb 2026 14:16:55 +0000 Subject: [PATCH 06/19] simple gradient experiment --- src/ark/comm/queriable.py | 2 +- src/ark/comm/querier.py | 10 +++- src/ark/node.py | 6 ++- test/ad_plotter_sub.py | 96 ++++++++++++++++++++++++++++++++++++--- test/common.py | 18 +++++++- test/common_example.py | 83 +++++++++++++++++++++++++++++++++ test/diff_publisher.py | 71 ++++++++++++++++++++--------- 7 files changed, 252 insertions(+), 34 deletions(-) create mode 100644 test/common_example.py diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 06da42f..961c98b 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -69,7 +69,7 @@ def _on_query(self, query: zenoh.Query) -> None: self._seq_index += 1 resp_env = Envelope.pack(self._node_name, self._clock, resp_msg) - query.reply(resp_env.SerializeToString()) + query.reply(self._channel, resp_env.SerializeToString()) if self._data_collector: self._data_collector.append(req_env.SerializeToString()) diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py index c6d4586..c10d216 100644 --- a/src/ark/comm/querier.py +++ b/src/ark/comm/querier.py @@ -11,12 +11,16 @@ def __init__( self, node_name: str, session: zenoh.Session, + query_target, clock, channel: str, data_collector: DataCollector | None, ): super().__init__(node_name, session, clock, channel, data_collector) - self._querier = self._session.declare_querier(self._channel) + self._querier = self._session.declare_querier(self._channel, + target=query_target) + print(f"Declared querier on channel: {self._channel}") + self._query_selector = zenoh.Selector(self._channel) def core_registration(self): print("..todo: register with ark core..") @@ -48,7 +52,9 @@ def query( else: raise TypeError("req must be a protobuf Message or bytes") - replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout) + print(f"Sending query on channel '{self._channel}' with timeout {timeout}s") + replies = self._querier.get(parameters=self._query_selector.parameters, payload=req_env.SerializeToString(), timeout=timeout) + print(f"Received {len(replies)} replies for query on channel '{self._channel}'") for reply in replies: if reply.ok is None: diff --git a/src/ark/node.py b/src/ark/node.py index 205f5a9..24a8df0 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -22,7 +22,8 @@ def __init__( sim: bool = False, collect_data: bool = False, ): - self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + # self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + self._z_cfg = z_cfg self._session = zenoh.open(self._z_cfg) self._env_name = env_name self._node_name = node_name @@ -73,10 +74,11 @@ def create_subscriber(self, channel, callback) -> Subscriber: self._subs[channel] = sub return sub - def create_querier(self, channel, timeout=10.0) -> Querier: + def create_querier(self, channel, target, timeout=10.0) -> Querier: querier = Querier( self._node_name, self._session, + target, self._clock, channel, self._data_collector, diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 104948c..5a67adf 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -3,21 +3,38 @@ import matplotlib.animation as animation from ark.node import BaseNode from ark_msgs import Translation, Value -from common import z_cfg +# from common import connect_cfg, z_cfg +import argparse +import zenoh +import common_example as common + class AutodiffPlotterNode(BaseNode): - def __init__(self): - super().__init__("env", "autodiff_plotter", z_cfg, sim=True) + def __init__(self, cfg, target): + super().__init__("env", "autodiff_plotter", cfg, sim=True) self.pos_x, self.pos_y = [], [] self.grad_vx, self.grad_my = [], [] self.create_subscriber("position", self.on_position) - self.grad_vx_querier = self.create_querier("grad/v/x") - self.grad_my_querier = self.create_querier("grad/m/y") + # self.grad_vx_querier = self.create_querier("grad/v/x", target=target) + # self.grad_my_querier = self.create_querier("grad/m/y", target=target) + self.grad_vx_querier = self._session.declare_querier( + "grad/v/x", + target=target, + timeout=10.0, + ) + self.grad_my_querier = self._session.declare_querier( + "grad/m/y", + target=target, + timeout=10.0 + ) + def on_position(self, msg: Translation): self.pos_x.append(msg.x) self.pos_y.append(msg.y) + def fetch_grads(self): req = Translation(x=0.0, y=0.0, z=0.0) + print("fetching grads") try: resp_vx = self.grad_vx_querier.query(req) print(f"Queried grad_vx: {resp_vx.grad}") @@ -31,8 +48,71 @@ def fetch_grads(self): self.grad_my.append(resp_my.grad) except Exception: pass + + def fetch_grads_exp(self): + try: + resp_vx = self.grad_vx_querier.get() + for resp in resp_vx: + if resp.ok is None: + continue + v_value_str = bytes(resp.ok.payload).decode("utf-8") + v_value = float(v_value_str) + print(f"Queried grad_vx: {v_value}") + self.grad_vx.append(v_value) + except Exception: + pass + try: + resp_my = self.grad_my_querier.get() + for resp in resp_my: + if resp.ok is None: + continue + m_value_str = bytes(resp.ok.payload).decode("utf-8") + m_value = float(m_value_str) + print(f"Queried grad_my: {m_value}") + self.grad_my.append(m_value) + except Exception: + print("Failed to query grad_my") + pass def main(): - node = AutodiffPlotterNode() + parser = argparse.ArgumentParser(description="Autodiff Plotter Node") + common.add_config_arguments(parser) + parser.add_argument( + "--target", + "-t", + dest="target", + choices=["ALL", "BEST_MATCHING", "ALL_COMPLETE", "NONE"], + default="BEST_MATCHING", + type=str, + help="The target queryables of the query.", + ) + parser.add_argument( + "--timeout", + "-o", + dest="timeout", + default=10.0, + type=float, + help="The query timeout", + ) + parser.add_argument( + "--iter", dest="iter", type=int, help="How many gets to perform" + ) + parser.add_argument( + "--add-matching-listener", + default=False, + action="store_true", + help="Add matching listener", + ) + + args = parser.parse_args() + conf = common.get_config_from_args(args) + + target = { + "ALL": zenoh.QueryTarget.ALL, + "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, + "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, + }.get(args.target) + + node = AutodiffPlotterNode(conf, target) threading.Thread(target=node.spin, daemon=True).start() fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) ax_pos.set_title("Position (Translation)") @@ -45,11 +125,13 @@ def main(): ax_grad.set_title("Gradients") ax_grad.set_xlabel("t") ax_grad.set_ylabel("grad") + ax_grad.set_xlim(-5, 5) + ax_grad.set_ylim(-5, 5) (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") ax_grad.legend() def update(frame): - node.fetch_grads() + node.fetch_grads_exp() line_pos.set_data(node.pos_x, node.pos_y) line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) diff --git a/test/common.py b/test/common.py index b26b213..0a5d7bc 100644 --- a/test/common.py +++ b/test/common.py @@ -1 +1,17 @@ -z_cfg = {"mode": "peer", "connect": {"endpoints": ["udp/127.0.0.1:7447"]}} +listen_cfg = { + "mode": "peer", + "listen": { + "endpoints": ["tcp/0.0.0.0:7447"]}, +} +connect_cfg = { + "mode": "peer", + "connect": { + "endpoints": ["tcp/127.0.0.1:7447"] + } +} +z_cfg = { + "mode": "peer", + # "connect": { + # "endpoints":["udp/127.0.0.1:7447"] + # } +} diff --git a/test/common_example.py b/test/common_example.py new file mode 100644 index 0000000..0c1eea3 --- /dev/null +++ b/test/common_example.py @@ -0,0 +1,83 @@ +import argparse +import json + +import zenoh + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--mode", + "-m", + dest="mode", + choices=["peer", "client"], + type=str, + help="The zenoh session mode.", + ) + parser.add_argument( + "--connect", + "-e", + dest="connect", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to connect to.", + ) + parser.add_argument( + "--listen", + "-l", + dest="listen", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to listen on.", + ) + parser.add_argument( + "--config", + "-c", + dest="config", + metavar="FILE", + type=str, + help="A configuration file.", + ) + parser.add_argument( + "--no-multicast-scouting", + dest="no_multicast_scouting", + default=False, + action="store_true", + help="Disable multicast scouting.", + ) + parser.add_argument( + "--cfg", + dest="cfg", + metavar="CFG", + default=[], + action="append", + type=str, + help="Allows arbitrary configuration changes as column-separated KEY:VALUE pairs. Where KEY must be a valid config path and VALUE must be a valid JSON5 string that can be deserialized to the expected type for the KEY field. Example: --cfg='transport/unicast/max_links:2'.", + ) + + +def get_config_from_args(args) -> zenoh.Config: + conf = ( + zenoh.Config.from_file(args.config) + if args.config is not None + else zenoh.Config() + ) + if args.mode is not None: + conf.insert_json5("mode", json.dumps(args.mode)) + if args.connect is not None: + conf.insert_json5("connect/endpoints", json.dumps(args.connect)) + if args.listen is not None: + conf.insert_json5("listen/endpoints", json.dumps(args.listen)) + if args.no_multicast_scouting: + conf.insert_json5("scouting/multicast/enabled", json.dumps(False)) + + for c in args.cfg: + try: + [key, value] = c.split(":", 1) + except: + print(f"`--cfg` argument: expected KEY:VALUE pair, got {c}") + raise + conf.insert_json5(key, value) + + return conf diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 788d56f..d1dcf1a 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -2,7 +2,9 @@ import time from ark.node import BaseNode from ark_msgs import Translation, Value -from common import z_cfg +import argparse +# from common import listen_cfg, z_cfg +import common_example as common import torch # Lissajous parameters @@ -13,7 +15,7 @@ DT = 1.0 / HZ class LissajousPublisherNode(BaseNode): def __init__(self): - super().__init__("env", "diff_pub", z_cfg, sim=True) + super().__init__("env", "diff_pub", listen_cfg, sim=True) self.pos_pub = self.create_publisher("position") self.vel_pub = self.create_publisher("velocity") self.rate = self.create_rate(HZ) @@ -31,8 +33,8 @@ def spin(self): class LinePublisherNode(BaseNode): - def __init__(self): - super().__init__("env", "line_pub", z_cfg, sim=True) + def __init__(self, cfg): + super().__init__("env", "line_pub", cfg, sim=True) self.pos_pub = self.create_publisher("position") self.rate = self.create_rate(HZ) self.v = torch.tensor(1.0, requires_grad=True) @@ -49,17 +51,29 @@ def __init__(self): "c_y": 0.0, } # declare and store all the queryables for gradients - self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) - self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) - self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) - self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) - self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) - self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + # self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + # self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + # self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + # self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + # self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + # self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + self.grad_v_queryable = self._session.declare_queryable("grad/v/x", + self._on_grad_v_x, + complete=False) + self.grad_m_queryable = self._session.declare_queryable("grad/m/y", + self._on_grad_m_y, + complete=False) + # def _on_grad_v_x(self, _req): + # print(f"Received query for grad_v_x, latest") + # return Value(val=self.latest["x"], grad=self.latest["v_x"]) def _on_grad_v_x(self, _req): - print(f"Received query for grad_v_x, latest") - return Value(val=self.latest["x"], grad=self.latest["v_x"]) + v_value = self.latest["v_x"] + v_value_str = str(v_value) + payload = v_value_str.encode("utf-8") + _req.reply("grad/v/x", payload) + pass def _on_grad_v_y(self, _req): return Value(val=self.latest["y"], grad=self.latest["v_y"]) @@ -68,8 +82,11 @@ def _on_grad_m_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["m_x"]) def _on_grad_m_y(self, _req): - print(f"Received query for grad_v_x, latest") - return Value(val=self.latest["y"], grad=self.latest["m_y"]) + m_value = self.latest["m_y"] + m_value_str = str(m_value) + payload = m_value_str.encode("utf-8") + _req.reply("grad/m/y", payload) + pass def _on_grad_c_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["c_x"]) @@ -93,25 +110,37 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) - # print(f"v grad {self.v.grad.item()}") - # self.latest["m_x"] = float(self.m.grad) - # self.latest["c_x"] = float(self.c.grad) + print(f"v grad {self.v.grad.item()}") self.v.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) - # print(f"v grad {self.v.grad.item()}") self.latest["m_y"] = float(self.m.grad) - # print(f"m grad {self.m.grad}") + print(f"m grad {self.m.grad}") self.latest["c_y"] = float(self.c.grad) self.latest["x"] = float(x.detach()) self.latest["y"] = float(y.detach()) + # with self.grad_v_queryable.recv() as query: + # print(f"Received query for grad_v_x, latest") t += DT self.rate.sleep() if __name__ == "__main__": try: - # node = LissajousPublisherNode() - node = LinePublisherNode() + parser = argparse.ArgumentParser( + prog="z_queryable", description="zenoh queryable example" + ) + common.add_config_arguments(parser) + parser.add_argument( + "--complete", + dest="complete", + default=False, + action="store_true", + help="Declare the queryable as complete w.r.t. the key expression.", + ) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = LinePublisherNode(conf) node.spin() except KeyboardInterrupt: print("Shutting down diff publisher.") From 0537c41fea7bd4dbd7440789c492a74a75edfc71 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Fri, 13 Feb 2026 18:12:54 +0000 Subject: [PATCH 07/19] gradient querying working with ark querier and queriablew --- src/ark/comm/queriable.py | 20 +++++++++---- src/ark/comm/querier.py | 17 +++++------ test/ad_plotter_sub.py | 62 +++++++++++++-------------------------- test/diff_publisher.py | 50 ++++++++++++------------------- 4 files changed, 61 insertions(+), 88 deletions(-) diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 961c98b..2e55a75 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -21,7 +21,9 @@ def __init__( ): super().__init__(node_name, session, clock, channel, data_collector) self._handler = handler - self._queryable = self._session.declare_queryable(self._channel, self._on_query) + self._queryable = self._session.declare_queryable(self._channel, + self._on_query, + complete=False) print(f"Declared queryable on channel: {self._channel}") def core_registration(self): @@ -29,7 +31,6 @@ def core_registration(self): def _on_query(self, query: zenoh.Query) -> None: # If we were closed, ignore queries - print("Received query, processing...") if not self._active: print("Received query on closed Queryable, ignoring") return @@ -37,9 +38,9 @@ def _on_query(self, query: zenoh.Query) -> None: try: # Zenoh query may or may not include a payload. # For your use-case, the request is always in query.value (bytes) - print("Parsing query") - raw = bytes(query.value) if query.value is not None else b"" + raw = bytes(query.payload) if query.payload is not None else b"" if not raw: + print("Received query with no payload, ignoring") return # nothing to do req_env = Envelope() @@ -50,6 +51,7 @@ def _on_query(self, query: zenoh.Query) -> None: req_type = msgs.get(req_env.msg_type) if req_type is None: # Unknown message type: ignore (or reply error later) + print(f"Unknown message type '{req_env.msg_type}' in query, ignoring") return req_msg = req_type() @@ -65,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None: resp_env.sent_seq_index = self._seq_index resp_env.src_node_name = self._node_name resp_env.channel = self._channel + resp_env.msg_type = resp_msg.DESCRIPTOR.full_name + resp_env.payload = resp_msg.SerializeToString() self._seq_index += 1 - resp_env = Envelope.pack(self._node_name, self._clock, resp_msg) - query.reply(self._channel, resp_env.SerializeToString()) + with query: + query.reply(query.key_expr, resp_env.SerializeToString()) if self._data_collector: self._data_collector.append(req_env.SerializeToString()) @@ -79,4 +83,8 @@ def _on_query(self, query: zenoh.Query) -> None: # Keep it minimal: don't kill the zenoh callback thread # You can add logging here if desired print("Error processing query:") + # write the traceback to stdout for debugging + import traceback + traceback.print_exc() + return diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py index c10d216..88a6e8a 100644 --- a/src/ark/comm/querier.py +++ b/src/ark/comm/querier.py @@ -3,6 +3,7 @@ from google.protobuf.message import Message from ark.data.data_collector import DataCollector from ark.comm.end_point import EndPoint +from ark_msgs.registry import msgs class Querier(EndPoint): @@ -52,20 +53,21 @@ def query( else: raise TypeError("req must be a protobuf Message or bytes") - print(f"Sending query on channel '{self._channel}' with timeout {timeout}s") - replies = self._querier.get(parameters=self._query_selector.parameters, payload=req_env.SerializeToString(), timeout=timeout) - print(f"Received {len(replies)} replies for query on channel '{self._channel}'") + replies = self._querier.get(payload=req_env.SerializeToString()) for reply in replies: if reply.ok is None: continue resp_env = Envelope() - resp_env.ParseFromString(bytes(reply.ok)) + resp_env.ParseFromString(bytes(reply.ok.payload)) resp_env.dst_node_name = self._node_name resp_env.recv_timestamp = self._clock.now() - resp = resp_env.extract_message() + try: + resp = resp_env.extract_message() + except Exception as e: + continue self._seq_index += 1 @@ -75,11 +77,6 @@ def query( return resp - else: - raise TimeoutError( - f"No OK reply received for query on '{self._channel}' within {timeout}s" - ) - def close(self): super().close() self._querier.undeclare() diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 5a67adf..314f23a 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -3,6 +3,7 @@ import matplotlib.animation as animation from ark.node import BaseNode from ark_msgs import Translation, Value + # from common import connect_cfg, z_cfg import argparse import zenoh @@ -15,65 +16,36 @@ def __init__(self, cfg, target): self.pos_x, self.pos_y = [], [] self.grad_vx, self.grad_my = [], [] self.create_subscriber("position", self.on_position) - # self.grad_vx_querier = self.create_querier("grad/v/x", target=target) - # self.grad_my_querier = self.create_querier("grad/m/y", target=target) - self.grad_vx_querier = self._session.declare_querier( - "grad/v/x", - target=target, - timeout=10.0, - ) - self.grad_my_querier = self._session.declare_querier( - "grad/m/y", - target=target, - timeout=10.0 - ) + self.grad_vx_querier = self.create_querier("grad/v/x", target=target) + self.grad_my_querier = self.create_querier("grad/m/y", target=target) def on_position(self, msg: Translation): self.pos_x.append(msg.x) self.pos_y.append(msg.y) def fetch_grads(self): - req = Translation(x=0.0, y=0.0, z=0.0) - print("fetching grads") + req = Value() try: resp_vx = self.grad_vx_querier.query(req) - print(f"Queried grad_vx: {resp_vx.grad}") if isinstance(resp_vx, Value): + print(f"Received grad_vx: {resp_vx.grad}") self.grad_vx.append(resp_vx.grad) except Exception: pass try: resp_my = self.grad_my_querier.query(req) if isinstance(resp_my, Value): + print(f"Received grad_my: {resp_my.grad}") self.grad_my.append(resp_my.grad) except Exception: pass - def fetch_grads_exp(self): - try: - resp_vx = self.grad_vx_querier.get() - for resp in resp_vx: - if resp.ok is None: - continue - v_value_str = bytes(resp.ok.payload).decode("utf-8") - v_value = float(v_value_str) - print(f"Queried grad_vx: {v_value}") - self.grad_vx.append(v_value) - except Exception: - pass - try: - resp_my = self.grad_my_querier.get() - for resp in resp_my: - if resp.ok is None: - continue - m_value_str = bytes(resp.ok.payload).decode("utf-8") - m_value = float(m_value_str) - print(f"Queried grad_my: {m_value}") - self.grad_my.append(m_value) - except Exception: - print("Failed to query grad_my") - pass + def main(): + + # These are a few zenoh config related arguments that were taken from the + # examples, keeping them there until we have a better way to manage configs + # across examples parser = argparse.ArgumentParser(description="Autodiff Plotter Node") common.add_config_arguments(parser) parser.add_argument( @@ -102,18 +74,22 @@ def main(): action="store_true", help="Add matching listener", ) - + args = parser.parse_args() conf = common.get_config_from_args(args) + # These were required for the querier and queryable to find each other. target = { "ALL": zenoh.QueryTarget.ALL, "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, }.get(args.target) + # Main subcription and querying loop node = AutodiffPlotterNode(conf, target) threading.Thread(target=node.spin, daemon=True).start() + + # Plotting trajectory and gradients fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) ax_pos.set_title("Position (Translation)") ax_pos.set_xlabel("x") @@ -130,15 +106,19 @@ def main(): (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") ax_grad.legend() + def update(frame): - node.fetch_grads_exp() + node.fetch_grads() line_pos.set_data(node.pos_x, node.pos_y) line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) return line_pos, line_grad_vx, line_grad_my + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) plt.tight_layout() plt.show() node.close() + + if __name__ == "__main__": main() diff --git a/test/diff_publisher.py b/test/diff_publisher.py index d1dcf1a..153ff92 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -3,6 +3,7 @@ from ark.node import BaseNode from ark_msgs import Translation, Value import argparse + # from common import listen_cfg, z_cfg import common_example as common import torch @@ -13,12 +14,15 @@ delta = math.pi / 2 HZ = 50 DT = 1.0 / HZ + + class LissajousPublisherNode(BaseNode): def __init__(self): super().__init__("env", "diff_pub", listen_cfg, sim=True) self.pos_pub = self.create_publisher("position") self.vel_pub = self.create_publisher("velocity") self.rate = self.create_rate(HZ) + def spin(self): t = 0.0 while True: @@ -31,6 +35,7 @@ def spin(self): t += DT self.rate.sleep() + class LinePublisherNode(BaseNode): def __init__(self, cfg): @@ -51,29 +56,15 @@ def __init__(self, cfg): "c_y": 0.0, } # declare and store all the queryables for gradients - # self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) - # self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) - # self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) - # self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) - # self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) - # self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) - self.grad_v_queryable = self._session.declare_queryable("grad/v/x", - self._on_grad_v_x, - complete=False) - self.grad_m_queryable = self._session.declare_queryable("grad/m/y", - self._on_grad_m_y, - complete=False) - - # def _on_grad_v_x(self, _req): - # print(f"Received query for grad_v_x, latest") - # return Value(val=self.latest["x"], grad=self.latest["v_x"]) + self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) def _on_grad_v_x(self, _req): - v_value = self.latest["v_x"] - v_value_str = str(v_value) - payload = v_value_str.encode("utf-8") - _req.reply("grad/v/x", payload) - pass + return Value(val=self.latest["x"], grad=self.latest["v_x"]) def _on_grad_v_y(self, _req): return Value(val=self.latest["y"], grad=self.latest["v_y"]) @@ -82,10 +73,7 @@ def _on_grad_m_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["m_x"]) def _on_grad_m_y(self, _req): - m_value = self.latest["m_y"] - m_value_str = str(m_value) - payload = m_value_str.encode("utf-8") - _req.reply("grad/m/y", payload) + return Value(val=self.latest["y"], grad=self.latest["m_y"]) pass def _on_grad_c_x(self, _req): @@ -100,8 +88,9 @@ def spin(self): t_val = torch.tensor(t, requires_grad=False) x = self.v * t_val y = self.m * x + self.c - self.pos_pub.publish(Translation(x=float(x.detach()), - y=float(y.detach()), z=0.0)) + self.pos_pub.publish( + Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) + ) if self.v.grad is not None: self.v.grad.zero_() if self.m.grad is not None: @@ -110,20 +99,19 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) - print(f"v grad {self.v.grad.item()}") + print(f"Grad v_x {self.v.grad.item()}") self.v.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) self.latest["m_y"] = float(self.m.grad) - print(f"m grad {self.m.grad}") + print(f"Grad m_y {self.m.grad}") self.latest["c_y"] = float(self.c.grad) self.latest["x"] = float(x.detach()) self.latest["y"] = float(y.detach()) - # with self.grad_v_queryable.recv() as query: - # print(f"Received query for grad_v_x, latest") t += DT self.rate.sleep() + if __name__ == "__main__": try: parser = argparse.ArgumentParser( From d76495e5732b19cc9ff5fbd4a398913edc5eeb6e Mon Sep 17 00:00:00 2001 From: kamiradi Date: Fri, 13 Feb 2026 19:11:52 +0000 Subject: [PATCH 08/19] adds readme to run gradient experiment --- test/diff_publisher.py | 7 +++++++ test/gradient_exp.md | 47 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 test/gradient_exp.md diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 153ff92..aa74ee8 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -86,11 +86,18 @@ def spin(self): t = 0.0 while True: t_val = torch.tensor(t, requires_grad=False) + + # Computation graph + # line equation: y = m * x + c, where x = v * t x = self.v * t_val y = self.m * x + self.c + + # publish position self.pos_pub.publish( Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) ) + + # compute gradients if self.v.grad is not None: self.v.grad.zero_() if self.m.grad is not None: diff --git a/test/gradient_exp.md b/test/gradient_exp.md new file mode 100644 index 0000000..54387b0 --- /dev/null +++ b/test/gradient_exp.md @@ -0,0 +1,47 @@ +# Gradient Experiment + +Demonstrates differentiable simulation using ark framework. A `LinePublisherNode` publishes position on a line (`y = m*x + c`, `x = v*t`) along with autograd gradients (dx/dv, dy/dm, dy/dc), and an `AutodiffPlotterNode` subscribes to position and queries gradients in real time. + +## Prerequisites + +- Install ark framework and dependencies (`zenoh`, `torch`, `matplotlib`, `ark_msgs`) +- Run all commands from the `test/` directory + +## Running the Experiment + +Open three separate terminals. All commands are run from the `test/` directory. + +### Shell 1 — Sim Clock + +Drives simulated time for all sim-enabled nodes. + +```bash +cd test +python simstep.py +``` + +### Shell 2 — Diff Publisher + +Publishes position (`Translation`) and serves gradient queryables (`grad/v/x`, `grad/m/y`, etc.). + +```bash +cd test +python diff_publisher.py +``` + +### Shell 3 — Autodiff Plotter + +Subscribes to position and queries gradients, then plots both in real time. + +```bash +cd test +python ad_plotter_sub.py +``` + +## What to Expect + +- **Shell 1** prints the simulated time advancing each tick. +- **Shell 2** prints computed gradients (`Grad v_x`, `Grad m_y`) each step. +- **Shell 3** opens a matplotlib window with two plots: + - **Left**: Position trajectory (x vs y). + - **Right**: Gradients over time (dx/dv in green, dy/dm in magenta). From 4317c117e2a7df43ea1145eec6b7f446621c6821 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 17 Feb 2026 11:30:11 +0000 Subject: [PATCH 09/19] adds automatic gradient query handling, separates parameter nodes and computation graph nodes --- src/ark/node.py | 43 +++++++++++++++++++++ test/ad_plotter_sub.py | 20 ++++++---- test/diff_variable_pub.py | 81 +++++++++++++++++++++++++++++++++++++++ test/gradient_exp.md | 56 +++++++++++++++++++++------ test/param_publisher.py | 39 +++++++++++++++++++ test/simstep.py | 6 ++- 6 files changed, 223 insertions(+), 22 deletions(-) create mode 100644 test/diff_variable_pub.py create mode 100644 test/param_publisher.py diff --git a/src/ark/node.py b/src/ark/node.py index 24a8df0..508f4cf 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -1,5 +1,6 @@ import json import time +import torch import zenoh from ark.time.clock import Clock from ark.time.rate import Rate @@ -10,6 +11,7 @@ from ark.comm.queriable import Queryable from ark.data.data_collector import DataCollector from ark.core.registerable import Registerable +from ark_msgs import Value class BaseNode(Registerable): @@ -37,6 +39,7 @@ def __init__( self._subs = {} self._queriers = {} self._queriables = {} + self._variables = {} self._session.declare_subscriber(f"{env_name}/reset", self._on_reset) @@ -102,6 +105,46 @@ def create_queryable(self, channel, handler) -> Queryable: self._queriables[channel] = queryable return queryable + def create_variable(self, name, value, mode="input", fields=None): + tensor = torch.tensor(value, requires_grad=True) + var_entry = { + "tensor": tensor, + "mode": mode, + "fields": fields or [], + "gradients": {f: 0.0 for f in (fields or [])}, + "values": {f: 0.0 for f in (fields or [])}, + } + self._variables[name] = var_entry + + if mode == "input": + if fields: + for field in fields: + grad_channel = f"grad/{name}/{field}" + + def _make_handler(var_name, fld): + def handler(_req): + v = self._variables[var_name] + return Value( + val=v["values"].get(fld, 0.0), + grad=v["gradients"].get(fld, 0.0), + ) + return handler + + self.create_queryable(grad_channel, _make_handler(name, field)) + + def _make_sub_callback(var_name): + def callback(msg): + v = self._variables[var_name] + v["tensor"].data = torch.tensor(msg.val) + return callback + + self.create_subscriber(f"param/{name}", _make_sub_callback(name)) + + return tensor + + def update_variable(self, name, grad_dict): + self._variables[name]["gradients"].update(grad_dict) + def create_rate(self, hz: float): rate = Rate(self._clock, hz) self._rates.append(rate) diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 314f23a..1815e87 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -15,6 +15,7 @@ def __init__(self, cfg, target): super().__init__("env", "autodiff_plotter", cfg, sim=True) self.pos_x, self.pos_y = [], [] self.grad_vx, self.grad_my = [], [] + self.grad_times = [] self.create_subscriber("position", self.on_position) self.grad_vx_querier = self.create_querier("grad/v/x", target=target) self.grad_my_querier = self.create_querier("grad/m/y", target=target) @@ -25,6 +26,7 @@ def on_position(self, msg: Translation): def fetch_grads(self): req = Value() + sim_t = self._clock.now() / 1e9 try: resp_vx = self.grad_vx_querier.query(req) if isinstance(resp_vx, Value): @@ -39,6 +41,7 @@ def fetch_grads(self): self.grad_my.append(resp_my.grad) except Exception: pass + self.grad_times.append(sim_t) def main(): @@ -94,15 +97,11 @@ def main(): ax_pos.set_title("Position (Translation)") ax_pos.set_xlabel("x") ax_pos.set_ylabel("y") - ax_pos.set_xlim(-5, 5) - ax_pos.set_ylim(-5, 5) ax_pos.set_aspect("equal") (line_pos,) = ax_pos.plot([], [], "b-") ax_grad.set_title("Gradients") - ax_grad.set_xlabel("t") + ax_grad.set_xlabel("sim time (s)") ax_grad.set_ylabel("grad") - ax_grad.set_xlim(-5, 5) - ax_grad.set_ylim(-5, 5) (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") ax_grad.legend() @@ -110,11 +109,16 @@ def main(): def update(frame): node.fetch_grads() line_pos.set_data(node.pos_x, node.pos_y) - line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) - line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) + ax_pos.relim() + ax_pos.autoscale_view() + times = node.grad_times + line_grad_vx.set_data(times[:len(node.grad_vx)], node.grad_vx) + line_grad_my.set_data(times[:len(node.grad_my)], node.grad_my) + ax_grad.relim() + ax_grad.autoscale_view() return line_pos, line_grad_vx, line_grad_my - ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + ani = animation.FuncAnimation(fig, update, interval=50, blit=False) plt.tight_layout() plt.show() node.close() diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py new file mode 100644 index 0000000..1f458de --- /dev/null +++ b/test/diff_variable_pub.py @@ -0,0 +1,81 @@ +import math +import time +from ark.node import BaseNode +from ark_msgs import Translation, Value +import argparse +import common_example as common +import torch + +HZ = 50 +DT = 1.0 / HZ + + +class LineVariableNode(BaseNode): + + def __init__(self, cfg): + super().__init__("env", "line_var_pub", cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.rate = self.create_rate(HZ) + + # Create differentiable input variables — auto-creates grad queryables + # grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y + self.v = self.create_variable("v", 0.0, mode="input", fields=["x", "y"]) + self.m = self.create_variable("m", 0.0, mode="input", fields=["x", "y"]) + self.c = self.create_variable("c", 0.0, mode="input", fields=["x", "y"]) + + def spin(self): + t = 0.0 + while True: + t_val = torch.tensor(t, requires_grad=False) + + # Forward: y = m * x + c, where x = v * t + x = self.v * t_val + y = self.m * x + self.c + + # Publish position + self.pos_pub.publish( + Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) + ) + + # Backward: compute gradients + if self.v.grad is not None: + self.v.grad.zero_() + if self.m.grad is not None: + self.m.grad.zero_() + if self.c.grad is not None: + self.c.grad.zero_() + + x.backward(retain_graph=True) + v_x = float(self.v.grad) + self.v.grad.zero_() + + y.backward() + v_y = float(self.v.grad) + m_y = float(self.m.grad) + c_y = float(self.c.grad) + + # Update variable gradients — served automatically by queryables + self.update_variable("v", {"x": v_x, "y": v_y}) + self.update_variable("m", {"x": 0.0, "y": m_y}) + self.update_variable("c", {"x": 0.0, "y": c_y}) + + print(f"t={t:.2f} dx/dv={v_x:.3f} dy/dm={m_y:.3f}") + + t += DT + self.rate.sleep() + + +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser( + prog="diff_variable_pub", description="Differentiable variable publisher" + ) + common.add_config_arguments(parser) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = LineVariableNode(conf) + node.spin() + except KeyboardInterrupt: + print("Shutting down diff variable publisher.") + node.close() diff --git a/test/gradient_exp.md b/test/gradient_exp.md index 54387b0..c7f35e1 100644 --- a/test/gradient_exp.md +++ b/test/gradient_exp.md @@ -1,6 +1,28 @@ # Gradient Experiment -Demonstrates differentiable simulation using ark framework. A `LinePublisherNode` publishes position on a line (`y = m*x + c`, `x = v*t`) along with autograd gradients (dx/dv, dy/dm, dy/dc), and an `AutodiffPlotterNode` subscribes to position and queries gradients in real time. +Demonstrates differentiable simulation using ark framework with distributed parameter publishing. A `ParamPublisherNode` publishes parameter values (`v`, `m`, `c`), a `LineVariableNode` subscribes to those parameters, computes position on a line (`y = m*x + c`, `x = v*t`) with autograd gradients, and an `AutodiffPlotterNode` subscribes to position and queries gradients in real time. + +## Architecture + +``` +ParamPublisherNode LineVariableNode AutodiffPlotterNode + publishes: subscribes to: subscribes to: + param/v (Value) ──► param/v, param/m, param/c position + param/m (Value) computes: queries: + param/c (Value) x = v*t, y = m*x + c grad/v/x, grad/m/y + publishes: plots: + position (Translation) trajectory + gradients + serves queryables: vs sim time + grad/{v,m,c}/{x,y} +``` + +## Key Concepts + +- **`create_variable(name, value, mode="input", fields=...)`** on `BaseNode`: + - Creates a `torch.tensor` with `requires_grad=True` + - Auto-subscribes on `param/{name}` to receive values from other nodes + - Auto-creates gradient queryables at `grad/{name}/{field}` for each field +- **`update_variable(name, grad_dict)`**: Caches gradients after `backward()`, served by queryables ## Prerequisites @@ -9,7 +31,7 @@ Demonstrates differentiable simulation using ark framework. A `LinePublisherNode ## Running the Experiment -Open three separate terminals. All commands are run from the `test/` directory. +Open four separate terminals. All commands are run from the `test/` directory. ### Shell 1 — Sim Clock @@ -20,18 +42,27 @@ cd test python simstep.py ``` -### Shell 2 — Diff Publisher +### Shell 2 — Parameter Publisher + +Publishes fixed parameter values: `v=1.0`, `m=0.5`, `c=0.0`. + +```bash +cd test +python param_publisher.py +``` + +### Shell 3 — Diff Variable Publisher -Publishes position (`Translation`) and serves gradient queryables (`grad/v/x`, `grad/m/y`, etc.). +Subscribes to parameters, computes position and gradients, publishes position, serves gradient queryables. ```bash cd test -python diff_publisher.py +python diff_variable_pub.py ``` -### Shell 3 — Autodiff Plotter +### Shell 4 — Autodiff Plotter -Subscribes to position and queries gradients, then plots both in real time. +Subscribes to position and queries gradients, then plots both against simulation time. ```bash cd test @@ -40,8 +71,9 @@ python ad_plotter_sub.py ## What to Expect -- **Shell 1** prints the simulated time advancing each tick. -- **Shell 2** prints computed gradients (`Grad v_x`, `Grad m_y`) each step. -- **Shell 3** opens a matplotlib window with two plots: - - **Left**: Position trajectory (x vs y). - - **Right**: Gradients over time (dx/dv in green, dy/dm in magenta). +- **Shell 1** prints real elapsed time and sim time advancing each tick. +- **Shell 2** publishes parameter values at 10Hz (no output by default). +- **Shell 3** prints computed gradients (`dx/dv`, `dy/dm`) each step. +- **Shell 4** opens a matplotlib window with two plots: + - **Left**: Position trajectory (x vs y), autoscaling. + - **Right**: Gradients vs simulation time (dx/dv in green, dy/dm in magenta), autoscaling. diff --git a/test/param_publisher.py b/test/param_publisher.py new file mode 100644 index 0000000..ff439ea --- /dev/null +++ b/test/param_publisher.py @@ -0,0 +1,39 @@ +from ark.node import BaseNode +from ark_msgs import Value +import argparse +import common_example as common + +HZ = 10 + + +class ParamPublisherNode(BaseNode): + + def __init__(self, cfg): + super().__init__("env", "param_pub", cfg, sim=True) + self.pub_v = self.create_publisher("param/v") + self.pub_m = self.create_publisher("param/m") + self.pub_c = self.create_publisher("param/c") + self.rate = self.create_rate(HZ) + + def spin(self): + while True: + self.pub_v.publish(Value(val=1.0)) + self.pub_m.publish(Value(val=0.5)) + self.pub_c.publish(Value(val=0.0)) + self.rate.sleep() + + +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser( + prog="param_publisher", description="Publishes parameter values" + ) + common.add_config_arguments(parser) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = ParamPublisherNode(conf) + node.spin() + except KeyboardInterrupt: + print("Shutting down param publisher.") + node.close() diff --git a/test/simstep.py b/test/simstep.py index 06da237..c0ed35d 100644 --- a/test/simstep.py +++ b/test/simstep.py @@ -9,10 +9,12 @@ def main(): with zenoh.open(z_config) as z: sim_time = SimTime(z, "clock", 1000) sim_time.reset() + start_time = time.time() while True: - current_time = time.time() - print(f"Simulated Time: {current_time:.2f} seconds") sim_time.tick() + elapsed = time.time() - start_time + sim_elapsed = sim_time._sim_time_ns / 1e9 + print(f"Real: {elapsed:.2f} s | Sim: {sim_elapsed:.3f} s") if __name__ == "__main__": main() From fa41a249196134244f936ea246d97df8e0484b9e Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 17 Feb 2026 11:31:28 +0000 Subject: [PATCH 10/19] formatting --- src/ark/node.py | 2 ++ test/ad_plotter_sub.py | 4 ++-- test/simstep.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ark/node.py b/src/ark/node.py index 508f4cf..a025296 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -128,6 +128,7 @@ def handler(_req): val=v["values"].get(fld, 0.0), grad=v["gradients"].get(fld, 0.0), ) + return handler self.create_queryable(grad_channel, _make_handler(name, field)) @@ -136,6 +137,7 @@ def _make_sub_callback(var_name): def callback(msg): v = self._variables[var_name] v["tensor"].data = torch.tensor(msg.val) + return callback self.create_subscriber(f"param/{name}", _make_sub_callback(name)) diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 1815e87..a5b32f2 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -112,8 +112,8 @@ def update(frame): ax_pos.relim() ax_pos.autoscale_view() times = node.grad_times - line_grad_vx.set_data(times[:len(node.grad_vx)], node.grad_vx) - line_grad_my.set_data(times[:len(node.grad_my)], node.grad_my) + line_grad_vx.set_data(times[: len(node.grad_vx)], node.grad_vx) + line_grad_my.set_data(times[: len(node.grad_my)], node.grad_my) ax_grad.relim() ax_grad.autoscale_view() return line_pos, line_grad_vx, line_grad_my diff --git a/test/simstep.py b/test/simstep.py index c0ed35d..48f1be2 100644 --- a/test/simstep.py +++ b/test/simstep.py @@ -4,6 +4,7 @@ import zenoh import time + def main(): z_config = zenoh.Config.from_json5(json.dumps(z_cfg)) with zenoh.open(z_config) as z: @@ -16,5 +17,6 @@ def main(): sim_elapsed = sim_time._sim_time_ns / 1e9 print(f"Real: {elapsed:.2f} s | Sim: {sim_elapsed:.3f} s") + if __name__ == "__main__": main() From 8293fbe51aa18abf11e63304c43e0d86a03509db Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 17 Feb 2026 11:42:38 +0000 Subject: [PATCH 11/19] removes unused files --- test/autodiff.py | 54 ---------------- test/diff_publisher.py | 142 ----------------------------------------- 2 files changed, 196 deletions(-) delete mode 100644 test/autodiff.py delete mode 100644 test/diff_publisher.py diff --git a/test/autodiff.py b/test/autodiff.py deleted file mode 100644 index b66b6da..0000000 --- a/test/autodiff.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np - -class Value: - def __init__(self, val, parents=(), backward=None, name=None): - self.val = np.asarray(val, dtype=float) - self.grad = np.zeros_like(self.val) - self._prev = parents - self._backward = backward or (lambda: None) - self.name = name - def backward(self, grad=None): - if grad is None: - grad = np.ones_like(self.val) - self.grad = self.grad + grad - topo = [] - visited = set() - def build(v): - if v not in visited: - visited.add(v) - for p in v._prev: - build(p) - topo.append(v) - build(self) - for v in reversed(topo): - v._backward() - def __add__(self, other): - out = Value(self.val + other.val, parents=(self, other)) - def _backward(): - self.grad = self.grad + out.grad - other.grad = other.grad + out.grad - out._backward = _backward - return out - def __sub__(self, other): - out = Value(self.val - other.val, parents=(self, other)) - def _backward(): - self.grad = self.grad + out.grad - other.grad = other.grad - out.grad - out._backward = _backward - return out - def __mul__(self, other): - out = Value(self.val * other.val, parents=(self, other)) - def _backward(): - self.grad = self.grad + other.val * out.grad - other.grad = other.grad + self.val * out.grad - out._backward = _backward - return out - def __neg__(self): - out = Value(-self.val, parents=(self,)) - def _backward(): - self.grad = self.grad - out.grad - out._backward = _backward - return out -def clear_grads(params): - for p in params: - p.grad = np.zeros_like(p.val) diff --git a/test/diff_publisher.py b/test/diff_publisher.py deleted file mode 100644 index aa74ee8..0000000 --- a/test/diff_publisher.py +++ /dev/null @@ -1,142 +0,0 @@ -import math -import time -from ark.node import BaseNode -from ark_msgs import Translation, Value -import argparse - -# from common import listen_cfg, z_cfg -import common_example as common -import torch - -# Lissajous parameters -A, B = 1.0, 1.0 -a, b = 3.0, 2.0 -delta = math.pi / 2 -HZ = 50 -DT = 1.0 / HZ - - -class LissajousPublisherNode(BaseNode): - def __init__(self): - super().__init__("env", "diff_pub", listen_cfg, sim=True) - self.pos_pub = self.create_publisher("position") - self.vel_pub = self.create_publisher("velocity") - self.rate = self.create_rate(HZ) - - def spin(self): - t = 0.0 - while True: - x = A * math.sin(a * t + delta) - y = B * math.sin(b * t) - dx = A * a * math.cos(a * t + delta) - dy = B * b * math.cos(b * t) - self.pos_pub.publish(Translation(x=x, y=y, z=0.0)) - self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0)) - t += DT - self.rate.sleep() - - -class LinePublisherNode(BaseNode): - - def __init__(self, cfg): - super().__init__("env", "line_pub", cfg, sim=True) - self.pos_pub = self.create_publisher("position") - self.rate = self.create_rate(HZ) - self.v = torch.tensor(1.0, requires_grad=True) - self.m = torch.tensor(0.5, requires_grad=True) - self.c = torch.tensor(0.0, requires_grad=True) - self.latest = { - "x": 0.0, - "y": 0.0, - "v_x": 0.0, - "v_y": 0.0, - "m_x": 0.0, - "m_y": 0.0, - "c_x": 0.0, - "c_y": 0.0, - } - # declare and store all the queryables for gradients - self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) - self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) - self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) - self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) - self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) - self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) - - def _on_grad_v_x(self, _req): - return Value(val=self.latest["x"], grad=self.latest["v_x"]) - - def _on_grad_v_y(self, _req): - return Value(val=self.latest["y"], grad=self.latest["v_y"]) - - def _on_grad_m_x(self, _req): - return Value(val=self.latest["x"], grad=self.latest["m_x"]) - - def _on_grad_m_y(self, _req): - return Value(val=self.latest["y"], grad=self.latest["m_y"]) - pass - - def _on_grad_c_x(self, _req): - return Value(val=self.latest["x"], grad=self.latest["c_x"]) - - def _on_grad_c_y(self, _req): - return Value(val=self.latest["y"], grad=self.latest["c_y"]) - - def spin(self): - t = 0.0 - while True: - t_val = torch.tensor(t, requires_grad=False) - - # Computation graph - # line equation: y = m * x + c, where x = v * t - x = self.v * t_val - y = self.m * x + self.c - - # publish position - self.pos_pub.publish( - Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) - ) - - # compute gradients - if self.v.grad is not None: - self.v.grad.zero_() - if self.m.grad is not None: - self.m.grad.zero_() - if self.c.grad is not None: - self.c.grad.zero_() - x.backward(retain_graph=True) - self.latest["v_x"] = float(self.v.grad) - print(f"Grad v_x {self.v.grad.item()}") - self.v.grad.zero_() - y.backward() - self.latest["v_y"] = float(self.v.grad) - self.latest["m_y"] = float(self.m.grad) - print(f"Grad m_y {self.m.grad}") - self.latest["c_y"] = float(self.c.grad) - self.latest["x"] = float(x.detach()) - self.latest["y"] = float(y.detach()) - t += DT - self.rate.sleep() - - -if __name__ == "__main__": - try: - parser = argparse.ArgumentParser( - prog="z_queryable", description="zenoh queryable example" - ) - common.add_config_arguments(parser) - parser.add_argument( - "--complete", - dest="complete", - default=False, - action="store_true", - help="Declare the queryable as complete w.r.t. the key expression.", - ) - args = parser.parse_args() - conf = common.get_config_from_args(args) - - node = LinePublisherNode(conf) - node.spin() - except KeyboardInterrupt: - print("Shutting down diff publisher.") - node.close() From 7d9c55043d433676645453fdf31460b54c4cd8cb Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 17 Feb 2026 11:57:17 +0000 Subject: [PATCH 12/19] creates Variable class to maintain values and gradients, modifies demo class --- test/diff_variable_pub.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index 1f458de..012e2bc 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -19,9 +19,9 @@ def __init__(self, cfg): # Create differentiable input variables — auto-creates grad queryables # grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y - self.v = self.create_variable("v", 0.0, mode="input", fields=["x", "y"]) - self.m = self.create_variable("m", 0.0, mode="input", fields=["x", "y"]) - self.c = self.create_variable("c", 0.0, mode="input", fields=["x", "y"]) + self.v = self.create_variable("v", 0.0, mode="input", out_fields=["x", "y"]) + self.m = self.create_variable("m", 0.0, mode="input", out_fields=["x", "y"]) + self.c = self.create_variable("c", 0.0, mode="input", out_fields=["x", "y"]) def spin(self): t = 0.0 @@ -29,8 +29,8 @@ def spin(self): t_val = torch.tensor(t, requires_grad=False) # Forward: y = m * x + c, where x = v * t - x = self.v * t_val - y = self.m * x + self.c + x = self.v.tensor * t_val + y = self.m.tensor * x + self.c.tensor # Publish position self.pos_pub.publish( @@ -38,26 +38,26 @@ def spin(self): ) # Backward: compute gradients - if self.v.grad is not None: - self.v.grad.zero_() - if self.m.grad is not None: - self.m.grad.zero_() - if self.c.grad is not None: - self.c.grad.zero_() + if self.v.tensor.grad is not None: + self.v.tensor.grad.zero_() + if self.m.tensor.grad is not None: + self.m.tensor.grad.zero_() + if self.c.tensor.grad is not None: + self.c.tensor.grad.zero_() x.backward(retain_graph=True) - v_x = float(self.v.grad) - self.v.grad.zero_() + v_x = float(self.v.tensor.grad) + self.v.tensor.grad.zero_() y.backward() - v_y = float(self.v.grad) - m_y = float(self.m.grad) - c_y = float(self.c.grad) + v_y = float(self.v.tensor.grad) + m_y = float(self.m.tensor.grad) + c_y = float(self.c.tensor.grad) # Update variable gradients — served automatically by queryables - self.update_variable("v", {"x": v_x, "y": v_y}) - self.update_variable("m", {"x": 0.0, "y": m_y}) - self.update_variable("c", {"x": 0.0, "y": c_y}) + self.v.update_gradients({"x": v_x, "y": v_y}) + self.m.update_gradients({"x": 0.0, "y": m_y}) + self.c.update_gradients({"x": 0.0, "y": c_y}) print(f"t={t:.2f} dx/dv={v_x:.3f} dy/dm={m_y:.3f}") From eb7165b787de6ea9f0685e63bca0b77530d7b024 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 17 Feb 2026 11:58:19 +0000 Subject: [PATCH 13/19] Variable class, modified demo, formatting --- src/ark/node.py | 52 +++++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/src/ark/node.py b/src/ark/node.py index a025296..18b3e06 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -14,6 +14,20 @@ from ark_msgs import Value +class Variable: + + def __init__(self, name, value, mode="input", out_fields=None): + self.name = name + self.mode = mode + self.out_fields = out_fields or [] + self.tensor = torch.tensor(value, requires_grad=True) + self.gradients = {f: 0.0 for f in self.out_fields} + self.values = {f: 0.0 for f in self.out_fields} + + def update_gradients(self, grad_dict): + self.gradients.update(grad_dict) + + class BaseNode(Registerable): def __init__( @@ -105,47 +119,35 @@ def create_queryable(self, channel, handler) -> Queryable: self._queriables[channel] = queryable return queryable - def create_variable(self, name, value, mode="input", fields=None): - tensor = torch.tensor(value, requires_grad=True) - var_entry = { - "tensor": tensor, - "mode": mode, - "fields": fields or [], - "gradients": {f: 0.0 for f in (fields or [])}, - "values": {f: 0.0 for f in (fields or [])}, - } - self._variables[name] = var_entry + def create_variable(self, name, value, mode="input", out_fields=None): + var = Variable(name, value, mode, out_fields) + self._variables[name] = var if mode == "input": - if fields: - for field in fields: + if var.out_fields: + for field in var.out_fields: grad_channel = f"grad/{name}/{field}" - def _make_handler(var_name, fld): + def _make_handler(v, fld): def handler(_req): - v = self._variables[var_name] return Value( - val=v["values"].get(fld, 0.0), - grad=v["gradients"].get(fld, 0.0), + val=v.values.get(fld, 0.0), + grad=v.gradients.get(fld, 0.0), ) return handler - self.create_queryable(grad_channel, _make_handler(name, field)) + self.create_queryable(grad_channel, _make_handler(var, field)) - def _make_sub_callback(var_name): + def _make_sub_callback(v): def callback(msg): - v = self._variables[var_name] - v["tensor"].data = torch.tensor(msg.val) + v.tensor.data = torch.tensor(msg.val) return callback - self.create_subscriber(f"param/{name}", _make_sub_callback(name)) - - return tensor + self.create_subscriber(f"param/{name}", _make_sub_callback(var)) - def update_variable(self, name, grad_dict): - self._variables[name]["gradients"].update(grad_dict) + return var def create_rate(self, hz: float): rate = Rate(self._clock, hz) From 54aa0b091298e19fd82c9c176b3a248be39c04ed Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 17 Feb 2026 14:12:32 +0000 Subject: [PATCH 14/19] modifies Variable class to handle computation of gradients on query --- src/ark/node.py | 54 +++++++++++++++++++++++++++++++++------ test/diff_variable_pub.py | 27 +++----------------- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/src/ark/node.py b/src/ark/node.py index 18b3e06..79a6094 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -1,5 +1,6 @@ import json import time +import threading import torch import zenoh from ark.time.clock import Clock @@ -13,6 +14,8 @@ from ark.core.registerable import Registerable from ark_msgs import Value +_BACKWARD_LOCK = threading.Lock() + class Variable: @@ -21,11 +24,27 @@ def __init__(self, name, value, mode="input", out_fields=None): self.mode = mode self.out_fields = out_fields or [] self.tensor = torch.tensor(value, requires_grad=True) - self.gradients = {f: 0.0 for f in self.out_fields} - self.values = {f: 0.0 for f in self.out_fields} + self._outputs = {} + + def set_output(self, field, tensor): + with _BACKWARD_LOCK: + self._outputs[field] = tensor + + def set_outputs(self, mapping): + with _BACKWARD_LOCK: + self._outputs.update(mapping) - def update_gradients(self, grad_dict): - self.gradients.update(grad_dict) + def compute_grad(self, field): + with _BACKWARD_LOCK: + out_tensor = self._outputs.get(field) + if out_tensor is None: + return 0.0, 0.0 + val = float(out_tensor.detach()) + if self.tensor.grad is not None: + self.tensor.grad.zero_() + out_tensor.backward(retain_graph=True) + grad = float(self.tensor.grad) if self.tensor.grad is not None else 0.0 + return val, grad class BaseNode(Registerable): @@ -120,25 +139,44 @@ def create_queryable(self, channel, handler) -> Queryable: return queryable def create_variable(self, name, value, mode="input", out_fields=None): + """Create a differentiable variable with automatic gradient queryables. + + For "input" mode variables with out_fields, this sets up: + - A queryable on "grad/{name}/{field}" for each field in out_fields. + Gradients are computed lazily via backward() when queried, using + output tensors registered by the user via Variable.set_outputs(). + - A subscriber on "param/{name}" that updates the variable's tensor + value when a new parameter is published. + + Args: + name: Variable identifier, used in channel names. + value: Initial scalar value for the underlying tensor. + mode: "input" creates queryables and a param subscriber. + out_fields: Output field names (e.g. ["x", "y"]) that this + variable contributes to. Each gets a gradient queryable. + """ var = Variable(name, value, mode, out_fields) self._variables[name] = var if mode == "input": + # Create a gradient queryable for each output field. + # On query, compute_grad() runs backward() on the registered + # output tensor and returns the gradient w.r.t. this variable. if var.out_fields: for field in var.out_fields: grad_channel = f"grad/{name}/{field}" def _make_handler(v, fld): def handler(_req): - return Value( - val=v.values.get(fld, 0.0), - grad=v.gradients.get(fld, 0.0), - ) + val, grad = v.compute_grad(fld) + return Value(val=val, grad=grad) return handler self.create_queryable(grad_channel, _make_handler(var, field)) + # Subscribe to parameter updates so external nodes can set + # this variable's value at runtime. def _make_sub_callback(v): def callback(msg): v.tensor.data = torch.tensor(msg.val) diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index 012e2bc..c826552 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -37,29 +37,10 @@ def spin(self): Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) ) - # Backward: compute gradients - if self.v.tensor.grad is not None: - self.v.tensor.grad.zero_() - if self.m.tensor.grad is not None: - self.m.tensor.grad.zero_() - if self.c.tensor.grad is not None: - self.c.tensor.grad.zero_() - - x.backward(retain_graph=True) - v_x = float(self.v.tensor.grad) - self.v.tensor.grad.zero_() - - y.backward() - v_y = float(self.v.tensor.grad) - m_y = float(self.m.tensor.grad) - c_y = float(self.c.tensor.grad) - - # Update variable gradients — served automatically by queryables - self.v.update_gradients({"x": v_x, "y": v_y}) - self.m.update_gradients({"x": 0.0, "y": m_y}) - self.c.update_gradients({"x": 0.0, "y": c_y}) - - print(f"t={t:.2f} dx/dv={v_x:.3f} dy/dm={m_y:.3f}") + # Register outputs — gradients computed lazily on query + self.v.set_outputs({"x": x, "y": y}) + self.m.set_outputs({"x": x, "y": y}) + self.c.set_outputs({"x": x, "y": y}) t += DT self.rate.sleep() From d51880b175ae164e70f886982fc636bdd881850b Mon Sep 17 00:00:00 2001 From: kamiradi Date: Wed, 18 Feb 2026 16:16:26 +0000 Subject: [PATCH 15/19] moves queryables into Variable class --- src/ark/node.py | 102 +++++++++++++++++++------------------- test/diff_variable_pub.py | 31 +++++------- 2 files changed, 65 insertions(+), 68 deletions(-) diff --git a/src/ark/node.py b/src/ark/node.py index 79a6094..b222132 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -19,32 +19,54 @@ class Variable: - def __init__(self, name, value, mode="input", out_fields=None): + def __init__(self, name, value, mode, variables_registry, create_queryable_fn): self.name = name self.mode = mode - self.out_fields = out_fields or [] - self.tensor = torch.tensor(value, requires_grad=True) - self._outputs = {} + self._variables_registry = variables_registry + self._grads = {} # input vars: {output_name: grad_value} - def set_output(self, field, tensor): - with _BACKWARD_LOCK: - self._outputs[field] = tensor + if mode == "input": + self._tensor = torch.tensor(value, requires_grad=True) + else: + self._tensor = None + for inp_name, inp_var in variables_registry.items(): + if inp_var.mode == "input": + grad_channel = f"grad/{inp_name}/{name}" + + def _make_handler(iv, ov_name, reg): + def handler(_req): + out_var = reg.get(ov_name) + val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0 + grad = iv._grads.get(ov_name, 0.0) + return Value(val=val, grad=grad) + return handler - def set_outputs(self, mapping): - with _BACKWARD_LOCK: - self._outputs.update(mapping) + create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry)) + + @property + def tensor(self): + return self._tensor + + @tensor.setter + def tensor(self, value): + if self.mode == "output": + self._tensor = value + self._compute_and_store_grads() + else: + self._tensor.data = value.data if isinstance(value, torch.Tensor) else torch.tensor(value) - def compute_grad(self, field): + def _compute_and_store_grads(self): + if self._tensor is None or not self._tensor.requires_grad: + return with _BACKWARD_LOCK: - out_tensor = self._outputs.get(field) - if out_tensor is None: - return 0.0, 0.0 - val = float(out_tensor.detach()) - if self.tensor.grad is not None: - self.tensor.grad.zero_() - out_tensor.backward(retain_graph=True) - grad = float(self.tensor.grad) if self.tensor.grad is not None else 0.0 - return val, grad + for var in self._variables_registry.values(): + if var.mode == "input" and var._tensor.grad is not None: + var._tensor.grad.zero_() + self._tensor.backward(retain_graph=True) + for var in self._variables_registry.values(): + if var.mode == "input": + grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0 + var._grads[self.name] = grad class BaseNode(Registerable): @@ -138,45 +160,25 @@ def create_queryable(self, channel, handler) -> Queryable: self._queriables[channel] = queryable return queryable - def create_variable(self, name, value, mode="input", out_fields=None): - """Create a differentiable variable with automatic gradient queryables. + def create_variable(self, name, value, mode="input"): + """Create a differentiable variable. - For "input" mode variables with out_fields, this sets up: - - A queryable on "grad/{name}/{field}" for each field in out_fields. - Gradients are computed lazily via backward() when queried, using - output tensors registered by the user via Variable.set_outputs(). - - A subscriber on "param/{name}" that updates the variable's tensor - value when a new parameter is published. + For "input" mode, a subscriber on "param/{name}" is created so that + external nodes can update the tensor value at runtime. + + For "output" mode, queryables are created on "grad/{input_name}/{name}" + for each existing input variable. Setting the tensor triggers an eager + backward pass that caches gradients into each input variable. Args: name: Variable identifier, used in channel names. value: Initial scalar value for the underlying tensor. - mode: "input" creates queryables and a param subscriber. - out_fields: Output field names (e.g. ["x", "y"]) that this - variable contributes to. Each gets a gradient queryable. + mode: "input" or "output". """ - var = Variable(name, value, mode, out_fields) + var = Variable(name, value, mode, self._variables, self.create_queryable) self._variables[name] = var if mode == "input": - # Create a gradient queryable for each output field. - # On query, compute_grad() runs backward() on the registered - # output tensor and returns the gradient w.r.t. this variable. - if var.out_fields: - for field in var.out_fields: - grad_channel = f"grad/{name}/{field}" - - def _make_handler(v, fld): - def handler(_req): - val, grad = v.compute_grad(fld) - return Value(val=val, grad=grad) - - return handler - - self.create_queryable(grad_channel, _make_handler(var, field)) - - # Subscribe to parameter updates so external nodes can set - # this variable's value at runtime. def _make_sub_callback(v): def callback(msg): v.tensor.data = torch.tensor(msg.val) diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index c826552..96d97f1 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -1,7 +1,5 @@ -import math -import time from ark.node import BaseNode -from ark_msgs import Translation, Value +from ark_msgs import Translation import argparse import common_example as common import torch @@ -17,31 +15,28 @@ def __init__(self, cfg): self.pos_pub = self.create_publisher("position") self.rate = self.create_rate(HZ) - # Create differentiable input variables — auto-creates grad queryables - # grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y - self.v = self.create_variable("v", 0.0, mode="input", out_fields=["x", "y"]) - self.m = self.create_variable("m", 0.0, mode="input", out_fields=["x", "y"]) - self.c = self.create_variable("c", 0.0, mode="input", out_fields=["x", "y"]) + # Input variables get param subscribers; output variables auto-create + # grad queryables (grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y) + self.v = self.create_variable("v", 0.0, mode="input") + self.m = self.create_variable("m", 0.0, mode="input") + self.c = self.create_variable("c", 0.0, mode="input") + self.x = self.create_variable("x", 0.0, mode="output") + self.y = self.create_variable("y", 0.0, mode="output") def spin(self): t = 0.0 while True: t_val = torch.tensor(t, requires_grad=False) - # Forward: y = m * x + c, where x = v * t - x = self.v.tensor * t_val - y = self.m.tensor * x + self.c.tensor + # Forward: x = v * t, y = m * x + c + # Setting output tensors triggers eager backward and caches gradients + self.x.tensor = self.v.tensor * t_val + self.y.tensor = self.m.tensor * self.x.tensor + self.c.tensor - # Publish position self.pos_pub.publish( - Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) + Translation(x=float(self.x.tensor.detach()), y=float(self.y.tensor.detach()), z=0.0) ) - # Register outputs — gradients computed lazily on query - self.v.set_outputs({"x": x, "y": y}) - self.m.set_outputs({"x": x, "y": y}) - self.c.set_outputs({"x": x, "y": y}) - t += DT self.rate.sleep() From 5c1e5a27d3c41c2472514b1b8dda0fd8bd491bfd Mon Sep 17 00:00:00 2001 From: kamiradi Date: Wed, 18 Feb 2026 16:44:03 +0000 Subject: [PATCH 16/19] reorganisation, placed Variable within variable.py --- src/ark/diff/__init__.py | 1 + src/ark/diff/variable.py | 60 ++++++++++++++++++++++++++++++++ src/ark/node.py | 73 ++------------------------------------- test/diff_variable_pub.py | 8 +++-- 4 files changed, 70 insertions(+), 72 deletions(-) create mode 100644 src/ark/diff/__init__.py create mode 100644 src/ark/diff/variable.py diff --git a/src/ark/diff/__init__.py b/src/ark/diff/__init__.py new file mode 100644 index 0000000..6777145 --- /dev/null +++ b/src/ark/diff/__init__.py @@ -0,0 +1 @@ +from ark.diff.variable import Variable diff --git a/src/ark/diff/variable.py b/src/ark/diff/variable.py new file mode 100644 index 0000000..dfa2956 --- /dev/null +++ b/src/ark/diff/variable.py @@ -0,0 +1,60 @@ +import torch +from ark_msgs import Value + + +class Variable: + + def __init__(self, name, value, mode, variables_registry, lock, create_queryable_fn): + self.name = name + self.mode = mode + self._variables_registry = variables_registry + self._lock = lock + self._grads = {} # input vars: {output_name: grad_value} + + if mode == "input": + self._tensor = torch.tensor(value, requires_grad=True) + else: + self._tensor = None + for inp_name, inp_var in variables_registry.items(): + if inp_var.mode == "input": + grad_channel = f"grad/{inp_name}/{name}" + + def _make_handler(iv, ov_name, reg, lk): + def handler(_req): + out_var = reg.get(ov_name) + with lk: + val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0 + grad = iv._grads.get(ov_name, 0.0) + return Value(val=val, grad=grad) + return handler + + create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry, self._lock)) + + @property + def tensor(self): + return self._tensor + + @tensor.setter + def tensor(self, value): + if self.mode == "output": + self._tensor = value + self._compute_and_store_grads() + else: + self._tensor.data = value.data if isinstance(value, torch.Tensor) else torch.tensor(value) + + def _is_last_output(self): + output_names = [k for k, v in self._variables_registry.items() if v.mode == "output"] + return output_names and output_names[-1] == self.name + + def _compute_and_store_grads(self): + if self._tensor is None or not self._tensor.requires_grad: + return + with self._lock: + for var in self._variables_registry.values(): + if var.mode == "input" and var._tensor.grad is not None: + var._tensor.grad.zero_() + self._tensor.backward(retain_graph=not self._is_last_output()) + for var in self._variables_registry.values(): + if var.mode == "input": + grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0 + var._grads[self.name] = grad diff --git a/src/ark/node.py b/src/ark/node.py index b222132..d955908 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -1,7 +1,6 @@ import json import time import threading -import torch import zenoh from ark.time.clock import Clock from ark.time.rate import Rate @@ -12,61 +11,7 @@ from ark.comm.queriable import Queryable from ark.data.data_collector import DataCollector from ark.core.registerable import Registerable -from ark_msgs import Value - -_BACKWARD_LOCK = threading.Lock() - - -class Variable: - - def __init__(self, name, value, mode, variables_registry, create_queryable_fn): - self.name = name - self.mode = mode - self._variables_registry = variables_registry - self._grads = {} # input vars: {output_name: grad_value} - - if mode == "input": - self._tensor = torch.tensor(value, requires_grad=True) - else: - self._tensor = None - for inp_name, inp_var in variables_registry.items(): - if inp_var.mode == "input": - grad_channel = f"grad/{inp_name}/{name}" - - def _make_handler(iv, ov_name, reg): - def handler(_req): - out_var = reg.get(ov_name) - val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0 - grad = iv._grads.get(ov_name, 0.0) - return Value(val=val, grad=grad) - return handler - - create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry)) - - @property - def tensor(self): - return self._tensor - - @tensor.setter - def tensor(self, value): - if self.mode == "output": - self._tensor = value - self._compute_and_store_grads() - else: - self._tensor.data = value.data if isinstance(value, torch.Tensor) else torch.tensor(value) - - def _compute_and_store_grads(self): - if self._tensor is None or not self._tensor.requires_grad: - return - with _BACKWARD_LOCK: - for var in self._variables_registry.values(): - if var.mode == "input" and var._tensor.grad is not None: - var._tensor.grad.zero_() - self._tensor.backward(retain_graph=True) - for var in self._variables_registry.values(): - if var.mode == "input": - grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0 - var._grads[self.name] = grad +from ark.diff.variable import Variable class BaseNode(Registerable): @@ -95,6 +40,7 @@ def __init__( self._queriers = {} self._queriables = {} self._variables = {} + self._grad_lock = threading.Lock() self._session.declare_subscriber(f"{env_name}/reset", self._on_reset) @@ -163,9 +109,6 @@ def create_queryable(self, channel, handler) -> Queryable: def create_variable(self, name, value, mode="input"): """Create a differentiable variable. - For "input" mode, a subscriber on "param/{name}" is created so that - external nodes can update the tensor value at runtime. - For "output" mode, queryables are created on "grad/{input_name}/{name}" for each existing input variable. Setting the tensor triggers an eager backward pass that caches gradients into each input variable. @@ -175,18 +118,8 @@ def create_variable(self, name, value, mode="input"): value: Initial scalar value for the underlying tensor. mode: "input" or "output". """ - var = Variable(name, value, mode, self._variables, self.create_queryable) + var = Variable(name, value, mode, self._variables, self._grad_lock, self.create_queryable) self._variables[name] = var - - if mode == "input": - def _make_sub_callback(v): - def callback(msg): - v.tensor.data = torch.tensor(msg.val) - - return callback - - self.create_subscriber(f"param/{name}", _make_sub_callback(var)) - return var def create_rate(self, hz: float): diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index 96d97f1..6a82b16 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -15,14 +15,18 @@ def __init__(self, cfg): self.pos_pub = self.create_publisher("position") self.rate = self.create_rate(HZ) - # Input variables get param subscribers; output variables auto-create - # grad queryables (grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y) + # Output variables auto-create grad queryables: + # grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y self.v = self.create_variable("v", 0.0, mode="input") self.m = self.create_variable("m", 0.0, mode="input") self.c = self.create_variable("c", 0.0, mode="input") self.x = self.create_variable("x", 0.0, mode="output") self.y = self.create_variable("y", 0.0, mode="output") + self.create_subscriber("param/v", lambda msg: self.v.tensor.data.fill_(msg.val)) + self.create_subscriber("param/m", lambda msg: self.m.tensor.data.fill_(msg.val)) + self.create_subscriber("param/c", lambda msg: self.c.tensor.data.fill_(msg.val)) + def spin(self): t = 0.0 while True: From cb10a47aaf5cb778f70264b9e2e15d049051c6d5 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Wed, 18 Feb 2026 18:17:56 +0000 Subject: [PATCH 17/19] Registers gradient query channel data on registry. Basic query discovery implementation --- src/ark/node.py | 15 ++++++ src/ark/scripts/core.py | 68 +++++++++++++++++++++++-- test/ad_plotter_sub.py | 101 +++++++++++++++++++------------------- test/diff_variable_pub.py | 10 ++-- 4 files changed, 136 insertions(+), 58 deletions(-) diff --git a/src/ark/node.py b/src/ark/node.py index d955908..4ae0c80 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -12,6 +12,7 @@ from ark.data.data_collector import DataCollector from ark.core.registerable import Registerable from ark.diff.variable import Variable +from ark_msgs import VariableInfo class BaseNode(Registerable): @@ -41,6 +42,7 @@ def __init__( self._queriables = {} self._variables = {} self._grad_lock = threading.Lock() + self._registry_pub = self.create_publisher("ark/vars/register") self._session.declare_subscriber(f"{env_name}/reset", self._on_reset) @@ -120,6 +122,19 @@ def create_variable(self, name, value, mode="input"): """ var = Variable(name, value, mode, self._variables, self._grad_lock, self.create_queryable) self._variables[name] = var + + if mode == "output": + grad_channels = [ + f"grad/{inp_name}/{name}" + for inp_name, v in self._variables.items() + if v.mode == "input" + ] + self._registry_pub.publish(VariableInfo( + output_name=name, + node_name=self._node_name, + grad_channels=grad_channels, + )) + return var def create_rate(self, hz: float): diff --git a/src/ark/scripts/core.py b/src/ark/scripts/core.py index 8760b94..638dc76 100644 --- a/src/ark/scripts/core.py +++ b/src/ark/scripts/core.py @@ -1,6 +1,68 @@ -import sys +import argparse +import time +import zenoh +from ark.node import BaseNode +from ark_msgs import VariableInfo + + +class RegistryNode(BaseNode): + + def __init__(self, cfg): + super().__init__("ark", "registry", cfg) + self._var_registry: dict[str, VariableInfo] = {} + self.create_subscriber("ark/vars/register", self._on_register) + + def _on_register(self, msg: VariableInfo): + name = msg.output_name + self._var_registry[name] = msg + channel = f"ark/vars/{name}" + if channel not in self._queriables: + def _make_handler(n): + def handler(_req): + return self._var_registry[n] + return handler + self.create_queryable(channel, _make_handler(name)) + print(f"Registered output variable '{name}' from node '{msg.node_name}' " + f"with channels: {list(msg.grad_channels)}") + + def core_registration(self): + pass + + def close(self): + super().close() def main(): - print(">>Ark core<<") - print(sys.executable) + parser = argparse.ArgumentParser( + prog="ark-core", description="Ark central registry" + ) + parser.add_argument("--mode", "-m", dest="mode", + choices=["peer", "client"], type=str) + parser.add_argument("--connect", "-e", dest="connect", + metavar="ENDPOINT", action="append", type=str) + parser.add_argument("--listen", "-l", dest="listen", + metavar="ENDPOINT", action="append", type=str) + args = parser.parse_args() + + cfg = zenoh.Config() + if args.mode: + import json + cfg.insert_json5("mode", json.dumps(args.mode)) + if args.connect: + import json + cfg.insert_json5("connect/endpoints", json.dumps(args.connect)) + if args.listen: + import json + cfg.insert_json5("listen/endpoints", json.dumps(args.listen)) + + node = RegistryNode(cfg) + print("Ark registry running.") + try: + node.spin() + except KeyboardInterrupt: + print("Shutting down registry.") + node.close() + + +if __name__ == "__main__": + main() diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index a5b32f2..6d89d38 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -1,10 +1,10 @@ +import time import threading import matplotlib.pyplot as plt import matplotlib.animation as animation from ark.node import BaseNode -from ark_msgs import Translation, Value +from ark_msgs import Value, VariableInfo -# from common import connect_cfg, z_cfg import argparse import zenoh import common_example as common @@ -14,41 +14,48 @@ class AutodiffPlotterNode(BaseNode): def __init__(self, cfg, target): super().__init__("env", "autodiff_plotter", cfg, sim=True) self.pos_x, self.pos_y = [], [] - self.grad_vx, self.grad_my = [], [] - self.grad_times = [] - self.create_subscriber("position", self.on_position) - self.grad_vx_querier = self.create_querier("grad/v/x", target=target) - self.grad_my_querier = self.create_querier("grad/m/y", target=target) - - def on_position(self, msg: Translation): - self.pos_x.append(msg.x) - self.pos_y.append(msg.y) + self._grad_queriers = {} # channel -> Querier + self._grad_data = {} # channel -> [float] + self._grad_times = [] + self.create_subscriber("x", self.on_x) + self.create_subscriber("y", self.on_y) + self._discover_grad_channels(["x", "y"], target) + + def _discover_grad_channels(self, output_names, target, timeout=5.0): + for out in output_names: + disc = self.create_querier(f"ark/vars/{out}", target=target) + deadline = time.time() + timeout + while time.time() < deadline: + try: + resp = disc.query(Value()) + if isinstance(resp, VariableInfo): + for ch in resp.grad_channels: + self._grad_queriers[ch] = self.create_querier(ch, target=target) + self._grad_data[ch] = [] + break + except Exception: + time.sleep(0.2) + + def on_x(self, msg: Value): + self.pos_x.append(msg.val) + + def on_y(self, msg: Value): + self.pos_y.append(msg.val) def fetch_grads(self): req = Value() sim_t = self._clock.now() / 1e9 - try: - resp_vx = self.grad_vx_querier.query(req) - if isinstance(resp_vx, Value): - print(f"Received grad_vx: {resp_vx.grad}") - self.grad_vx.append(resp_vx.grad) - except Exception: - pass - try: - resp_my = self.grad_my_querier.query(req) - if isinstance(resp_my, Value): - print(f"Received grad_my: {resp_my.grad}") - self.grad_my.append(resp_my.grad) - except Exception: - pass - self.grad_times.append(sim_t) + for ch, querier in self._grad_queriers.items(): + try: + resp = querier.query(req) + if isinstance(resp, Value): + self._grad_data[ch].append(resp.grad) + except Exception: + pass + self._grad_times.append(sim_t) def main(): - - # These are a few zenoh config related arguments that were taken from the - # examples, keeping them there until we have a better way to manage configs - # across examples parser = argparse.ArgumentParser(description="Autodiff Plotter Node") common.add_config_arguments(parser) parser.add_argument( @@ -68,33 +75,21 @@ def main(): type=float, help="The query timeout", ) - parser.add_argument( - "--iter", dest="iter", type=int, help="How many gets to perform" - ) - parser.add_argument( - "--add-matching-listener", - default=False, - action="store_true", - help="Add matching listener", - ) args = parser.parse_args() conf = common.get_config_from_args(args) - # These were required for the querier and queryable to find each other. target = { "ALL": zenoh.QueryTarget.ALL, "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, }.get(args.target) - # Main subcription and querying loop node = AutodiffPlotterNode(conf, target) threading.Thread(target=node.spin, daemon=True).start() - # Plotting trajectory and gradients fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) - ax_pos.set_title("Position (Translation)") + ax_pos.set_title("Position") ax_pos.set_xlabel("x") ax_pos.set_ylabel("y") ax_pos.set_aspect("equal") @@ -102,21 +97,27 @@ def main(): ax_grad.set_title("Gradients") ax_grad.set_xlabel("sim time (s)") ax_grad.set_ylabel("grad") - (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") - (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") + + colors = plt.cm.tab10.colors + grad_lines = {} + for i, ch in enumerate(node._grad_queriers): + (line,) = ax_grad.plot([], [], color=colors[i % 10], label=ch) + grad_lines[ch] = line ax_grad.legend() def update(frame): node.fetch_grads() - line_pos.set_data(node.pos_x, node.pos_y) + n = min(len(node.pos_x), len(node.pos_y)) + line_pos.set_data(node.pos_x[:n], node.pos_y[:n]) ax_pos.relim() ax_pos.autoscale_view() - times = node.grad_times - line_grad_vx.set_data(times[: len(node.grad_vx)], node.grad_vx) - line_grad_my.set_data(times[: len(node.grad_my)], node.grad_my) + times = node._grad_times + for ch, line in grad_lines.items(): + data = node._grad_data[ch] + line.set_data(times[: len(data)], data) ax_grad.relim() ax_grad.autoscale_view() - return line_pos, line_grad_vx, line_grad_my + return line_pos, *grad_lines.values() ani = animation.FuncAnimation(fig, update, interval=50, blit=False) plt.tight_layout() diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index 6a82b16..4319f56 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -1,5 +1,5 @@ from ark.node import BaseNode -from ark_msgs import Translation +from ark_msgs import Value import argparse import common_example as common import torch @@ -12,7 +12,8 @@ class LineVariableNode(BaseNode): def __init__(self, cfg): super().__init__("env", "line_var_pub", cfg, sim=True) - self.pos_pub = self.create_publisher("position") + self.x_pub = self.create_publisher("x") + self.y_pub = self.create_publisher("y") self.rate = self.create_rate(HZ) # Output variables auto-create grad queryables: @@ -37,9 +38,8 @@ def spin(self): self.x.tensor = self.v.tensor * t_val self.y.tensor = self.m.tensor * self.x.tensor + self.c.tensor - self.pos_pub.publish( - Translation(x=float(self.x.tensor.detach()), y=float(self.y.tensor.detach()), z=0.0) - ) + self.x_pub.publish(Value(val=float(self.x.tensor.detach()))) + self.y_pub.publish(Value(val=float(self.y.tensor.detach()))) t += DT self.rate.sleep() From 04f6095074aa47933b61b52400c5a54e6558bff9 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Thu, 19 Feb 2026 12:05:35 +0000 Subject: [PATCH 18/19] impose temporal correlation between variable and gradients --- src/ark/diff/variable.py | 8 ++++++-- src/ark/node.py | 2 +- test/ad_plotter_sub.py | 13 ++++++++----- test/diff_variable_pub.py | 5 +++-- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/ark/diff/variable.py b/src/ark/diff/variable.py index dfa2956..46055f9 100644 --- a/src/ark/diff/variable.py +++ b/src/ark/diff/variable.py @@ -4,17 +4,19 @@ class Variable: - def __init__(self, name, value, mode, variables_registry, lock, create_queryable_fn): + def __init__(self, name, value, mode, variables_registry, lock, clock, create_queryable_fn): self.name = name self.mode = mode self._variables_registry = variables_registry self._lock = lock + self._clock = clock self._grads = {} # input vars: {output_name: grad_value} if mode == "input": self._tensor = torch.tensor(value, requires_grad=True) else: self._tensor = None + self._computation_ts = clock.now() for inp_name, inp_var in variables_registry.items(): if inp_var.mode == "input": grad_channel = f"grad/{inp_name}/{name}" @@ -25,7 +27,8 @@ def handler(_req): with lk: val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0 grad = iv._grads.get(ov_name, 0.0) - return Value(val=val, grad=grad) + ts = out_var._computation_ts if out_var else 0 + return Value(val=val, grad=grad, timestamp=ts) return handler create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry, self._lock)) @@ -58,3 +61,4 @@ def _compute_and_store_grads(self): if var.mode == "input": grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0 var._grads[self.name] = grad + self._computation_ts = self._clock.now() diff --git a/src/ark/node.py b/src/ark/node.py index 4ae0c80..f9188ee 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -120,7 +120,7 @@ def create_variable(self, name, value, mode="input"): value: Initial scalar value for the underlying tensor. mode: "input" or "output". """ - var = Variable(name, value, mode, self._variables, self._grad_lock, self.create_queryable) + var = Variable(name, value, mode, self._variables, self._grad_lock, self._clock, self.create_queryable) self._variables[name] = var if mode == "output": diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 6d89d38..8ea9e90 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -14,9 +14,10 @@ class AutodiffPlotterNode(BaseNode): def __init__(self, cfg, target): super().__init__("env", "autodiff_plotter", cfg, sim=True) self.pos_x, self.pos_y = [], [] + self.pos_x_ts, self.pos_y_ts = [], [] self._grad_queriers = {} # channel -> Querier self._grad_data = {} # channel -> [float] - self._grad_times = [] + self._grad_ts = {} # channel -> [int] self.create_subscriber("x", self.on_x) self.create_subscriber("y", self.on_y) self._discover_grad_channels(["x", "y"], target) @@ -27,32 +28,34 @@ def _discover_grad_channels(self, output_names, target, timeout=5.0): deadline = time.time() + timeout while time.time() < deadline: try: - resp = disc.query(Value()) + resp = disc.query(VariableInfo()) if isinstance(resp, VariableInfo): for ch in resp.grad_channels: self._grad_queriers[ch] = self.create_querier(ch, target=target) self._grad_data[ch] = [] + self._grad_ts[ch] = [] break except Exception: time.sleep(0.2) def on_x(self, msg: Value): self.pos_x.append(msg.val) + self.pos_x_ts.append(msg.timestamp) def on_y(self, msg: Value): self.pos_y.append(msg.val) + self.pos_y_ts.append(msg.timestamp) def fetch_grads(self): req = Value() - sim_t = self._clock.now() / 1e9 for ch, querier in self._grad_queriers.items(): try: resp = querier.query(req) if isinstance(resp, Value): self._grad_data[ch].append(resp.grad) + self._grad_ts[ch].append(resp.timestamp) except Exception: pass - self._grad_times.append(sim_t) def main(): @@ -111,9 +114,9 @@ def update(frame): line_pos.set_data(node.pos_x[:n], node.pos_y[:n]) ax_pos.relim() ax_pos.autoscale_view() - times = node._grad_times for ch, line in grad_lines.items(): data = node._grad_data[ch] + times = [t / 1e9 for t in node._grad_ts[ch]] line.set_data(times[: len(data)], data) ax_grad.relim() ax_grad.autoscale_view() diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index 4319f56..eb6934e 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -38,8 +38,9 @@ def spin(self): self.x.tensor = self.v.tensor * t_val self.y.tensor = self.m.tensor * self.x.tensor + self.c.tensor - self.x_pub.publish(Value(val=float(self.x.tensor.detach()))) - self.y_pub.publish(Value(val=float(self.y.tensor.detach()))) + ts = self._clock.now() + self.x_pub.publish(Value(val=float(self.x.tensor.detach()), timestamp=ts)) + self.y_pub.publish(Value(val=float(self.y.tensor.detach()), timestamp=ts)) t += DT self.rate.sleep() From e5334e76e660902d588adef30770c4148f95a3be Mon Sep 17 00:00:00 2001 From: kamiradi Date: Thu, 19 Feb 2026 17:32:55 +0000 Subject: [PATCH 19/19] adds functionality to query gradient at a time step --- src/ark/diff/variable.py | 16 ++++++++++++ src/ark/node.py | 1 + test/ad_plotter_sub.py | 44 ++++++++++++++++++++++++++++--- test/diff_variable_pub.py | 54 +++++++++++++++++++++++++++++---------- 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/src/ark/diff/variable.py b/src/ark/diff/variable.py index 46055f9..3330c97 100644 --- a/src/ark/diff/variable.py +++ b/src/ark/diff/variable.py @@ -14,9 +14,12 @@ def __init__(self, name, value, mode, variables_registry, lock, clock, create_qu if mode == "input": self._tensor = torch.tensor(value, requires_grad=True) + self._history = {} + self._replay_tensor = None else: self._tensor = None self._computation_ts = clock.now() + self._replay_fn = None for inp_name, inp_var in variables_registry.items(): if inp_var.mode == "input": grad_channel = f"grad/{inp_name}/{name}" @@ -24,6 +27,9 @@ def __init__(self, name, value, mode, variables_registry, lock, clock, create_qu def _make_handler(iv, ov_name, reg, lk): def handler(_req): out_var = reg.get(ov_name) + if _req.timestamp != 0 and out_var._replay_fn: + val, grad = out_var._replay_fn(_req.timestamp, iv.name, ov_name) + return Value(val=val, grad=grad, timestamp=_req.timestamp) with lk: val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0 grad = iv._grads.get(ov_name, 0.0) @@ -33,6 +39,16 @@ def handler(_req): create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry, self._lock)) + def snapshot(self, ts): + """Record current tensor value at clock timestamp ts.""" + self._history[ts] = float(self._tensor.detach()) + + def at(self, ts): + """Return a fresh requires_grad tensor from history at ts.""" + val = self._history[ts] + self._replay_tensor = torch.tensor(val, requires_grad=True) + return self._replay_tensor + @property def tensor(self): return self._tensor diff --git a/src/ark/node.py b/src/ark/node.py index f9188ee..1a01706 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -1,6 +1,7 @@ import json import time import threading +import torch import zenoh from ark.time.clock import Clock from ark.time.rate import Rate diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 8ea9e90..e8b3777 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -57,6 +57,18 @@ def fetch_grads(self): except Exception: pass + def fetch_grads_at(self, ts): + req = Value(timestamp=ts) + results = {} + for ch, querier in self._grad_queriers.items(): + try: + resp = querier.query(req) + if isinstance(resp, Value): + results[ch] = (resp.val, resp.grad) + except Exception: + pass + return results + def main(): parser = argparse.ArgumentParser(description="Autodiff Plotter Node") @@ -91,13 +103,13 @@ def main(): node = AutodiffPlotterNode(conf, target) threading.Thread(target=node.spin, daemon=True).start() - fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) + fig, (ax_pos, ax_grad, ax_replay) = plt.subplots(1, 3, figsize=(18, 5)) ax_pos.set_title("Position") ax_pos.set_xlabel("x") ax_pos.set_ylabel("y") ax_pos.set_aspect("equal") (line_pos,) = ax_pos.plot([], [], "b-") - ax_grad.set_title("Gradients") + ax_grad.set_title("Gradients (live)") ax_grad.set_xlabel("sim time (s)") ax_grad.set_ylabel("grad") @@ -108,8 +120,28 @@ def main(): grad_lines[ch] = line ax_grad.legend() + ax_replay.set_title("Gradients (replay)") + ax_replay.set_xlabel("sim time (s)") + ax_replay.set_ylabel("grad") + replay_data = {ch: [] for ch in node._grad_queriers} + replay_ts = {ch: [] for ch in node._grad_queriers} + replay_lines = {} + for i, ch in enumerate(node._grad_queriers): + (line,) = ax_replay.plot([], [], color=colors[i % 10], label=ch) + replay_lines[ch] = line + ax_replay.legend() + def update(frame): node.fetch_grads() + + # Replay: query gradient at a historical timestamp + if len(node.pos_x_ts) > 10: + historical_ts = node.pos_x_ts[-10] + results = node.fetch_grads_at(historical_ts) + for ch, (val, grad) in results.items(): + replay_data[ch].append(grad) + replay_ts[ch].append(historical_ts) + n = min(len(node.pos_x), len(node.pos_y)) line_pos.set_data(node.pos_x[:n], node.pos_y[:n]) ax_pos.relim() @@ -120,7 +152,13 @@ def update(frame): line.set_data(times[: len(data)], data) ax_grad.relim() ax_grad.autoscale_view() - return line_pos, *grad_lines.values() + for ch, line in replay_lines.items(): + data = replay_data[ch] + times = [t / 1e9 for t in replay_ts[ch]] + line.set_data(times[: len(data)], data) + ax_replay.relim() + ax_replay.autoscale_view() + return line_pos, *grad_lines.values(), *replay_lines.values() ani = animation.FuncAnimation(fig, update, interval=50, blit=False) plt.tight_layout() diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py index eb6934e..de18f63 100644 --- a/test/diff_variable_pub.py +++ b/test/diff_variable_pub.py @@ -14,7 +14,6 @@ def __init__(self, cfg): super().__init__("env", "line_var_pub", cfg, sim=True) self.x_pub = self.create_publisher("x") self.y_pub = self.create_publisher("y") - self.rate = self.create_rate(HZ) # Output variables auto-create grad queryables: # grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y @@ -28,22 +27,49 @@ def __init__(self, cfg): self.create_subscriber("param/m", lambda msg: self.m.tensor.data.fill_(msg.val)) self.create_subscriber("param/c", lambda msg: self.c.tensor.data.fill_(msg.val)) - def spin(self): - t = 0.0 - while True: - t_val = torch.tensor(t, requires_grad=False) + self.x._replay_fn = self._replay_grad + self.y._replay_fn = self._replay_grad - # Forward: x = v * t, y = m * x + c - # Setting output tensors triggers eager backward and caches gradients - self.x.tensor = self.v.tensor * t_val - self.y.tensor = self.m.tensor * self.x.tensor + self.c.tensor + self.create_stepper(HZ, self.step) - ts = self._clock.now() - self.x_pub.publish(Value(val=float(self.x.tensor.detach()), timestamp=ts)) - self.y_pub.publish(Value(val=float(self.y.tensor.detach()), timestamp=ts)) + def forward(self, ts, replay=False): + """Compute outputs from inputs at a given timestamp. - t += DT - self.rate.sleep() + Builds the computation graph parameterised by ts so that + gradients can later be evaluated at arbitrary times. + When replay=True, uses historical input values at ts. + """ + if replay: + v, m, c = self.v.at(ts), self.m.at(ts), self.c.at(ts) + else: + v, m, c = self.v.tensor, self.m.tensor, self.c.tensor + + t_val = torch.tensor(ts / 1e9, requires_grad=False) + x = v * t_val + y = m * x + c + return x, y + + def _replay_grad(self, ts, input_name, output_name): + x, y = self.forward(ts, replay=True) + outputs = {'x': x, 'y': y} + inp_var = self._variables[input_name] + (grad,) = torch.autograd.grad(outputs[output_name], inp_var._replay_tensor, retain_graph=True, allow_unused=True) + return float(outputs[output_name].detach()), float(grad) if grad is not None else 0.0 + + def step(self, ts): + x, y = self.forward(ts) + + # Setting output tensors triggers eager backward and caches gradients + self.x.tensor = x + self.y.tensor = y + + # Snapshot input values at this timestamp + self.v.snapshot(ts) + self.m.snapshot(ts) + self.c.snapshot(ts) + + self.x_pub.publish(Value(val=float(self.x.tensor.detach()), timestamp=ts)) + self.y_pub.publish(Value(val=float(self.y.tensor.detach()), timestamp=ts)) if __name__ == "__main__":