Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def generate_table(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
output_schema: str,
output_schema: Union[str, Mapping[str, str]],
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
Expand Down Expand Up @@ -642,8 +642,10 @@ def generate_table(
treated as the 'prompt' column. If a DataFrame is provided, it
must contain a 'prompt' column, or you must rename the column you
wish to generate table to 'prompt'.
output_schema (str):
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
output_schema (str | Mapping[str, str]):
A string defining the output schema (e.g., "col1 STRING, col2 INT64"),
or a mapping value that specifies the schema of the output, in the form {field_name: data_type}.
Supported data types include `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
temperature (float, optional):
A FLOAT64 value that is used for sampling promiscuity. The value
must be in the range ``[0.0, 1.0]``.
Expand All @@ -666,8 +668,17 @@ def generate_table(
model_name, session = bq_utils.get_model_name_and_session(model, data)
table_sql = bq_utils.to_sql(data)

if isinstance(output_schema, Mapping):
output_schema_str = ", ".join(
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
)
# Validate user input
output_schemas.parse_sql_fields(output_schema_str)
else:
output_schema_str = output_schema

struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
"output_schema": output_schema
"output_schema": output_schema_str
}
if temperature is not None:
struct_fields_bq["temperature"] = temperature
Expand Down
17 changes: 17 additions & 0 deletions tests/system/large/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,20 @@ def test_generate_table(text_model):
assert "creator" in result.columns
# The model may not always return the exact number of rows requested.
assert len(result) > 0


def test_generate_table_with_mapping_schema(text_model):
df = bpd.DataFrame(
{"prompt": ["Generate a table of 2 programming languages and their creators."]}
)

result = ai.generate_table(
text_model,
df,
output_schema={"language": "STRING", "creator": "STRING"},
)

assert "language" in result.columns
assert "creator" in result.columns
# The model may not always return the exact number of rows requested.
assert len(result) > 0
26 changes: 26 additions & 0 deletions tests/unit/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ def test_generate_table_with_options(mock_dataframe, mock_session):
)


def test_generate_table_with_mapping_schema(mock_dataframe, mock_session):
model_name = "project.dataset.model"

bbq.ai.generate_table(
model_name,
mock_dataframe,
output_schema={"col1": "STRING", "col2": "INT64"},
)

mock_session.read_gbq_query.assert_called_once()
query = mock_session.read_gbq_query.call_args[0][0]

# Normalize whitespace for comparison
query = " ".join(query.split())

expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
expected_part_2 = f"MODEL `{model_name}`,"
expected_part_3 = "(SELECT * FROM my_table),"
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"

assert expected_part_1 in query
assert expected_part_2 in query
assert expected_part_3 in query
assert expected_part_4 in query


@mock.patch("bigframes.pandas.read_pandas")
def test_generate_text_with_pandas_dataframe(
read_pandas_mock, mock_dataframe, mock_session
Expand Down
Loading