Skip to content

Commit 4e223ff

Browse files
authored
Add deprecation warnings for Expr passed to confirmed literal-only function arguments (#1605)
* feat: add deprecation warning for Expr passed to literal-only args - Introduced shared `_warn_if_expr_for_literal_arg` in `functions/__init__.py` - Added `DeprecationWarning` for the following methods when `Expr` is passed as argument: - `encode(..., encoding=Expr)` - `decode(..., encoding=Expr)` - `digest(..., method=Expr)` - `arrow_cast(..., data_type=Expr)` - `arrow_try_cast(..., data_type=Expr)` - `arrow_metadata(..., key=Expr)` test: update tests to check for warnings - Implemented tests in `test_functions.py` to ensure: - Warning is raised for `Expr` form - No warning for native literal form * fix(tests): resolve E501 line length issue in test_functions.py * feat: consolidate warning helpers and update temporal function calls - Collapsed warning helpers into a single function `_warn_if_expr_for_literal_arg` in `python/datafusion/functions/__init__.py`. - Updated callers of temporal functions `_date_part` and `_date_trunc` in `python/datafusion/functions/__init__.py`. - Modified the digest behavior test in `python/tests/test_functions.py` to use native method strings. - Updated the encode/decode behavior test to use native "base64" in `python/tests/test_functions.py`. * docs: update function examples in common-operations user guide - Replaced `string_literal("Float64")` with `"Float64"` in examples. - Replaced `str_lit("Int32")` with `"Int32"` in examples. - Removed unused `string_literal` and `str_lit` imports.
1 parent 5b7ad64 commit 4e223ff

3 files changed

Lines changed: 87 additions & 20 deletions

File tree

docs/source/user-guide/common-operations/functions.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ df = ctx.table("pokemon")
5252
DataFusion offers mathematical functions such as {py:func}`~datafusion.functions.pow` or {py:func}`~datafusion.functions.log`
5353

5454
```{code-cell} ipython3
55-
from datafusion import col, literal, string_literal, str_lit
55+
from datafusion import col, literal
5656
from datafusion import functions as f
5757
5858
df.select(
@@ -122,8 +122,8 @@ Casting expressions to different data types using {py:func}`~datafusion.function
122122

123123
```{code-cell} ipython3
124124
df.select(
125-
f.arrow_cast(col('"Total"'), string_literal("Float64")).alias("total_as_float"),
126-
f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int")
125+
f.arrow_cast(col('"Total"'), "Float64").alias("total_as_float"),
126+
f.arrow_cast(col('"Total"'), "Int32").alias("total_as_int")
127127
)
128128
```
129129

python/datafusion/functions/__init__.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,16 @@
6464
from datafusion.functions import spark
6565

6666

67-
def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
68-
warnings.warn(
69-
f"Passing Expr for {function_name}() argument {arg_name!r} is deprecated; "
70-
"pass a Python literal instead.",
71-
DeprecationWarning,
72-
stacklevel=4,
73-
)
67+
def _warn_if_expr_for_literal_arg(
68+
value: Any, function_name: str, arg_name: str
69+
) -> None:
70+
if isinstance(value, Expr):
71+
warnings.warn(
72+
f"Passing Expr for {function_name}() argument {arg_name!r} is deprecated; "
73+
"pass a Python literal instead.",
74+
DeprecationWarning,
75+
stacklevel=3,
76+
)
7477

7578

7679
__all__ = [
@@ -437,6 +440,7 @@ def encode(expr: Expr, encoding: Expr | str) -> Expr:
437440
>>> result.collect_column("enc")[0].as_py()
438441
'aGVsbG8'
439442
"""
443+
_warn_if_expr_for_literal_arg(encoding, "encode", "encoding")
440444
encoding = coerce_to_expr(encoding)
441445
return Expr(f.encode(expr.expr, encoding.expr))
442446

@@ -452,6 +456,7 @@ def decode(expr: Expr, encoding: Expr | str) -> Expr:
452456
>>> result.collect_column("dec")[0].as_py()
453457
b'hello'
454458
"""
459+
_warn_if_expr_for_literal_arg(encoding, "decode", "encoding")
455460
encoding = coerce_to_expr(encoding)
456461
return Expr(f.decode(expr.expr, encoding.expr))
457462

@@ -742,6 +747,7 @@ def digest(value: Expr, method: Expr | str) -> Expr:
742747
>>> len(result.collect_column("d")[0].as_py()) > 0
743748
True
744749
"""
750+
_warn_if_expr_for_literal_arg(method, "digest", "method")
745751
method = coerce_to_expr(method)
746752
return Expr(f.digest(value.expr, method.expr))
747753

@@ -2723,8 +2729,7 @@ def date_part(part: Expr | str, date: Expr) -> Expr:
27232729

27242730

27252731
def _date_part(part: Expr | str, date: Expr, function_name: str) -> Expr:
2726-
if isinstance(part, Expr):
2727-
_warn_expr_for_literal_arg(function_name, "part")
2732+
_warn_if_expr_for_literal_arg(part, function_name, "part")
27282733
part = coerce_to_expr(part)
27292734
return Expr(f.date_part(part.expr, date.expr))
27302735

@@ -2760,8 +2765,7 @@ def date_trunc(part: Expr | str, date: Expr) -> Expr:
27602765

27612766

27622767
def _date_trunc(part: Expr | str, date: Expr, function_name: str) -> Expr:
2763-
if isinstance(part, Expr):
2764-
_warn_expr_for_literal_arg(function_name, "part")
2768+
_warn_if_expr_for_literal_arg(part, function_name, "part")
27652769
part = coerce_to_expr(part)
27662770
return Expr(f.date_trunc(part.expr, date.expr))
27672771

@@ -3096,6 +3100,7 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
30963100
>>> result.collect_column("c")[0].as_py()
30973101
1.0
30983102
"""
3103+
_warn_if_expr_for_literal_arg(data_type, "arrow_cast", "data_type")
30993104
if isinstance(data_type, pa.DataType):
31003105
return expr.cast(data_type)
31013106
if isinstance(data_type, str):
@@ -3128,6 +3133,7 @@ def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
31283133
>>> result.collect_column("c")[0].as_py() is None
31293134
True
31303135
"""
3136+
_warn_if_expr_for_literal_arg(data_type, "arrow_try_cast", "data_type")
31313137
if isinstance(data_type, pa.DataType):
31323138
return expr.try_cast(data_type)
31333139
if isinstance(data_type, str):
@@ -3235,6 +3241,7 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
32353241
"""
32363242
if key is None:
32373243
return Expr(f.arrow_metadata(expr.expr))
3244+
_warn_if_expr_for_literal_arg(key, "arrow_metadata", "key")
32383245
if isinstance(key, str):
32393246
key = Expr.string_literal(key)
32403247
return Expr(f.arrow_metadata(expr.expr, key.expr))

python/tests/test_functions.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def test_string_functions(df, function, expected_result):
10141014

10151015
def test_hash_functions(df):
10161016
exprs = [
1017-
f.digest(column("a"), literal(m))
1017+
f.digest(column("a"), m)
10181018
for m in (
10191019
"md5",
10201020
"sha224",
@@ -1602,12 +1602,10 @@ def test_regr_funcs_df(func, expected):
16021602

16031603
def test_binary_string_functions(df):
16041604
df = df.select(
1605-
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
1605+
f.encode(column("a").cast(pa.string()), "base64"),
16061606
f.decode(
1607-
f.encode(
1608-
column("a").cast(pa.string()), literal("base64").cast(pa.string())
1609-
),
1610-
literal("base64").cast(pa.string()),
1607+
f.encode(column("a").cast(pa.string()), "base64"),
1608+
"base64",
16111609
),
16121610
)
16131611
result = df.collect()
@@ -2355,6 +2353,68 @@ def test_regexp_replace_native(self):
23552353
).collect()
23562354
assert result[0].column(0)[0].as_py() == "aX bX cX"
23572355

2356+
@pytest.mark.parametrize(
2357+
("func", "arg_name", "expr"),
2358+
[
2359+
pytest.param(
2360+
f.encode,
2361+
"encoding",
2362+
lambda: f.encode(column("a"), literal("base64")),
2363+
id="encode-encoding",
2364+
),
2365+
pytest.param(
2366+
f.decode,
2367+
"encoding",
2368+
lambda: f.decode(column("a"), literal("base64")),
2369+
id="decode-encoding",
2370+
),
2371+
pytest.param(
2372+
f.digest,
2373+
"method",
2374+
lambda: f.digest(column("a"), literal("sha256")),
2375+
id="digest-method",
2376+
),
2377+
pytest.param(
2378+
f.arrow_cast,
2379+
"data_type",
2380+
lambda: f.arrow_cast(column("a"), literal("Float64")),
2381+
id="arrow-cast-data-type",
2382+
),
2383+
pytest.param(
2384+
f.arrow_try_cast,
2385+
"data_type",
2386+
lambda: f.arrow_try_cast(column("a"), literal("Float64")),
2387+
id="arrow-try-cast-data-type",
2388+
),
2389+
pytest.param(
2390+
f.arrow_metadata,
2391+
"key",
2392+
lambda: f.arrow_metadata(column("a"), literal("k")),
2393+
id="arrow-metadata-key",
2394+
),
2395+
],
2396+
)
2397+
def test_literal_only_expr_args_warn_deprecated(self, func, arg_name, expr):
2398+
with pytest.warns(
2399+
DeprecationWarning,
2400+
match=(
2401+
rf"Passing Expr for {func.__name__}\(\) argument "
2402+
rf"'{arg_name}' is deprecated"
2403+
),
2404+
):
2405+
result = expr()
2406+
assert result is not None
2407+
2408+
def test_literal_only_native_args_do_not_warn(self):
2409+
with warnings.catch_warnings():
2410+
warnings.simplefilter("error", DeprecationWarning)
2411+
assert f.encode(column("a"), "base64") is not None
2412+
assert f.decode(column("a"), "base64") is not None
2413+
assert f.digest(column("a"), "sha256") is not None
2414+
assert f.arrow_cast(column("a"), "Float64") is not None
2415+
assert f.arrow_try_cast(column("a"), pa.float64()) is not None
2416+
assert f.arrow_metadata(column("a"), "k") is not None
2417+
23582418
def test_backward_compat_with_lit(self):
23592419
"""Verify that existing code using lit() still works."""
23602420
ctx = SessionContext()

0 commit comments

Comments
 (0)