Skip to content

Commit 6124155

Browse files
committed
feat: add deprecation warning for Expr passed to literal-only args
- Introduced shared `_warn_if_expr_for_literal_arg` in `functions/__init__.py` - Added `DeprecationWarning` for the following methods when `Expr` is passed as argument: - `encode(..., encoding=Expr)` - `decode(..., encoding=Expr)` - `digest(..., method=Expr)` - `arrow_cast(..., data_type=Expr)` - `arrow_try_cast(..., data_type=Expr)` - `arrow_metadata(..., key=Expr)` test: update tests to check for warnings - Implemented tests in `test_functions.py` to ensure: - Warning is raised for `Expr` form - No warning for native literal form
1 parent 5b7ad64 commit 6124155

2 files changed

Lines changed: 72 additions & 0 deletions

File tree

python/datafusion/functions/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
7373
)
7474

7575

76+
def _warn_if_expr_for_literal_arg(
77+
value: Any, function_name: str, arg_name: str
78+
) -> None:
79+
if isinstance(value, Expr):
80+
_warn_expr_for_literal_arg(function_name, arg_name)
81+
82+
7683
__all__ = [
7784
"abs",
7885
"acos",
@@ -437,6 +444,7 @@ def encode(expr: Expr, encoding: Expr | str) -> Expr:
437444
>>> result.collect_column("enc")[0].as_py()
438445
'aGVsbG8'
439446
"""
447+
_warn_if_expr_for_literal_arg(encoding, "encode", "encoding")
440448
encoding = coerce_to_expr(encoding)
441449
return Expr(f.encode(expr.expr, encoding.expr))
442450

@@ -452,6 +460,7 @@ def decode(expr: Expr, encoding: Expr | str) -> Expr:
452460
>>> result.collect_column("dec")[0].as_py()
453461
b'hello'
454462
"""
463+
_warn_if_expr_for_literal_arg(encoding, "decode", "encoding")
455464
encoding = coerce_to_expr(encoding)
456465
return Expr(f.decode(expr.expr, encoding.expr))
457466

@@ -742,6 +751,7 @@ def digest(value: Expr, method: Expr | str) -> Expr:
742751
>>> len(result.collect_column("d")[0].as_py()) > 0
743752
True
744753
"""
754+
_warn_if_expr_for_literal_arg(method, "digest", "method")
745755
method = coerce_to_expr(method)
746756
return Expr(f.digest(value.expr, method.expr))
747757

@@ -3096,6 +3106,7 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
30963106
>>> result.collect_column("c")[0].as_py()
30973107
1.0
30983108
"""
3109+
_warn_if_expr_for_literal_arg(data_type, "arrow_cast", "data_type")
30993110
if isinstance(data_type, pa.DataType):
31003111
return expr.cast(data_type)
31013112
if isinstance(data_type, str):
@@ -3128,6 +3139,7 @@ def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
31283139
>>> result.collect_column("c")[0].as_py() is None
31293140
True
31303141
"""
3142+
_warn_if_expr_for_literal_arg(data_type, "arrow_try_cast", "data_type")
31313143
if isinstance(data_type, pa.DataType):
31323144
return expr.try_cast(data_type)
31333145
if isinstance(data_type, str):
@@ -3235,6 +3247,7 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
32353247
"""
32363248
if key is None:
32373249
return Expr(f.arrow_metadata(expr.expr))
3250+
_warn_if_expr_for_literal_arg(key, "arrow_metadata", "key")
32383251
if isinstance(key, str):
32393252
key = Expr.string_literal(key)
32403253
return Expr(f.arrow_metadata(expr.expr, key.expr))

python/tests/test_functions.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,6 +2355,65 @@ def test_regexp_replace_native(self):
23552355
).collect()
23562356
assert result[0].column(0)[0].as_py() == "aX bX cX"
23572357

2358+
@pytest.mark.parametrize(
2359+
("func", "arg_name", "expr"),
2360+
[
2361+
pytest.param(
2362+
f.encode,
2363+
"encoding",
2364+
lambda: f.encode(column("a"), literal("base64")),
2365+
id="encode-encoding",
2366+
),
2367+
pytest.param(
2368+
f.decode,
2369+
"encoding",
2370+
lambda: f.decode(column("a"), literal("base64")),
2371+
id="decode-encoding",
2372+
),
2373+
pytest.param(
2374+
f.digest,
2375+
"method",
2376+
lambda: f.digest(column("a"), literal("sha256")),
2377+
id="digest-method",
2378+
),
2379+
pytest.param(
2380+
f.arrow_cast,
2381+
"data_type",
2382+
lambda: f.arrow_cast(column("a"), literal("Float64")),
2383+
id="arrow-cast-data-type",
2384+
),
2385+
pytest.param(
2386+
f.arrow_try_cast,
2387+
"data_type",
2388+
lambda: f.arrow_try_cast(column("a"), literal("Float64")),
2389+
id="arrow-try-cast-data-type",
2390+
),
2391+
pytest.param(
2392+
f.arrow_metadata,
2393+
"key",
2394+
lambda: f.arrow_metadata(column("a"), literal("k")),
2395+
id="arrow-metadata-key",
2396+
),
2397+
],
2398+
)
2399+
def test_literal_only_expr_args_warn_deprecated(self, func, arg_name, expr):
2400+
with pytest.warns(
2401+
DeprecationWarning,
2402+
match=rf"Passing Expr for {func.__name__}\(\) argument '{arg_name}' is deprecated",
2403+
):
2404+
result = expr()
2405+
assert result is not None
2406+
2407+
def test_literal_only_native_args_do_not_warn(self):
2408+
with warnings.catch_warnings():
2409+
warnings.simplefilter("error", DeprecationWarning)
2410+
assert f.encode(column("a"), "base64") is not None
2411+
assert f.decode(column("a"), "base64") is not None
2412+
assert f.digest(column("a"), "sha256") is not None
2413+
assert f.arrow_cast(column("a"), "Float64") is not None
2414+
assert f.arrow_try_cast(column("a"), pa.float64()) is not None
2415+
assert f.arrow_metadata(column("a"), "k") is not None
2416+
23582417
def test_backward_compat_with_lit(self):
23592418
"""Verify that existing code using lit() still works."""
23602419
ctx = SessionContext()

0 commit comments

Comments
 (0)