diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 5fe9f306d5..477ca91366 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -606,7 +606,7 @@ def generate_table( model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series], *, - output_schema: str, + output_schema: Union[str, Mapping[str, str]], temperature: Optional[float] = None, top_p: Optional[float] = None, max_output_tokens: Optional[int] = None, @@ -642,8 +642,10 @@ def generate_table( treated as the 'prompt' column. If a DataFrame is provided, it must contain a 'prompt' column, or you must rename the column you wish to generate table to 'prompt'. - output_schema (str): - A string defining the output schema (e.g., "col1 STRING, col2 INT64"). + output_schema (str | Mapping[str, str]): + A string defining the output schema (e.g., "col1 STRING, col2 INT64"), + or a mapping value that specifies the schema of the output, in the form {field_name: data_type}. + Supported data types include `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`. temperature (float, optional): A FLOAT64 value that is used for sampling promiscuity. The value must be in the range ``[0.0, 1.0]``. @@ -666,8 +668,17 @@ def generate_table( model_name, session = bq_utils.get_model_name_and_session(model, data) table_sql = bq_utils.to_sql(data) + if isinstance(output_schema, Mapping): + output_schema_str = ", ".join( + [f"{name} {sql_type}" for name, sql_type in output_schema.items()] + ) + # Validate user input + output_schemas.parse_sql_fields(output_schema_str) + else: + output_schema_str = output_schema + struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = { - "output_schema": output_schema + "output_schema": output_schema_str } if temperature is not None: struct_fields_bq["temperature"] = temperature diff --git a/tests/system/large/bigquery/test_ai.py b/tests/system/large/bigquery/test_ai.py index 86cf4d7f00..668581c627 100644 --- a/tests/system/large/bigquery/test_ai.py +++ b/tests/system/large/bigquery/test_ai.py @@ -111,3 +111,20 @@ def test_generate_table(text_model): assert "creator" in result.columns # The model may not always return the exact number of rows requested. assert len(result) > 0 + + +def test_generate_table_with_mapping_schema(text_model): + df = bpd.DataFrame( + {"prompt": ["Generate a table of 2 programming languages and their creators."]} + ) + + result = ai.generate_table( + text_model, + df, + output_schema={"language": "STRING", "creator": "STRING"}, + ) + + assert "language" in result.columns + assert "creator" in result.columns + # The model may not always return the exact number of rows requested. + assert len(result) > 0 diff --git a/tests/unit/bigquery/test_ai.py b/tests/unit/bigquery/test_ai.py index 796e86f924..c73e63b9db 100644 --- a/tests/unit/bigquery/test_ai.py +++ b/tests/unit/bigquery/test_ai.py @@ -269,6 +269,32 @@ def test_generate_table_with_options(mock_dataframe, mock_session): ) +def test_generate_table_with_mapping_schema(mock_dataframe, mock_session): + model_name = "project.dataset.model" + + bbq.ai.generate_table( + model_name, + mock_dataframe, + output_schema={"col1": "STRING", "col2": "INT64"}, + ) + + mock_session.read_gbq_query.assert_called_once() + query = mock_session.read_gbq_query.call_args[0][0] + + # Normalize whitespace for comparison + query = " ".join(query.split()) + + expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE(" + expected_part_2 = f"MODEL `{model_name}`," + expected_part_3 = "(SELECT * FROM my_table)," + expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)" + + assert expected_part_1 in query + assert expected_part_2 in query + assert expected_part_3 in query + assert expected_part_4 in query + + @mock.patch("bigframes.pandas.read_pandas") def test_generate_text_with_pandas_dataframe( read_pandas_mock, mock_dataframe, mock_session