From 3cd4357cb9b92a22a317239471a05ecaa08e7d1f Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Sat, 21 Feb 2026 02:14:09 +0000 Subject: [PATCH] feat: Support pd.col expressions with .loc and getitem --- bigframes/core/array_value.py | 7 ++++++- bigframes/core/indexers.py | 16 +++++++++++++++- bigframes/dataframe.py | 7 ++++++- tests/unit/test_col.py | 18 ++++++++++++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index ccec1f9b954..b20c6561ea9 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -204,7 +204,12 @@ def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue return self.filter(predicate) def filter(self, predicate: ex.Expression): - return ArrayValue(nodes.FilterNode(child=self.node, predicate=predicate)) + if predicate.is_scalar_expr: + return ArrayValue(nodes.FilterNode(child=self.node, predicate=predicate)) + else: + arr, filter_ids = self.compute_general_expression([predicate]) + arr = arr.filter_by_id(filter_ids[0]) + return arr.drop_columns(filter_ids) def order_by( self, by: Sequence[OrderingExpression], is_total_order: bool = False diff --git a/bigframes/core/indexers.py b/bigframes/core/indexers.py index c60e40880b7..987edf2339c 100644 --- a/bigframes/core/indexers.py +++ b/bigframes/core/indexers.py @@ -23,6 +23,7 @@ import pandas as pd import bigframes.core.blocks +import bigframes.core.col import bigframes.core.expression as ex import bigframes.core.guid as guid import bigframes.core.indexes as indexes @@ -36,7 +37,11 @@ if typing.TYPE_CHECKING: LocSingleKey = Union[ - bigframes.series.Series, indexes.Index, slice, bigframes.core.scalar.Scalar + bigframes.series.Series, + indexes.Index, + slice, + bigframes.core.scalar.Scalar, + bigframes.core.col.Expression, ] @@ -309,6 +314,15 @@ def _loc_getitem_series_or_dataframe( raise NotImplementedError( f"loc does not yet support indexing with a slice. {constants.FEEDBACK_LINK}" ) + if isinstance(key, bigframes.core.col.Expression): + label_to_col_ref = { + label: ex.deref(id) + for id, label in series_or_dataframe._block.col_id_to_label.items() + } + resolved_expr = key._value.bind_variables(label_to_col_ref) + result = series_or_dataframe.copy() + result._set_block(series_or_dataframe._block.filter(resolved_expr)) + return result if callable(key): raise NotImplementedError( f"loc does not yet support indexing with a callable. {constants.FEEDBACK_LINK}" diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 2a22fc4487d..2c734f2943e 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -623,13 +623,18 @@ def __getitem__( ): # No return type annotations (like pandas) as type cannot always be determined statically # NOTE: This implements the operations described in # https://pandas.pydata.org/docs/getting_started/intro_tutorials/03_subset_data.html + import bigframes.core.col + import bigframes.pandas - if isinstance(key, bigframes.series.Series): + if isinstance(key, bigframes.pandas.Series): return self._getitem_bool_series(key) if isinstance(key, slice): return self.iloc[key] + if isinstance(key, bigframes.core.col.Expression): + return self.loc[key] + # TODO(tswast): Fix this pylance warning: Class overlaps "Hashable" # unsafely and could produce a match at runtime if isinstance(key, blocks.Label): diff --git a/tests/unit/test_col.py b/tests/unit/test_col.py index e01c25ddd2c..9c9088e037c 100644 --- a/tests/unit/test_col.py +++ b/tests/unit/test_col.py @@ -158,3 +158,21 @@ def test_pd_col_binary_bool_operators(scalars_dfs, op): pd_result = scalars_pandas_df.assign(**pd_kwargs) assert_frame_equal(bf_result, pd_result) + + +def test_loc_with_pd_col(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.loc[bpd.col("float64_col") > 4].to_pandas() + pd_result = scalars_pandas_df.loc[pd.col("float64_col") > 4] # type: ignore + + assert_frame_equal(bf_result, pd_result) + + +def test_getitem_with_pd_col(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df[bpd.col("float64_col") > 4].to_pandas() + pd_result = scalars_pandas_df[pd.col("float64_col") > 4] # type: ignore + + assert_frame_equal(bf_result, pd_result)