-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsqlite.py
More file actions
423 lines (344 loc) · 15.6 KB
/
sqlite.py
File metadata and controls
423 lines (344 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
"""
This module provides a custom SQLite client implementation `Sqlite` that communicates with an external SQLite service over a TCP socket.
"""
# coding=utf-8
import json
import math
import re
import socket
import struct
from typing import Optional, Any, List
# Defaults
_SQLITE_IP = "127.0.0.1"
_SQLITE_PORT = 3333
_SQLITE_AUTH = ""
# The server serialises BLOB columns as a JSON string of the form X'<hex>' (see
# RequestHandler.cpp). This matches the literal we emit in _serialize_sql_value.
_BLOB_LITERAL_RE = re.compile(r"^[Xx]'([0-9A-Fa-f]*)'$")
def decode_blob_literal(value: Any) -> Optional[bytes]:
"""
Decodes a BLOB value returned by the server into ``bytes``.
BLOB columns arrive as a hex string literal ``X'..'`` (the server cannot put raw
binary into JSON). This reverses that encoding.
Decoding must be explicit because the response carries no type information: a TEXT
column literally containing ``X'00ff'`` is indistinguishable from a BLOB, so callers
opt in per column rather than risk corrupting text values.
:param value: ``None``, raw ``bytes``/``bytearray``, or an ``X'..'`` string literal.
:return: The decoded ``bytes``, or ``None`` for a NULL value.
:raises ValueError: If ``value`` is a string that is not a valid ``X'..'`` literal.
"""
if value is None:
return None
if isinstance(value, (bytes, bytearray)):
return bytes(value)
match = _BLOB_LITERAL_RE.match(value) if isinstance(value, str) else None
if not match:
raise ValueError(f"Not a BLOB literal: {value!r}")
return bytes.fromhex(match.group(1))
class Row(dict):
"""
A single result row.
Behaves exactly like a ``dict`` (so ``row["col"]``, ``.get()``, ``in`` and JSON
serialization all keep working) but additionally exposes columns as attributes::
row["name"] == row.name
BLOB columns arrive as ``X'..'`` hex strings; use :meth:`blob` to decode one to bytes.
"""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
try:
del self[name]
except KeyError:
raise AttributeError(name)
def blob(self, key: str) -> Optional[bytes]:
"""
Returns column ``key`` decoded from its ``X'..'`` hex literal into ``bytes``.
Returns ``None`` when the column is NULL. Raises ``ValueError`` if the column does
not hold a BLOB literal (i.e. it was a plain TEXT/number value).
"""
return decode_blob_literal(self[key])
class QueryResult:
"""
A lightweight, read-only wrapper around a query response.
The server replies with a JSON object shaped like ``{"data": [{col: val, ...}, ...]}``.
This wrapper exposes that payload as an iterable, sized and truthy sequence of
:class:`Row` objects, while still supporting the legacy ``result["data"]`` access so
existing call sites keep working unchanged.
A ``QueryResult`` is *always* well-formed: a ``None`` or malformed payload simply
yields an empty result, so callers never have to guard against ``None``.
Example:
result = db.query("SELECT COUNT(*) AS count FROM users")
if result: # truthy only when there are rows
total = result.scalar() # first column of first row
for row in result: # iterate rows directly
print(row.name) # columns reachable as attributes
"""
__slots__ = ("_payload",)
def __init__(self, payload: Any = None):
# Normalise into a dict carrying a "data" list, regardless of what we received.
if isinstance(payload, list):
payload = {"data": payload}
elif not isinstance(payload, dict):
payload = {}
# Wrap row dicts as Rows so columns are reachable as attributes (row.col).
data = payload.get("data")
if isinstance(data, list):
payload = {**payload, "data": [r if isinstance(r, Row) else Row(r)
for r in data if isinstance(r, dict)]}
self._payload = payload
@property
def rows(self) -> List[Row]:
"""The list of :class:`Row` objects (always a list, never ``None``)."""
data = self._payload.get("data")
return data if isinstance(data, list) else []
@property
def columns(self) -> List[str]:
"""
The query's column names, in ``SELECT`` order.
This is the authoritative column order: the server serialises each row as a JSON
object whose keys come back alphabetically sorted, so a row's ``keys()`` do *not*
reflect the original order. ``columns`` is also present when the result is empty.
"""
cols = self._payload.get("columns")
return cols if isinstance(cols, list) else []
# --- Sequence protocol over rows -------------------------------------
def __iter__(self):
return iter(self.rows)
def __len__(self) -> int:
return len(self.rows)
def __bool__(self) -> bool:
return len(self.rows) > 0
def __getitem__(self, key):
# Legacy/dict access: result["data"] always yields the rows list (never raises);
# any other string key indexes the raw payload; int/slice indexes the rows.
if key == "data":
return self.rows
if isinstance(key, str):
return self._payload[key]
return self.rows[key]
def __contains__(self, key) -> bool:
return key in self._payload
def __repr__(self) -> str:
return f"QueryResult(rows={len(self)})"
# --- Convenience accessors -------------------------------------------
def first(self) -> Optional[Row]:
"""The first row, or ``None`` when the result set is empty."""
rows = self.rows
return rows[0] if rows else None
def scalar(self, default: Any = None) -> Any:
"""
The first column of the first row.
Ideal for single-value queries such as ``COUNT(*)`` or ``MAX(...)``.
Returns ``default`` when there are no rows.
Uses ``columns[0]`` for the true first column when available, since a row's own
key order is alphabetical rather than ``SELECT`` order.
"""
row = self.first()
if not row:
return default
cols = self.columns
key = cols[0] if cols else next(iter(row), None)
return row.get(key, default) if key is not None else default
def column(self, name: str) -> List[Any]:
"""Every value of column ``name`` across all rows."""
return [row[name] for row in self.rows if name in row]
def get(self, key, default=None):
"""Dict-style access to the raw payload (e.g. metadata keys)."""
return self._payload.get(key, default)
class Sqlite:
"""
A client for connecting to and executing queries against an external SQLite server via TCP sockets.
Supports context management (with-statement) for safe resource cleanup.
Example:
with Sqlite("my_database") as db:
result = db.query("SELECT * FROM users WHERE id = ?", [1])
"""
def __init__(
self,
database: str,
ip: str = _SQLITE_IP,
port: int = _SQLITE_PORT,
auth: str = _SQLITE_AUTH,
):
"""
Initializes the Sqlite client and establishes a TCP connection.
If ``auth`` is non-empty, the ``{"auth": ...}`` handshake is sent immediately
after connecting (required when the server is started with ``--auth``).
:param database: The name or path of the database file on the server.
:param ip: Server IP/host to connect to (defaults to ``_SQLITE_IP``).
:param port: Server TCP port (defaults to ``_SQLITE_PORT``).
:param auth: Authentication password; an empty string skips the handshake
(defaults to ``_SQLITE_AUTH``).
"""
self._database = database
self._auth = auth
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.settimeout(120.0)
self._sock.connect((ip, port))
if len(auth) != 0:
self._send_auth()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self) -> None:
"""
Closes the socket connection safely. This method shuts down both sending and receiving operations
before completely closing the socket to ensure a graceful disconnect.
"""
if self._sock:
try:
# Shutdown tells the other end we are done sending/receiving
self._sock.shutdown(socket.SHUT_RDWR)
except OSError:
# Connection might already be closed or reset
pass
finally:
try:
self._sock.close()
except OSError:
pass
def query(self, query: str, params: Optional[List[Any]] = None) -> QueryResult:
"""
Executes a SQL query on the server and returns the result.
If parameters are provided, they are safely escaped client-side and injected into the query string
before transmission. The server responds with JSON-encoded data.
:param query: The SQL query string, optionally containing '?' or '?N' placeholders.
:param params: A list of parameter values to bind to the placeholders.
:return: A :class:`QueryResult` wrapping the response. Always returned (never None);
an error or empty response yields an empty result.
"""
try:
if params:
query = self._client_side_prepare(query, params)
data = self._send_query(query)
payload = json.loads(data) if data else None
# The server reports failures as a JSON object with no "data" key (e.g.
# {"error_code", "error_message"} or {"generic_error"}). Surface it instead of
# silently degrading to an empty result set.
if isinstance(payload, dict) and (
"error_message" in payload or "error_code" in payload or "generic_error" in payload
):
print(f"Sqlite.query server error: {payload.get('error_message', payload)}")
return QueryResult(payload)
except Exception as e:
print(f"Sqlite.query error: {e}")
return QueryResult()
def send_query(self, query: str, params: Optional[List[Any]] = None) -> None:
"""
Executes a SQL query without expecting a return value.
Useful for INSERT, UPDATE, or DELETE operations where the result set is not needed.
:param query: The SQL query string.
:param params: A list of parameter values to bind.
"""
self.query(query, params)
def _client_side_prepare(self, query: str, params: List[Any]) -> str:
"""
Replaces '?' and '?N' placeholders with sanitized parameter values.
Both placeholder styles are resolved in a single left-to-right pass so that a
substituted value is never re-scanned: this keeps the escaping intact even when
a parameter value itself contains a '?' character.
- '?N' (positional) binds to ``params[N - 1]``.
- '?' (standard) binds to the next not-yet-consumed standard parameter.
:param query: The raw SQL query.
:param params: The parameter values.
:return: The formatted query string with parameters safely injected.
"""
# Standard '?' placeholders are consumed in order, independently of positional ones.
param_iter = iter(params)
def replacement(match):
digits = match.group(1)
if digits:
# Positional parameter (?N): SQL ?1 maps to index 0.
index = int(digits) - 1
if 0 <= index < len(params):
return self._serialize_sql_value(params[index])
return match.group(0)
# Standard parameter (?): take the next available value.
try:
return self._serialize_sql_value(next(param_iter))
except StopIteration:
return "?" # No more params left
# A single regex matches both forms; '?N' is tried before bare '?' via alternation.
return re.sub(r'\?(\d+)|\?', replacement, query)
@staticmethod
def _serialize_sql_value(value: Any) -> str:
"""
Sanitizes Python types into SQL-safe string literals to prevent injection.
:param value: The Python value to serialize.
:return: The SQL string literal.
"""
if value is None:
return "NULL"
if isinstance(value, bool):
return "1" if value else "0"
if isinstance(value, float) and not math.isfinite(value):
# SQLite has no NaN/Infinity literals: NaN -> NULL, +/-Inf -> +/-9e999.
if math.isnan(value):
return "NULL"
return "9e999" if value > 0 else "-9e999"
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, (bytes, bytearray)):
# Emit a BLOB literal X'..' so binary data round-trips without corruption.
return f"X'{value.hex()}'"
# Escape single quotes by doubling them for security
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _send_auth(self) -> Optional[str]:
payload = {"auth": self._auth}
self._send_data(json.dumps(payload))
return self._recv_data()
def _send_query(self, query: str) -> Optional[str]:
"""
Constructs the JSON payload and sends it over the socket.
:param query: The final, prepared SQL query string.
:return: The raw string response from the server.
"""
payload = {
"db": self._database,
"cmd": "QUERY",
"query": query
}
self._send_data(json.dumps(payload))
return self._recv_data()
def _send_data(self, data: str) -> None:
"""
Encodes the string data and prepends a 4-byte little-endian length header before sending.
:param data: The JSON payload string.
"""
encoded_data = data.encode("utf-8")
# Header is a 4-byte unsigned little-endian length, matching the server's uint32_t.
header = struct.pack("<I", len(encoded_data))
self._sock.sendall(header + encoded_data)
def _recv_data(self) -> Optional[str]:
"""
Reads the 4-byte header to determine payload size, then reads the full payload.
:return: The decoded string response, or None if the connection fails.
"""
# 1. Read 4-byte header
header = self._read_n_bytes(4)
if not header:
return None
size, = struct.unpack("<I", header)
# 2. Read the full payload based on header size
payload = self._read_n_bytes(size)
return payload.decode("utf-8") if payload else None
def _read_n_bytes(self, n: int) -> Optional[bytes]:
"""
Helper to ensure we receive exactly N bytes from the socket.
:param n: The number of bytes to read.
:return: The byte string read, or None if the connection closed early.
"""
data = b''
while len(data) < n:
chunk = self._sock.recv(n - len(data))
if not chunk: # Connection closed
return None
data += chunk
return data