Skip to content

Commit f4ea3b2

Browse files
committed
Add statement level query tag support by introducing it as a parameter on execute* methods
1 parent 61f8029 commit f4ea3b2

File tree

6 files changed

+201
-7
lines changed

6 files changed

+201
-7
lines changed

examples/query_tags_example.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,23 @@
77
Query Tags are key-value pairs that can be attached to SQL executions and will appear
88
in the system.query.history table for analytical purposes.
99
10-
Format: "key1:value1,key2:value2,key3:value3"
10+
There are two ways to set query tags:
11+
1. Session-level: Set in session_configuration (applies to all queries in the session)
12+
2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query)
13+
14+
Format: Dictionary with string keys and optional string values
15+
Example: {"team": "engineering", "application": "etl", "priority": "high"}
16+
17+
Special cases:
18+
- If a value is None, only the key is included (no colon or value)
19+
- Special characters (:, ,, \\) in values are automatically escaped
20+
- Keys are not escaped (should be controlled identifiers)
1121
"""
1222

1323
print("=== Query Tags Example ===\n")
1424

25+
# Example 1: Session-level query tags (old approach)
26+
print("Example 1: Session-level query tags")
1527
with sql.connect(
1628
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
1729
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
@@ -21,10 +33,64 @@
2133
'ansi_mode': False
2234
}
2335
) as connection:
24-
36+
2537
with connection.cursor() as cursor:
2638
cursor.execute("SELECT 1")
2739
result = cursor.fetchone()
2840
print(f" Result: {result[0]}")
2941

42+
print()
43+
44+
# Example 2: Per-query query tags (new approach)
45+
print("Example 2: Per-query query tags")
46+
with sql.connect(
47+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
48+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
49+
access_token=os.getenv("DATABRICKS_TOKEN"),
50+
) as connection:
51+
52+
with connection.cursor() as cursor:
53+
# Query 1: Tags for a critical ETL job
54+
cursor.execute(
55+
"SELECT 1",
56+
query_tags={"team": "data-eng", "application": "etl", "priority": "high"}
57+
)
58+
result = cursor.fetchone()
59+
print(f" ETL Query Result: {result[0]}")
60+
61+
# Query 2: Tags with None value (key-only tag)
62+
cursor.execute(
63+
"SELECT 2",
64+
query_tags={"team": "analytics", "experimental": None}
65+
)
66+
result = cursor.fetchone()
67+
print(f" Experimental Query Result: {result[0]}")
68+
69+
# Query 3: Tags with special characters (automatically escaped)
70+
cursor.execute(
71+
"SELECT 3",
72+
query_tags={"description": "test:with:colons,and,commas"}
73+
)
74+
result = cursor.fetchone()
75+
print(f" Special Chars Query Result: {result[0]}")
76+
77+
print()
78+
79+
# Example 3: Async execution with query tags
80+
print("Example 3: Async execution with query tags")
81+
with sql.connect(
82+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
83+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
84+
access_token=os.getenv("DATABRICKS_TOKEN"),
85+
) as connection:
86+
87+
with connection.cursor() as cursor:
88+
cursor.execute_async(
89+
"SELECT 4",
90+
query_tags={"team": "data-eng", "mode": "async"}
91+
)
92+
cursor.get_async_execution_result()
93+
result = cursor.fetchone()
94+
print(f" Async Query Result: {result[0]}")
95+
3096
print("\n=== Query Tags Example Complete ===")

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def execute_command(
8383
async_op: bool,
8484
enforce_embedded_schema_correctness: bool,
8585
row_limit: Optional[int] = None,
86+
query_tags: Optional[Dict[str, Optional[str]]] = None,
8687
) -> Union[ResultSet, None]:
8788
"""
8889
Executes a SQL command or query within the specified session.
@@ -102,6 +103,7 @@ def execute_command(
102103
async_op: Whether to execute the command asynchronously
103104
enforce_embedded_schema_correctness: Whether to enforce schema correctness
104105
row_limit: Maximum number of rows in the response.
106+
query_tags: Optional dictionary of query tags to apply for this query only.
105107
106108
Returns:
107109
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/thrift_backend.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import time
77
import threading
8-
from typing import List, Optional, Union, Any, TYPE_CHECKING
8+
from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING
99
from uuid import UUID
1010

1111
from databricks.sql.common.unified_http_client import UnifiedHttpClient
@@ -53,6 +53,7 @@
5353
convert_arrow_based_set_to_arrow_table,
5454
convert_decimals_in_arrow_table,
5555
convert_column_based_set_to_arrow_table,
56+
serialize_query_tags,
5657
)
5758
from databricks.sql.types import SSLOptions
5859
from databricks.sql.backend.databricks_client import DatabricksClient
@@ -1003,6 +1004,7 @@ def execute_command(
10031004
async_op=False,
10041005
enforce_embedded_schema_correctness=False,
10051006
row_limit: Optional[int] = None,
1007+
query_tags: Optional[Dict[str, Optional[str]]] = None,
10061008
) -> Union["ResultSet", None]:
10071009
thrift_handle = session_id.to_thrift_handle()
10081010
if not thrift_handle:
@@ -1022,6 +1024,19 @@ def execute_command(
10221024
# DBR should be changed to use month_day_nano_interval
10231025
intervalTypesAsArrow=False,
10241026
)
1027+
1028+
# Build confOverlay with default configs and query_tags
1029+
merged_conf_overlay = {
1030+
# We want to receive proper Timestamp arrow types.
1031+
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
1032+
}
1033+
1034+
# Serialize and add query_tags to confOverlay if provided
1035+
if query_tags:
1036+
serialized_tags = serialize_query_tags(query_tags)
1037+
if serialized_tags:
1038+
merged_conf_overlay["query_tags"] = serialized_tags
1039+
10251040
req = ttypes.TExecuteStatementReq(
10261041
sessionHandle=thrift_handle,
10271042
statement=operation,
@@ -1036,10 +1051,7 @@ def execute_command(
10361051
canReadArrowResult=True if pyarrow else False,
10371052
canDecompressLZ4Result=lz4_compression,
10381053
canDownloadResult=use_cloud_fetch,
1039-
confOverlay={
1040-
# We want to receive proper Timestamp arrow types.
1041-
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
1042-
},
1054+
confOverlay=merged_conf_overlay,
10431055
useArrowNativeTypes=spark_arrow_types,
10441056
parameters=parameters,
10451057
enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness,

src/databricks/sql/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ def execute(
12631263
parameters: Optional[TParameterCollection] = None,
12641264
enforce_embedded_schema_correctness=False,
12651265
input_stream: Optional[BinaryIO] = None,
1266+
query_tags: Optional[Dict[str, Optional[str]]] = None,
12661267
) -> "Cursor":
12671268
"""
12681269
Execute a query and wait for execution to complete.
@@ -1293,6 +1294,10 @@ def execute(
12931294
Both will result in the query equivalent to "SELECT * FROM table WHERE field = 'foo'
12941295
being sent to the server
12951296
1297+
:param query_tags: Optional dictionary of query tags to apply for this query only.
1298+
Tags are key-value pairs that can be used to identify and categorize queries.
1299+
Example: {"team": "data-eng", "application": "etl"}
1300+
12961301
:returns self
12971302
"""
12981303

@@ -1333,6 +1338,7 @@ def execute(
13331338
async_op=False,
13341339
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
13351340
row_limit=self.row_limit,
1341+
query_tags=query_tags,
13361342
)
13371343

13381344
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -1349,13 +1355,17 @@ def execute_async(
13491355
operation: str,
13501356
parameters: Optional[TParameterCollection] = None,
13511357
enforce_embedded_schema_correctness=False,
1358+
query_tags: Optional[Dict[str, Optional[str]]] = None,
13521359
) -> "Cursor":
13531360
"""
13541361
13551362
Execute a query and do not wait for it to complete and just move ahead
13561363
13571364
:param operation:
13581365
:param parameters:
1366+
:param query_tags: Optional dictionary of query tags to apply for this query only.
1367+
Tags are key-value pairs that can be used to identify and categorize queries.
1368+
Example: {"team": "data-eng", "application": "etl"}
13591369
:return:
13601370
"""
13611371

@@ -1392,6 +1402,7 @@ def execute_async(
13921402
async_op=True,
13931403
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
13941404
row_limit=self.row_limit,
1405+
query_tags=query_tags,
13951406
)
13961407

13971408
return self

src/databricks/sql/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,46 @@ def concat_table_chunks(
898898
return pyarrow.concat_tables(table_chunks)
899899

900900

901+
def serialize_query_tags(query_tags: Optional[Dict[str, Optional[str]]]) -> Optional[str]:
902+
"""
903+
Serialize query_tags dictionary to a string format.
904+
905+
Format: "key1:value1,key2:value2"
906+
Special cases:
907+
- If value is None, omit the colon and value (e.g., "key1:value1,key2,key3:value3")
908+
- Escape special characters (:, ,, \\) in values with a leading backslash
909+
- Keys are not escaped (assumed to be controlled identifiers)
910+
911+
Args:
912+
query_tags: Dictionary of query tags where keys are strings and values are optional strings
913+
914+
Returns:
915+
Serialized string or None if query_tags is None or empty
916+
"""
917+
if not query_tags:
918+
return None
919+
920+
def escape_value(value: str) -> str:
921+
"""Escape special characters in tag values."""
922+
# Escape backslash first to avoid double-escaping
923+
value = value.replace("\\", "\\\\")
924+
# Escape colon and comma
925+
value = value.replace(":", "\\:")
926+
value = value.replace(",", "\\,")
927+
return value
928+
929+
serialized_parts = []
930+
for key, value in query_tags.items():
931+
if value is None:
932+
# No colon or value when value is None
933+
serialized_parts.append(key)
934+
else:
935+
escaped_value = escape_value(value)
936+
serialized_parts.append(f"{key}:{escaped_value}")
937+
938+
return ",".join(serialized_parts)
939+
940+
901941
def build_client_context(server_hostname: str, version: str, **kwargs):
902942
"""Build ClientContext for HTTP client configuration."""
903943
from databricks.sql.auth.common import ClientContext

tests/unit/test_util.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
convert_to_assigned_datatypes_in_column_table,
77
ColumnTable,
88
concat_table_chunks,
9+
serialize_query_tags,
910
)
1011

1112
try:
@@ -161,3 +162,65 @@ def test_concat_table_chunks__incorrect_column_names_error(self):
161162

162163
with pytest.raises(ValueError):
163164
concat_table_chunks([column_table1, column_table2])
165+
166+
def test_serialize_query_tags_basic(self):
167+
"""Test basic query tags serialization"""
168+
query_tags = {"team": "data-eng", "application": "etl"}
169+
result = serialize_query_tags(query_tags)
170+
assert result == "team:data-eng,application:etl"
171+
172+
def test_serialize_query_tags_with_none_value(self):
173+
"""Test query tags with None value (should omit colon and value)"""
174+
query_tags = {"key1": "value1", "key2": None, "key3": "value3"}
175+
result = serialize_query_tags(query_tags)
176+
assert result == "key1:value1,key2,key3:value3"
177+
178+
def test_serialize_query_tags_with_special_chars(self):
179+
"""Test query tags with special characters (colon, comma, backslash)"""
180+
query_tags = {
181+
"key1": "value:with:colons",
182+
"key2": "value,with,commas",
183+
"key3": "value\\with\\backslashes",
184+
}
185+
result = serialize_query_tags(query_tags)
186+
assert (
187+
result
188+
== "key1:value\\:with\\:colons,key2:value\\,with\\,commas,key3:value\\\\with\\\\backslashes"
189+
)
190+
191+
def test_serialize_query_tags_with_mixed_special_chars(self):
192+
"""Test query tags with mixed special characters"""
193+
query_tags = {"key1": "a:b,c\\d"}
194+
result = serialize_query_tags(query_tags)
195+
assert result == "key1:a\\:b\\,c\\\\d"
196+
197+
def test_serialize_query_tags_empty_dict(self):
198+
"""Test serialization with empty dictionary"""
199+
query_tags = {}
200+
result = serialize_query_tags(query_tags)
201+
assert result is None
202+
203+
def test_serialize_query_tags_none(self):
204+
"""Test serialization with None input"""
205+
result = serialize_query_tags(None)
206+
assert result is None
207+
208+
def test_serialize_query_tags_with_special_chars_in_key(self):
209+
"""Test query tags with special characters in keys (keys are not escaped)"""
210+
query_tags = {
211+
"key:with:colons": "value1",
212+
"key,with,commas": "value2",
213+
"key\\with\\backslashes": "value3",
214+
}
215+
result = serialize_query_tags(query_tags)
216+
# Keys are not escaped, only values are
217+
assert (
218+
result
219+
== "key:with:colons:value1,key,with,commas:value2,key\\with\\backslashes:value3"
220+
)
221+
222+
def test_serialize_query_tags_all_none_values(self):
223+
"""Test query tags where all values are None"""
224+
query_tags = {"key1": None, "key2": None, "key3": None}
225+
result = serialize_query_tags(query_tags)
226+
assert result == "key1,key2,key3"

0 commit comments

Comments
 (0)