Skip to content

Commit f7fd189

Browse files
authored
feat: Update bigquery.ai.generate_table output_schema to allow Mapping type (#2463)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent ca9fb13 commit f7fd189

File tree

3 files changed

+58
-4
lines changed

3 files changed

+58
-4
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def generate_table(
606606
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
607607
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
608608
*,
609-
output_schema: str,
609+
output_schema: Union[str, Mapping[str, str]],
610610
temperature: Optional[float] = None,
611611
top_p: Optional[float] = None,
612612
max_output_tokens: Optional[int] = None,
@@ -642,8 +642,10 @@ def generate_table(
642642
treated as the 'prompt' column. If a DataFrame is provided, it
643643
must contain a 'prompt' column, or you must rename the column you
644644
wish to generate table to 'prompt'.
645-
output_schema (str):
646-
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
645+
output_schema (str | Mapping[str, str]):
646+
A string defining the output schema (e.g., "col1 STRING, col2 INT64"),
647+
or a mapping value that specifies the schema of the output, in the form {field_name: data_type}.
648+
Supported data types include `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
647649
temperature (float, optional):
648650
A FLOAT64 value that is used for sampling promiscuity. The value
649651
must be in the range ``[0.0, 1.0]``.
@@ -666,8 +668,17 @@ def generate_table(
666668
model_name, session = bq_utils.get_model_name_and_session(model, data)
667669
table_sql = bq_utils.to_sql(data)
668670

671+
if isinstance(output_schema, Mapping):
672+
output_schema_str = ", ".join(
673+
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
674+
)
675+
# Validate user input
676+
output_schemas.parse_sql_fields(output_schema_str)
677+
else:
678+
output_schema_str = output_schema
679+
669680
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
670-
"output_schema": output_schema
681+
"output_schema": output_schema_str
671682
}
672683
if temperature is not None:
673684
struct_fields_bq["temperature"] = temperature

tests/system/large/bigquery/test_ai.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,20 @@ def test_generate_table(text_model):
111111
assert "creator" in result.columns
112112
# The model may not always return the exact number of rows requested.
113113
assert len(result) > 0
114+
115+
116+
def test_generate_table_with_mapping_schema(text_model):
117+
df = bpd.DataFrame(
118+
{"prompt": ["Generate a table of 2 programming languages and their creators."]}
119+
)
120+
121+
result = ai.generate_table(
122+
text_model,
123+
df,
124+
output_schema={"language": "STRING", "creator": "STRING"},
125+
)
126+
127+
assert "language" in result.columns
128+
assert "creator" in result.columns
129+
# The model may not always return the exact number of rows requested.
130+
assert len(result) > 0

tests/unit/bigquery/test_ai.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,32 @@ def test_generate_table_with_options(mock_dataframe, mock_session):
269269
)
270270

271271

272+
def test_generate_table_with_mapping_schema(mock_dataframe, mock_session):
273+
model_name = "project.dataset.model"
274+
275+
bbq.ai.generate_table(
276+
model_name,
277+
mock_dataframe,
278+
output_schema={"col1": "STRING", "col2": "INT64"},
279+
)
280+
281+
mock_session.read_gbq_query.assert_called_once()
282+
query = mock_session.read_gbq_query.call_args[0][0]
283+
284+
# Normalize whitespace for comparison
285+
query = " ".join(query.split())
286+
287+
expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
288+
expected_part_2 = f"MODEL `{model_name}`,"
289+
expected_part_3 = "(SELECT * FROM my_table),"
290+
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"
291+
292+
assert expected_part_1 in query
293+
assert expected_part_2 in query
294+
assert expected_part_3 in query
295+
assert expected_part_4 in query
296+
297+
272298
@mock.patch("bigframes.pandas.read_pandas")
273299
def test_generate_text_with_pandas_dataframe(
274300
read_pandas_mock, mock_dataframe, mock_session

0 commit comments

Comments
 (0)