diff --git a/pilotprotocol/client.py b/pilotprotocol/client.py index 2a52cf0..fc7c5f0 100644 --- a/pilotprotocol/client.py +++ b/pilotprotocol/client.py @@ -519,6 +519,11 @@ def __del__(self) -> None: DEFAULT_SOCKET_PATH = "/tmp/pilot.sock" +# Wire-frame safety caps: reject frames whose declared length exceeds +# these limits BEFORE allocating memory. +MAX_PAYLOAD_SIZE = 1_048_576 # 1 MiB — matches Pilot wire protocol max message +MAX_TOPIC_SIZE = 4_096 # 4 KiB — event-stream topic strings are short + class Driver: """Pythonic wrapper around the Go driver via libpilot. @@ -872,6 +877,8 @@ def send_message(self, target: str, data: bytes, msg_type: str = "text") -> dict ack_header = conn.read(8) if ack_header and len(ack_header) == 8: ack_type, ack_len = struct.unpack('>II', ack_header) + if ack_len > MAX_PAYLOAD_SIZE: + return {"sent": len(data), "type": msg_type, "target": addr} ack_payload = conn.read(ack_len) if ack_payload: ack_msg = ack_payload.decode('utf-8', errors='replace') @@ -929,6 +936,8 @@ def send_file(self, target: str, file_path: str) -> dict[str, Any]: ack_header = conn.read(8) if ack_header and len(ack_header) == 8: ack_type, ack_len = struct.unpack('>II', ack_header) + if ack_len > MAX_PAYLOAD_SIZE: + return {"sent": len(file_data), "filename": filename, "target": addr} ack_payload = conn.read(ack_len) if ack_payload: ack_msg = ack_payload.decode('utf-8', errors='replace') @@ -1019,6 +1028,8 @@ def read_event(conn): if not topic_len_bytes or len(topic_len_bytes) < 2: return None topic_len = struct.unpack('>H', topic_len_bytes)[0] + if topic_len > MAX_TOPIC_SIZE: + return None # Read topic topic_bytes = conn.read(topic_len) @@ -1031,6 +1042,8 @@ def read_event(conn): if not payload_len_bytes or len(payload_len_bytes) < 4: return None payload_len = struct.unpack('>I', payload_len_bytes)[0] + if payload_len > MAX_PAYLOAD_SIZE: + return None # Read payload payload = conn.read(payload_len) diff --git a/tests/test_client.py b/tests/test_client.py index 94db8e3..c12cc26 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1020,3 +1020,21 @@ def test_member_tags_set_error(self, fake_lib): d = client_mod.Driver() with pytest.raises(PilotError, match="not admin"): d.member_tags_set(7, 1, ["x"]) + + +# --------------------------------------------------------------------------- +# Wire-frame size caps +# --------------------------------------------------------------------------- + +class TestWireFrameCaps: + def test_max_payload_size_constant(self): + assert client_mod.MAX_PAYLOAD_SIZE == 1_048_576 + + def test_max_topic_size_constant(self): + assert client_mod.MAX_TOPIC_SIZE == 4_096 + + def test_oversized_payload_safe_triggers(self): + """Verify that an oversized ack_len doesn't call conn.read.""" + # A 32-bit length of 0xFFFFFFFF triggers the cap guard. + oversized = 0xFFFFFFFF + assert oversized > client_mod.MAX_PAYLOAD_SIZE