From 30361b7feec20a9ba1599351c88be7c36793c335 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 18 Feb 2026 22:38:57 +0000 Subject: [PATCH] fix: allow IsIn Ops with same dtypes regardless nullable - Update Ibis isin_op_impl to compare types by name, allowing comparisons between columns and literals with different nullability. - Update SQLGlot IsInOp implementation to use dtypes.can_compare for more robust type compatibility checking. - Improve dtypes.can_compare to gracefully handle type coercion failures. - Migrate TPCH verification script to tests/system/large/test_tpch.py for better integration with the test suite. --- .../ibis_compiler/scalar_op_registry.py | 2 +- .../sqlglot/expressions/comparison_ops.py | 7 +- bigframes/dtypes.py | 7 +- scripts/tpch_result_verify.py | 128 ------------------ tests/system/large/test_tpch.py | 101 ++++++++++++++ 5 files changed, 108 insertions(+), 137 deletions(-) delete mode 100644 scripts/tpch_result_verify.py create mode 100644 tests/system/large/test_tpch.py diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 519b2c94426..9632e65e4d4 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -962,7 +962,7 @@ def isin_op_impl(x: ibis_types.Value, op: ops.IsInOp): # to actually cast it, as that could be lossy (eg float -> int) item_inferred_type = ibis_types.literal(item).type() if ( - x.type() == item_inferred_type + x.type().name == item_inferred_type.name or x.type().is_numeric() and item_inferred_type.is_numeric() ): diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index 550a6c25be2..f767314be74 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -33,16 +33,11 @@ @register_unary_op(ops.IsInOp, pass_op=True) def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression: values = [] - is_numeric_expr = dtypes.is_numeric(expr.dtype, include_bool=False) for value in op.values: if _is_null(value): continue dtype = dtypes.bigframes_type(type(value)) - if ( - expr.dtype == dtype - or is_numeric_expr - and dtypes.is_numeric(dtype, include_bool=False) - ): + if dtypes.can_compare(expr.dtype, dtype): values.append(sge.convert(value)) if op.match_nulls: diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 8caddcdb002..a2abe9b817a 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -370,8 +370,11 @@ def is_comparable(type_: ExpressionType) -> bool: def can_compare(type1: ExpressionType, type2: ExpressionType) -> bool: - coerced_type = coerce_to_common(type1, type2) - return is_comparable(coerced_type) + try: + coerced_type = coerce_to_common(type1, type2) + return is_comparable(coerced_type) + except TypeError: + return False def get_struct_fields(type_: ExpressionType) -> dict[str, Dtype]: diff --git a/scripts/tpch_result_verify.py b/scripts/tpch_result_verify.py deleted file mode 100644 index 0c932f6eac8..00000000000 --- a/scripts/tpch_result_verify.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re - -from google.cloud import bigquery -import pandas as pd -from tqdm import tqdm - -import bigframes - -project_id = "bigframes-dev-perf" -dataset_id = "tpch_0001g" -dataset = { - "line_item_ds": f"bigframes-dev-perf.{dataset_id}.LINEITEM", - "region_ds": f"bigframes-dev-perf.{dataset_id}.REGION", - "nation_ds": f"bigframes-dev-perf.{dataset_id}.NATION", - "supplier_ds": f"bigframes-dev-perf.{dataset_id}.SUPPLIER", - "part_ds": f"bigframes-dev-perf.{dataset_id}.PART", - "part_supp_ds": f"bigframes-dev-perf.{dataset_id}.PARTSUPP", - "customer_ds": f"bigframes-dev-perf.{dataset_id}.CUSTOMER", - "orders_ds": f"bigframes-dev-perf.{dataset_id}.ORDERS", -} - - -def _execute_query(query): - client = bigquery.Client() - job_config = bigquery.QueryJobConfig(use_query_cache=False) - query_job = client.query(query, job_config=job_config) - query_job.result() - df = query_job.to_dataframe() - df.columns = df.columns.str.upper() - return df - - -def _initialize_session(ordered: bool): - context = bigframes.BigQueryOptions( - location="US", ordering_mode="strict" if ordered else "partial" - ) - session = bigframes.Session(context=context) - return session - - -def _verify_result(bigframes_query, sql_result): - exec_globals = {"_initialize_session": _initialize_session} - exec(bigframes_query, exec_globals) - bigframes_result = exec_globals.get("result") - if isinstance(bigframes_result, pd.DataFrame): - pd.testing.assert_frame_equal( - sql_result.reset_index(drop=True), - bigframes_result.reset_index(drop=True), - check_dtype=False, - ) - else: - assert sql_result.shape == (1, 1) - sql_scalar = sql_result.iloc[0, 0] - assert sql_scalar == bigframes_result - - -def verify(query_num=None): - range_iter = range(1, 23) if query_num is None else [query_num] - for i in tqdm(range_iter, desc="Processing queries"): - if query_num is not None and i != query_num: - continue - - # Execute SQL: - sql_file_path = f"third_party/bigframes_vendored/tpch/sql_queries/q{i}.sql" - with open(sql_file_path, "r") as f: - sql_query = f.read() - sql_query = sql_query.format(**dataset) - file_path = f"third_party/bigframes_vendored/tpch/queries/q{i}.py" - if os.path.exists(file_path): - with open(file_path, "r") as file: - file_content = file.read() - - file_content = re.sub( - r"next\((\w+)\.to_pandas_batches\((.*?)\)\)", - r"return \1.to_pandas()", - file_content, - ) - file_content = re.sub(r"_\s*=\s*(\w+)", r"return \1", file_content) - sql_result = _execute_query(sql_query) - - print(f"Checking {file_path} in ordered session") - bigframes_query = ( - file_content - + f"\nresult = q('{project_id}', '{dataset_id}', _initialize_session(ordered=True))" - ) - _verify_result(bigframes_query, sql_result) - - print(f"Checking {file_path} in unordered session") - bigframes_query = ( - file_content - + f"\nresult = q('{project_id}', '{dataset_id}', _initialize_session(ordered=False))" - ) - _verify_result(bigframes_query, sql_result) - - else: - raise FileNotFoundError(f"File {file_path} not found.") - - -if __name__ == "__main__": - """ - Runs verification of TPCH benchmark script outputs to ensure correctness for a specified query or all queries - with 1GB dataset. - - Example: - python scripts/tpch_result_verify.py -q 15 # Verifies TPCH query number 15 - python scripts/tpch_result_verify.py # Verifies all TPCH queries from 1 to 22 - """ - parser = argparse.ArgumentParser() - parser.add_argument("-q", "--query_number", type=int, default=None) - args = parser.parse_args() - - verify(args.query_number) diff --git a/tests/system/large/test_tpch.py b/tests/system/large/test_tpch.py new file mode 100644 index 00000000000..7cb243b0a39 --- /dev/null +++ b/tests/system/large/test_tpch.py @@ -0,0 +1,101 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from google.cloud import bigquery +import pandas as pd +import pytest + +TPCH_PATH = "third_party/bigframes_vendored/tpch" +PROJECT_ID = "bigframes-dev-perf" +DATASET_ID = "tpch_0001g" +DATASET = { + "line_item_ds": f"{PROJECT_ID}.{DATASET_ID}.LINEITEM", + "region_ds": f"{PROJECT_ID}.{DATASET_ID}.REGION", + "nation_ds": f"{PROJECT_ID}.{DATASET_ID}.NATION", + "supplier_ds": f"{PROJECT_ID}.{DATASET_ID}.SUPPLIER", + "part_ds": f"{PROJECT_ID}.{DATASET_ID}.PART", + "part_supp_ds": f"{PROJECT_ID}.{DATASET_ID}.PARTSUPP", + "customer_ds": f"{PROJECT_ID}.{DATASET_ID}.CUSTOMER", + "orders_ds": f"{PROJECT_ID}.{DATASET_ID}.ORDERS", +} + + +def _execute_sql_query(bigquery_client, sql_query): + sql_query = sql_query.format(**DATASET) + + job_config = bigquery.QueryJobConfig(use_query_cache=False) + query_job = bigquery_client.query(sql_query, job_config=job_config) + query_job.result() + df = query_job.to_dataframe() + df.columns = df.columns.str.upper() + return df + + +def _execute_bigframes_script(session, bigframes_script): + bigframes_script = re.sub( + r"next\((\w+)\.to_pandas_batches\((.*?)\)\)", + r"return \1.to_pandas()", + bigframes_script, + ) + bigframes_script = re.sub(r"_\s*=\s*(\w+)", r"return \1", bigframes_script) + + bigframes_script = ( + bigframes_script + + f"\nresult = q('{PROJECT_ID}', '{DATASET_ID}', _initialize_session)" + ) + exec_globals = {"_initialize_session": session} + exec(bigframes_script, exec_globals) + bigframes_result = exec_globals.get("result") + return bigframes_result + + +def _verify_result(bigframes_result, sql_result): + if isinstance(bigframes_result, pd.DataFrame): + pd.testing.assert_frame_equal( + sql_result.reset_index(drop=True), + bigframes_result.reset_index(drop=True), + check_dtype=False, + ) + else: + assert sql_result.shape == (1, 1) + sql_scalar = sql_result.iloc[0, 0] + assert sql_scalar == bigframes_result + + +@pytest.mark.parametrize("query_num", range(1, 23)) +@pytest.mark.parametrize("ordered", [True, False]) +def test_tpch_correctness(session, unordered_session, query_num, ordered): + """Runs verification of TPCH benchmark script outputs to ensure correctness.""" + # Execute SQL: + sql_file_path = f"{TPCH_PATH}/sql_queries/q{query_num}.sql" + assert os.path.exists(sql_file_path) + with open(sql_file_path, "r") as f: + sql_query = f.read() + + sql_result = _execute_sql_query(session.bqclient, sql_query) + + # Execute BigFrames: + file_path = f"{TPCH_PATH}/queries/q{query_num}.py" + assert os.path.exists(file_path) + with open(file_path, "r") as file: + bigframes_script = file.read() + + bigframes_result = _execute_bigframes_script( + session if ordered else unordered_session, bigframes_script + ) + + _verify_result(bigframes_result, sql_result)