Skip to content

Commit 29883b6

Browse files
authored
Merge branch 'main' into main_chelsealin_enablesqlglot2
2 parents 5a146b3 + 543ce52 commit 29883b6

File tree

58 files changed

+1233
-561
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1233
-561
lines changed

.librarian/state.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-librarian-generator@sha256:e7cc6823efb073a8a26e7cefdd869f12ec228abfbd2a44aa9a7eacc284023677
1+
image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-librarian-generator@sha256:1a2a85ab507aea26d787c06cc7979decb117164c81dd78a745982dfda80d4f68
22
libraries:
33
- id: bigframes
44
version: 2.35.0

bigframes/bigquery/_operations/ai.py

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,10 @@ def generate_text(
522522
model (bigframes.ml.base.BaseEstimator or str):
523523
The model to use for text generation.
524524
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
525-
The data to generate embeddings for. If a Series is provided, it is
526-
treated as the 'content' column. If a DataFrame is provided, it
527-
must contain a 'content' column, or you must rename the column you
528-
wish to embed to 'content'.
525+
The data to generate text for. If a Series is provided, it is
526+
treated as the 'prompt' column. If a DataFrame is provided, it
527+
must contain a 'prompt' column, or you must rename the column you
528+
wish to generate text to 'prompt'.
529529
temperature (float, optional):
530530
A FLOAT64 value that is used for sampling promiscuity. The value
531531
must be in the range ``[0.0, 1.0]``. A lower temperature works well
@@ -601,6 +601,101 @@ def generate_text(
601601
return session.read_gbq_query(query)
602602

603603

604+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605+
def generate_table(
606+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
607+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
608+
*,
609+
output_schema: str,
610+
temperature: Optional[float] = None,
611+
top_p: Optional[float] = None,
612+
max_output_tokens: Optional[int] = None,
613+
stop_sequences: Optional[List[str]] = None,
614+
request_type: Optional[str] = None,
615+
) -> dataframe.DataFrame:
616+
"""
617+
Generates a table using a BigQuery ML model.
618+
619+
See the `AI.GENERATE_TABLE function syntax
620+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-table>`_
621+
for additional reference.
622+
623+
**Examples:**
624+
625+
>>> import bigframes.pandas as bpd
626+
>>> import bigframes.bigquery as bbq
627+
>>> # The user is responsible for constructing a DataFrame that contains
628+
>>> # the necessary columns for the model's prompt. For example, a
629+
>>> # DataFrame with a 'prompt' column for text classification.
630+
>>> df = bpd.DataFrame({'prompt': ["some text to classify"]})
631+
>>> result = bbq.ai.generate_table(
632+
... "project.dataset.model_name",
633+
... data=df,
634+
... output_schema="category STRING"
635+
... ) # doctest: +SKIP
636+
637+
Args:
638+
model (bigframes.ml.base.BaseEstimator or str):
639+
The model to use for table generation.
640+
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
641+
The data to generate table for. If a Series is provided, it is
642+
treated as the 'prompt' column. If a DataFrame is provided, it
643+
must contain a 'prompt' column, or you must rename the column you
644+
wish to generate table to 'prompt'.
645+
output_schema (str):
646+
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
647+
temperature (float, optional):
648+
A FLOAT64 value that is used for sampling promiscuity. The value
649+
must be in the range ``[0.0, 1.0]``.
650+
top_p (float, optional):
651+
A FLOAT64 value that changes how the model selects tokens for
652+
output.
653+
max_output_tokens (int, optional):
654+
An INT64 value that sets the maximum number of tokens in the
655+
generated table.
656+
stop_sequences (List[str], optional):
657+
An ARRAY<STRING> value that contains the stop sequences for the model.
658+
request_type (str, optional):
659+
A STRING value that contains the request type for the model.
660+
661+
Returns:
662+
bigframes.pandas.DataFrame:
663+
The generated table.
664+
"""
665+
data = _to_dataframe(data, series_rename="prompt")
666+
model_name, session = bq_utils.get_model_name_and_session(model, data)
667+
table_sql = bq_utils.to_sql(data)
668+
669+
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
670+
"output_schema": output_schema
671+
}
672+
if temperature is not None:
673+
struct_fields_bq["temperature"] = temperature
674+
if top_p is not None:
675+
struct_fields_bq["top_p"] = top_p
676+
if max_output_tokens is not None:
677+
struct_fields_bq["max_output_tokens"] = max_output_tokens
678+
if stop_sequences is not None:
679+
struct_fields_bq["stop_sequences"] = stop_sequences
680+
if request_type is not None:
681+
struct_fields_bq["request_type"] = request_type
682+
683+
struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
684+
query = f"""
685+
SELECT *
686+
FROM AI.GENERATE_TABLE(
687+
MODEL `{model_name}`,
688+
({table_sql}),
689+
{struct_sql}
690+
)
691+
"""
692+
693+
if session is None:
694+
return bpd.read_gbq_query(query)
695+
else:
696+
return session.read_gbq_query(query)
697+
698+
604699
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605700
def if_(
606701
prompt: PROMPT_TYPE,

bigframes/bigquery/ai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
generate_double,
2525
generate_embedding,
2626
generate_int,
27+
generate_table,
2728
generate_text,
2829
if_,
2930
score,
@@ -37,6 +38,7 @@
3738
"generate_double",
3839
"generate_embedding",
3940
"generate_int",
41+
"generate_table",
4042
"generate_text",
4143
"if_",
4244
"score",

bigframes/core/array_value.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
import datetime
1818
import functools
1919
import typing
20-
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
20+
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
2121

22-
import google.cloud.bigquery
2322
import pandas
2423
import pyarrow as pa
2524

@@ -91,7 +90,7 @@ def from_range(cls, start, end, step):
9190
@classmethod
9291
def from_table(
9392
cls,
94-
table: google.cloud.bigquery.Table,
93+
table: Union[bq_data.BiglakeIcebergTable, bq_data.GbqNativeTable],
9594
session: Session,
9695
*,
9796
columns: Optional[Sequence[str]] = None,
@@ -103,8 +102,6 @@ def from_table(
103102
):
104103
if offsets_col and primary_key:
105104
raise ValueError("must set at most one of 'offests', 'primary_key'")
106-
# define data source only for needed columns, this makes row-hashing cheaper
107-
table_def = bq_data.GbqTable.from_table(table, columns=columns or ())
108105

109106
# create ordering from info
110107
ordering = None
@@ -115,7 +112,9 @@ def from_table(
115112
[ids.ColumnId(key_part) for key_part in primary_key]
116113
)
117114

118-
bf_schema = schemata.ArraySchema.from_bq_table(table, columns=columns)
115+
bf_schema = schemata.ArraySchema.from_bq_schema(
116+
table.physical_schema, columns=columns
117+
)
119118
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
120119
scan_list = nodes.ScanList(
121120
tuple(
@@ -124,7 +123,7 @@ def from_table(
124123
)
125124
)
126125
source_def = bq_data.BigqueryDataSource(
127-
table=table_def,
126+
table=table,
128127
schema=bf_schema,
129128
at_time=at_time,
130129
sql_predicate=predicate,

0 commit comments

Comments
 (0)