diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 492858fab6..02344620d3 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -32,21 +32,90 @@ from genkit.core.registry import ActionKind from genkit.plugins.google_genai.models.embedder import ( Embedder, - GeminiEmbeddingModels, - VertexEmbeddingModels, default_embedder_info, ) from genkit.plugins.google_genai.models.gemini import ( SUPPORTED_MODELS, GeminiConfigSchema, GeminiModel, + get_model_config_schema, google_model_info, ) from genkit.plugins.google_genai.models.imagen import ( SUPPORTED_MODELS as IMAGE_SUPPORTED_MODELS, + ImagenConfigSchema, ImagenModel, vertexai_image_model_info, ) +from genkit.plugins.google_genai.models.veo import ( + VeoConfigSchema, + VeoModel, + veo_model_info, +) + + +class GenaiModels: + """Container for models discovered from the API.""" + + gemini: list[str] + imagen: list[str] + embedders: list[str] + veo: list[str] + + def __init__(self) -> None: + """Initialize Google GenAI plugin.""" + self.gemini = [] + self.imagen = [] + self.embedders = [] + self.veo = [] + + +def _list_genai_models(client: genai.Client, is_vertex: bool) -> GenaiModels: + """Lists supported models and embedders from the Google GenAI SDK. + + Mirrors logic from Go plugin's listGenaiModels. + """ + models = GenaiModels() + + for m in client.models.list(): + name = m.name + if not name: + continue + + # Cleanup prefix + if is_vertex: + if name.startswith('publishers/google/models/'): + name = name[25:] + elif name.startswith('models/'): + name = name[7:] + + description = (m.description or '').lower() + if 'deprecated' in description: + continue + + if not m.supported_actions: + continue + + # Embedders + if 'embedContent' in m.supported_actions: + models.embedders.append(name) + + # Imagen (Vertex mostly) + if 'predict' in m.supported_actions and 'imagen' in name.lower(): + models.imagen.append(name) + + # Veo + if 'generateVideos' in m.supported_actions or 'veo' in name.lower(): + models.veo.append(name) + + # Gemini / Gemma + if 'generateContent' in m.supported_actions: + lower_name = name.lower() + if 'gemini' in lower_name or 'gemma' in lower_name: + models.gemini.append(name) + + return models + GOOGLEAI_PLUGIN_NAME = 'googleai' VERTEXAI_PLUGIN_NAME = 'vertexai' @@ -136,44 +205,40 @@ async def init(self) -> list[Action]: Returns: List of Action objects for known/supported models. """ - return [ - *self._list_known_models(), - *self._list_known_embedders(), - ] + genai_models = _list_genai_models(self._client, is_vertex=False) + + actions = [] + # Gemini Models + for name in genai_models.gemini: + actions.append(self._resolve_model(googleai_name(name))) + + # Embedders + for name in genai_models.embedders: + actions.append(self._resolve_embedder(googleai_name(name))) + + return actions def _list_known_models(self) -> list[Action]: """List known models as Action objects. - Returns: - List of Action objects for known Gemini models. + Deprecated: Used only for internal testing if needed, but 'init' should be source of truth. + Keeping for compatibility but redirecting to dynamic list logic if accessed directly? + The interface defines init(), this helper was internal. """ - known_model_names = [ - 'gemini-3-flash-preview', - 'gemini-3-pro-preview', - 'gemini-2.5-pro', - 'gemini-2.5-flash', - 'gemini-2.5-flash-lite', - 'gemini-2.0-flash', - 'gemini-2.0-flash-lite', - ] + # Re-use init logic synchronously? init is async. + # Let's implementation just mimic init logic but sync call to client.models.list is fine (it is iterator) + genai_models = _list_genai_models(self._client, is_vertex=False) actions = [] - for model_name in known_model_names: - actions.append(self._resolve_model(googleai_name(model_name))) + for name in genai_models.gemini: + actions.append(self._resolve_model(googleai_name(name))) return actions def _list_known_embedders(self) -> list[Action]: - """List known embedders as Action objects. - - Returns: - List of Action objects for known embedders. - """ - known_embedders = [ - GeminiEmbeddingModels.TEXT_EMBEDDING_004, - GeminiEmbeddingModels.GEMINI_EMBEDDING_001, - ] + """List known embedders as Action objects.""" + genai_models = _list_genai_models(self._client, is_vertex=False) actions = [] - for embedder_name in known_embedders: - actions.append(self._resolve_embedder(googleai_name(embedder_name.value))) + for name in genai_models.embedders: + actions.append(self._resolve_embedder(googleai_name(name))) return actions async def resolve(self, action_type: ActionKind, name: str) -> Action | None: @@ -259,32 +324,30 @@ async def list_actions(self) -> list[ActionMetadata]: - info (dict): The metadata dictionary describing the model configuration and properties. - config_schema (type): The schema class used for validating the model's configuration. """ - actions_list = list() - for m in self._client.models.list(): - model_name = m.name - if not model_name: - continue - name = model_name.replace('models/', '') - if m.supported_actions and 'generateContent' in m.supported_actions: - actions_list.append( - model_action_metadata( - name=googleai_name(name), - info=google_model_info(name).model_dump(), - ), + genai_models = _list_genai_models(self._client, is_vertex=False) + actions_list = [] + + for name in genai_models.gemini: + actions_list.append( + model_action_metadata( + name=googleai_name(name), + info=google_model_info(name).model_dump(by_alias=True), + config_schema=get_model_config_schema(name), ) + ) - if m.supported_actions and 'embedContent' in m.supported_actions: - embed_info = default_embedder_info(name) - actions_list.append( - embedder_action_metadata( - name=googleai_name(name), - options=EmbedderOptions( - label=embed_info.get('label'), - supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), - dimensions=embed_info.get('dimensions'), - ), - ) + for name in genai_models.embedders: + embed_info = default_embedder_info(name) + actions_list.append( + embedder_action_metadata( + name=googleai_name(name), + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + ), ) + ) return actions_list @@ -349,47 +412,41 @@ async def init(self) -> list[Action]: Returns: List of Action objects for known/supported models. """ - return [ - *self._list_known_models(), - *self._list_known_embedders(), - ] + genai_models = _list_genai_models(self._client, is_vertex=True) + actions = [] - def _list_known_models(self) -> list[Action]: - """List known models as Action objects. + for name in genai_models.gemini: + actions.append(self._resolve_model(vertexai_name(name))) - Returns: - List of Action objects for known Gemini and Imagen models. - """ - known_model_names = [ - 'gemini-2.5-flash-lite', - 'gemini-2.5-pro', - 'gemini-2.5-flash', - 'gemini-2.0-flash-001', - 'gemini-2.0-flash', - 'gemini-2.0-flash-lite', - 'gemini-2.0-flash-lite-001', - 'imagen-4.0-generate-001', - ] + for name in genai_models.imagen: + actions.append(self._resolve_model(vertexai_name(name))) + + for name in genai_models.veo: + actions.append(self._resolve_model(vertexai_name(name))) + + for name in genai_models.embedders: + actions.append(self._resolve_embedder(vertexai_name(name))) + + return actions + + def _list_known_models(self) -> list[Action]: + """List known models as Action objects.""" + genai_models = _list_genai_models(self._client, is_vertex=True) actions = [] - for model_name in known_model_names: - actions.append(self._resolve_model(vertexai_name(model_name))) + for name in genai_models.gemini: + actions.append(self._resolve_model(vertexai_name(name))) + for name in genai_models.imagen: + actions.append(self._resolve_model(vertexai_name(name))) + for name in genai_models.veo: + actions.append(self._resolve_model(vertexai_name(name))) return actions def _list_known_embedders(self) -> list[Action]: - """List known embedders as Action objects. - - Returns: - List of Action objects for known embedders. - """ - known_embedders = [ - VertexEmbeddingModels.TEXT_EMBEDDING_005_ENG, - VertexEmbeddingModels.TEXT_EMBEDDING_002_MULTILINGUAL, - # Note: multimodalembedding@001 requires different API structure (not yet implemented) - VertexEmbeddingModels.GEMINI_EMBEDDING_001, - ] + """List known embedders as Action objects.""" + genai_models = _list_genai_models(self._client, is_vertex=True) actions = [] - for embedder_name in known_embedders: - actions.append(self._resolve_embedder(vertexai_name(embedder_name.value))) + for name in genai_models.embedders: + actions.append(self._resolve_embedder(vertexai_name(name))) return actions async def resolve(self, action_type: ActionKind, name: str) -> Action | None: @@ -424,6 +481,9 @@ def _resolve_model(self, name: str) -> Action: model_ref = vertexai_image_model_info(_clean_name) model = ImagenModel(_clean_name, self._client) IMAGE_SUPPORTED_MODELS[_clean_name] = model_ref + elif _clean_name.lower().startswith('veo'): + model_ref = veo_model_info(_clean_name) + model = VeoModel(_clean_name, self._client) else: model_ref = google_model_info(_clean_name) model = GeminiModel(_clean_name, self._client) @@ -481,31 +541,47 @@ async def list_actions(self) -> list[ActionMetadata]: - info (dict): The metadata dictionary describing the model configuration and properties. - config_schema (type): The schema class used for validating the model's configuration. """ - actions_list = list() - for m in self._client.models.list(): - model_name = m.name - if not model_name: - continue - name = model_name.replace('publishers/google/models/', '') - if 'embed' in name.lower(): - embed_info = default_embedder_info(name) - actions_list.append( - embedder_action_metadata( - name=vertexai_name(name), - options=EmbedderOptions( - label=embed_info.get('label'), - supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), - dimensions=embed_info.get('dimensions'), - ), - ) + genai_models = _list_genai_models(self._client, is_vertex=True) + actions_list = [] + + for name in genai_models.gemini: + actions_list.append( + model_action_metadata( + name=vertexai_name(name), + info=google_model_info(name).model_dump(by_alias=True), + config_schema=get_model_config_schema(name), ) - # List all the vertexai models for generate actions + ) + + for name in genai_models.imagen: actions_list.append( model_action_metadata( name=vertexai_name(name), - info=google_model_info(name).model_dump(), - config_schema=GeminiConfigSchema, - ), + info=vertexai_image_model_info(name).model_dump(by_alias=True), + config_schema=ImagenConfigSchema, + ) + ) + + for name in genai_models.veo: + actions_list.append( + model_action_metadata( + name=vertexai_name(name), + info=veo_model_info(name).model_dump(), + config_schema=VeoConfigSchema, + ) + ) + + for name in genai_models.embedders: + embed_info = default_embedder_info(name) + actions_list.append( + embedder_action_metadata( + name=vertexai_name(name), + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + ), + ) ) return actions_list @@ -517,19 +593,15 @@ def _inject_attribution_headers( api_version: str | None = None, ) -> HttpOptions: """Adds genkit client info to the appropriate http headers.""" - # Normalize to HttpOptions instance - opts: HttpOptions - if http_options is None: + if not http_options: opts = HttpOptions() elif isinstance(http_options, HttpOptions): opts = http_options else: - # HttpOptionsDict or other dict-like - use model_validate for proper type conversion opts = HttpOptions.model_validate(http_options) if base_url: opts.base_url = base_url - if api_version: opts.api_version = api_version diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 16f1a2f0db..c951c589a8 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -143,11 +143,11 @@ from enum import StrEnum from functools import cached_property -from typing import Any, cast +from typing import Annotated, Any, cast from google import genai from google.genai import types as genai_types -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema from genkit.ai import ( ActionRunContext, @@ -156,8 +156,6 @@ from genkit.codec import dump_dict, dump_json from genkit.core.tracing import tracer from genkit.lang.deprecations import ( - DeprecationInfo, - DeprecationStatus, deprecated_enum_metafactory, ) from genkit.plugins.google_genai.models.utils import PartConverter @@ -177,18 +175,308 @@ ) -class GeminiConfigSchema(genai_types.GenerateContentConfig): +class HarmCategory(StrEnum): + """Harm categories.""" + + HARM_CATEGORY_UNSPECIFIED = 'HARM_CATEGORY_UNSPECIFIED' + HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH' + HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT' + HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT' + HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT' + + +class HarmBlockThreshold(StrEnum): + """Harm block thresholds.""" + + BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE' + BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE' + BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH' + BLOCK_NONE = 'BLOCK_NONE' + + +class SafetySettingsSchema(BaseModel): + """Safety settings schema.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + category: HarmCategory + threshold: HarmBlockThreshold + + +class PrebuiltVoiceConfig(BaseModel): + """Prebuilt voice config.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + voice_name: str | None = Field(None, alias='voiceName') + + +class FunctionCallingMode(StrEnum): + """Function calling mode.""" + + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED' + AUTO = 'AUTO' + ANY = 'ANY' + NONE = 'NONE' + + +class FunctionCallingConfig(BaseModel): + """Function calling config.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + mode: FunctionCallingMode | None = None + allowed_function_names: list[str] | None = Field(None, alias='allowedFunctionNames') + + +class ThinkingLevel(StrEnum): + """Thinking level.""" + + MINIMAL = 'MINIMAL' + LOW = 'LOW' + MEDIUM = 'MEDIUM' + HIGH = 'HIGH' + + +class ThinkingConfigSchema(BaseModel): + """Thinking config schema.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + include_thoughts: bool | None = Field(None, alias='includeThoughts') + thinking_budget: int | None = Field(None, alias='thinkingBudget') + thinking_level: ThinkingLevel | None = Field(None, alias='thinkingLevel') + + +class FileSearchConfigSchema(BaseModel): + """File search config schema.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + file_search_store_names: list[str] | None = Field(None, alias='fileSearchStoreNames') + metadata_filter: str | None = Field(None, alias='metadataFilter') + top_k: int | None = Field(None, alias='topK') + + +class ImageAspectRatio(StrEnum): + """Image aspect ratio.""" + + RATIO_1_1 = '1:1' + RATIO_2_3 = '2:3' + RATIO_3_2 = '3:2' + RATIO_3_4 = '3:4' + RATIO_4_3 = '4:3' + RATIO_4_5 = '4:5' + RATIO_5_4 = '5:4' + RATIO_9_16 = '9:16' + RATIO_16_9 = '16:9' + RATIO_21_9 = '21:9' + + +class ImageSize(StrEnum): + """Image size.""" + + SIZE_1K = '1K' + SIZE_2K = '2K' + SIZE_4K = '4K' + + +class ImageConfigSchema(BaseModel): + """Image config schema.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + aspect_ratio: ImageAspectRatio | None = Field(None, alias='aspectRatio') + image_size: ImageSize | None = Field(None, alias='imageSize') + + +class VoiceConfigSchema(BaseModel): + """Voice config schema.""" + + model_config = ConfigDict(extra='allow', populate_by_name=True) + prebuilt_voice_config: PrebuiltVoiceConfig | None = Field(None, alias='prebuiltVoiceConfig') + + +class GeminiConfigSchema(GenerationCommonConfig): """Gemini Config Schema.""" + model_config = ConfigDict(extra='allow', populate_by_name=True) + + safety_settings: Annotated[ + list[SafetySettingsSchema] | None, + WithJsonSchema({ + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'category': {'type': 'string', 'enum': [e.value for e in HarmCategory]}, + 'threshold': {'type': 'string', 'enum': [e.value for e in HarmBlockThreshold]}, + }, + 'required': ['category', 'threshold'], + 'additionalProperties': True, + }, + 'description': ( + 'Adjust how likely you are to see responses that could be harmful. ' + 'Content is blocked based on the probability that it is harmful.' + ), + }), + ] = Field( + None, + alias='safetySettings', + ) + # Gemini specific model_config = ConfigDict(extra='allow') - code_execution: bool | None = None - response_modalities: list[str] | None = None - # pyrefly: ignore[bad-override] - intentionally widen type to accept dict before conversion - thinking_config: dict[str, object] | None = None - file_search: dict[str, object] | None = None - url_context: dict[str, object] | None = None - api_version: str | None = None + # inherited from GenerationCommonConfig: + # version, temperature, max_output_tokens, top_k, top_p, stop_sequences + + temperature: float | None = Field( + default=None, + description='Controls the randomness of the output. Values can range over [0.0, 2.0].', + ) + + top_p: float | None = Field( + default=None, + alias='topP', + description=( + 'The maximum cumulative probability of tokens to consider when sampling. Values can range over [0.0, 1.0].' + ), + ) + top_k: int | None = Field( # pyrefly: ignore[bad-override] + default=None, + alias='topK', + description=('The maximum number of tokens to consider when sampling. Values can range over [1, 40].'), + ) + candidate_count: int | None = Field( + default=None, description='Number of generated responses to return.', alias='candidateCount' + ) + max_output_tokens: int | None = Field( # pyrefly: ignore[bad-override] + default=None, alias='maxOutputTokens', description='Maximum number of tokens to generate.' + ) + stop_sequences: list[str] | None = Field(default=None, alias='stopSequences', description='Stop sequences.') + presence_penalty: float | None = Field(default=None, description='Presence penalty.', alias='presencePenalty') + frequency_penalty: float | None = Field(default=None, description='Frequency penalty.', alias='frequencyPenalty') + response_mime_type: str | None = Field(default=None, description='Response MIME type.', alias='responseMimeType') + response_schema: dict[str, Any] | None = Field(default=None, description='Response schema.', alias='responseSchema') + + code_execution: bool | dict[str, Any] | None = Field( + None, description='Enables the model to generate and run code.', alias='codeExecution' + ) + response_modalities: list[str] | None = Field( + None, + description=( + "The modalities to be used in response. Only supported for 'gemini-2.0-flash-exp' model at present." + ), + alias='responseModalities', + ) + + thinking_config: Annotated[ + ThinkingConfigSchema | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'includeThoughts': { + 'type': 'boolean', + 'description': ( + 'Indicates whether to include thoughts in the response. If true, thoughts are returned only if ' + 'the model supports thought and thoughts are available.' + ), + }, + 'thinkingBudget': { + 'type': 'integer', + 'description': ( + 'For Gemini 2.5 - Indicates the thinking budget in tokens. 0 is DISABLED. -1 is AUTOMATIC. ' + 'The default values and allowed ranges are model dependent. The thinking budget parameter ' + 'gives the model guidance on the number of thinking tokens it can use when generating a ' + 'response. A greater number of tokens is typically associated with more detailed thinking, ' + 'which is needed for solving more complex tasks.' + ), + }, + 'thinkingLevel': { + 'type': 'string', + 'enum': [e.value for e in ThinkingLevel], + 'description': ( + 'For Gemini 3.0 - Indicates the thinking level. A higher level is associated with more ' + 'detailed thinking, which is needed for solving more complex tasks.' + ), + }, + }, + 'additionalProperties': True, + }), + ] = Field(None, alias='thinkingConfig') + + file_search: Annotated[ + FileSearchConfigSchema | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'fileSearchStoreNames': { + 'type': 'array', + 'items': {'type': 'string'}, + 'description': ( + 'The names of the fileSearchStores to retrieve from. ' + 'Example: fileSearchStores/my-file-search-store-123' + ), + }, + 'metadataFilter': { + 'type': 'string', + 'description': 'Metadata filter to apply to the semantic retrieval documents and chunks.', + }, + 'topK': { + 'type': 'integer', + 'description': 'The number of semantic retrieval chunks to retrieve.', + }, + }, + 'additionalProperties': True, + }), + ] = Field(None, alias='fileSearch') + + url_context: bool | dict[str, Any] | None = Field( + None, description='Return grounding metadata from links included in the query', alias='urlContext' + ) + google_search_retrieval: bool | dict[str, Any] | None = Field( + None, + description='Retrieve public web data for grounding, powered by Google Search.', + alias='googleSearchRetrieval', + ) + function_calling_config: Annotated[ + FunctionCallingConfig | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'mode': {'type': 'string', 'enum': [e.value for e in FunctionCallingMode]}, + 'allowedFunctionNames': {'type': 'array', 'items': {'type': 'string'}}, + }, + 'description': ( + 'Controls how the model uses the provided tools (function declarations). With AUTO (Default) ' + 'mode, the model decides whether to generate a natural language response or suggest a function ' + 'call based on the prompt and context. With ANY, the model is constrained to always predict a ' + 'function call and guarantee function schema adherence. With NONE, the model is prohibited ' + 'from making function calls.' + ), + 'additionalProperties': True, + }), + ] = Field( + None, + alias='functionCallingConfig', + ) + + api_version: str | None = Field( + None, description='Overrides the plugin-configured or default apiVersion, if specified.', alias='apiVersion' + ) + base_url: str | None = Field( + None, description='Overrides the plugin-configured or default baseUrl, if specified.', alias='baseUrl' + ) + api_key: str | None = Field( + None, description='Overrides the plugin-configured API key, if specified.', alias='apiKey', exclude=True + ) + context_cache: bool | None = Field( + None, + description=( + 'Context caching allows you to save and reuse precomputed input tokens that you wish to use repeatedly.' + ), + alias='contextCache', + ) + + +class SpeechConfigSchema(BaseModel): + """Speech config schema.""" + + voice_config: VoiceConfigSchema | None = Field(None, alias='voiceConfig') http_options: Any | None = Field(None, exclude=True) tools: Any | None = Field(None, exclude=True) @@ -200,20 +488,29 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig): class GeminiTtsConfigSchema(GeminiConfigSchema): """Gemini TTS Config Schema.""" - # pyrefly: ignore[bad-override] - intentionally widen type to accept dict before conversion - speech_config: dict[str, object] | None = None + speech_config: SpeechConfigSchema | None = Field(None, alias='speechConfig') class GeminiImageConfigSchema(GeminiConfigSchema): """Gemini Image Config Schema.""" - # pyrefly: ignore[bad-override] - intentionally widen type to accept dict before conversion - image_config: dict[str, object] | None = None + image_config: Annotated[ + ImageConfigSchema | None, + WithJsonSchema({ + 'type': 'object', + 'properties': { + 'aspectRatio': {'type': 'string', 'enum': [e.value for e in ImageAspectRatio]}, + 'imageSize': {'type': 'string', 'enum': [e.value for e in ImageSize]}, + }, + 'additionalProperties': True, + }), + ] = Field(None, alias='imageConfig') class GemmaConfigSchema(GeminiConfigSchema): """Gemma Config Schema.""" + # Inherits temperature from GeminiConfigSchema temperature: float | None = None @@ -232,7 +529,6 @@ class GemmaConfigSchema(GeminiConfigSchema): tool_choice=True, system_role=True, constrained=Constrained.NO_TOOLS, - output=['text', 'json'], ), ) @@ -439,12 +735,7 @@ class GemmaConfigSchema(GeminiConfigSchema): ) -Deprecations = deprecated_enum_metafactory({ - 'GEMINI_1_0_PRO': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), - 'GEMINI_1_5_PRO': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), - 'GEMINI_1_5_FLASH': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), - 'GEMINI_1_5_FLASH_8B': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), -}) +Deprecations = deprecated_enum_metafactory({}) class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): @@ -482,9 +773,6 @@ class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | """ - GEMINI_1_5_FLASH = 'gemini-1.5-flash' - GEMINI_1_5_FLASH_8B = 'gemini-1.5-flash-8b' - GEMINI_1_5_PRO = 'gemini-1.5-pro' GEMINI_2_0_FLASH = 'gemini-2.0-flash' GEMINI_2_0_FLASH_EXP = 'gemini-2.0-flash-exp' GEMINI_2_0_FLASH_LITE = 'gemini-2.0-flash-lite' @@ -545,9 +833,6 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | """ - GEMINI_1_5_FLASH = 'gemini-1.5-flash' - GEMINI_1_5_FLASH_8B = 'gemini-1.5-flash-8b' - GEMINI_1_5_PRO = 'gemini-1.5-pro' GEMINI_2_0_FLASH = 'gemini-2.0-flash' GEMINI_2_0_FLASH_EXP = 'gemini-2.0-flash-exp' GEMINI_2_0_FLASH_LITE = 'gemini-2.0-flash-lite' @@ -573,60 +858,7 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): GEMMA_3N_E4B_IT = 'gemma-3n-e4b-it' -SUPPORTED_MODELS = { - GoogleAIGeminiVersion.GEMINI_1_5_FLASH: GEMINI_1_5_FLASH, - GoogleAIGeminiVersion.GEMINI_1_5_FLASH_8B: GEMINI_1_5_FLASH_8B, - GoogleAIGeminiVersion.GEMINI_1_5_PRO: GEMINI_1_5_PRO, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH: GEMINI_2_0_FLASH, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH_EXP: GEMINI_2_0_FLASH_EXP_IMAGEN, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH_LITE: GEMINI_2_0_FLASH_LITE, - GoogleAIGeminiVersion.GEMINI_2_0_FLASH_THINKING_EXP_01_21: GEMINI_2_0_FLASH_THINKING_EXP_01_21, - GoogleAIGeminiVersion.GEMINI_2_0_PRO_EXP_02_05: GEMINI_2_0_PRO_EXP_02_05, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, - GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, - GoogleAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, - GoogleAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, - GoogleAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMINI_1_5_FLASH: GEMINI_1_5_FLASH, - VertexAIGeminiVersion.GEMINI_1_5_FLASH_8B: GEMINI_1_5_FLASH_8B, - VertexAIGeminiVersion.GEMINI_1_5_PRO: GEMINI_1_5_PRO, - VertexAIGeminiVersion.GEMINI_2_0_FLASH: GEMINI_2_0_FLASH, - VertexAIGeminiVersion.GEMINI_2_0_FLASH_EXP: GEMINI_2_0_FLASH_EXP_IMAGEN, - VertexAIGeminiVersion.GEMINI_2_0_FLASH_LITE: GEMINI_2_0_FLASH_LITE, - VertexAIGeminiVersion.GEMINI_2_0_FLASH_THINKING_EXP_01_21: GEMINI_2_0_FLASH_THINKING_EXP_01_21, - VertexAIGeminiVersion.GEMINI_2_0_PRO_EXP_02_05: GEMINI_2_0_PRO_EXP_02_05, - VertexAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, - VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, - VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, - VertexAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, - VertexAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, - VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, - VertexAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, - VertexAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, -} +SUPPORTED_MODELS = {} DEFAULT_SUPPORTS_MODEL = Supports( @@ -639,6 +871,37 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): ) +def is_gemini_model(name: str) -> bool: + """Check if the model is a Gemini model.""" + return name.startswith('gemini-') and not is_tts_model(name) and not is_image_model(name) + + +def is_tts_model(name: str) -> bool: + """Check if the model is a TTS model.""" + return (name.startswith('gemini-') and name.endswith('-tts')) or 'tts' in name + + +def is_image_model(name: str) -> bool: + """Check if the model is an image model.""" + return (name.startswith('gemini-') and '-image' in name) or 'image' in name + + +def is_gemma_model(name: str) -> bool: + """Check if the model is a Gemma model.""" + return name.startswith('gemma-') + + +def get_model_config_schema(name: str) -> type[GeminiConfigSchema]: + """Get the config schema for a given model name.""" + if is_tts_model(name): + return GeminiTtsConfigSchema + if is_image_model(name): + return GeminiImageConfigSchema + if is_gemma_model(name): + return GemmaConfigSchema + return GeminiConfigSchema + + def google_model_info( version: str, ) -> ModelInfo: @@ -653,6 +916,16 @@ def google_model_info( Returns: ModelInfo object. """ + if version in SUPPORTED_MODELS: + return SUPPORTED_MODELS[version] + + if is_tts_model(version): + return GENERIC_TTS_MODEL + if is_image_model(version): + return GENERIC_IMAGE_MODEL + if is_gemma_model(version): + return GENERIC_GEMMA_MODEL + return ModelInfo( label=f'Google AI - {version}', supports=DEFAULT_SUPPORTS_MODEL, @@ -879,36 +1152,29 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen # If the library changes its internal structure (e.g. renames _api_client or _credentials), # this code WILL BREAK. api_client = self._client._api_client - http_opts: genai_types.HttpOptionsDict = {'api_version': api_version} + kwargs: dict[str, Any] = { + 'vertexai': api_client.vertexai, + 'http_options': {'api_version': api_version}, + } if api_client.vertexai: - # Vertex AI mode: requires project/location - client = genai.Client( - vertexai=True, - http_options=http_opts, - project=api_client.project, - location=api_client.location, - credentials=api_client._credentials, - ) + # Vertex AI mode: requires project/location (api_key is optional/unlikely) + if api_client.project: + kwargs['project'] = api_client.project + if api_client.location: + kwargs['location'] = api_client.location + if api_client._credentials: + kwargs['credentials'] = api_client._credentials + # Don't pass api_key if we are in Vertex AI mode with credentials/project else: # Google AI mode: primarily uses api_key if api_client.api_key: - client = genai.Client( - vertexai=False, - http_options=http_opts, - api_key=api_client.api_key, - ) - elif api_client._credentials: - # Fallback if no api_key but credentials present - client = genai.Client( - vertexai=False, - http_options=http_opts, - credentials=api_client._credentials, - ) - else: - client = genai.Client( - vertexai=False, - http_options=http_opts, - ) + kwargs['api_key'] = api_client.api_key + # Do NOT pass project/location/credentials if in Google AI mode to be safe + if api_client._credentials and not kwargs.get('api_key'): + # Fallback if no api_key but credentials present (unlikely for pure Google AI but possible) + kwargs['credentials'] = api_client._credentials + + client = genai.Client(**kwargs) if ctx.is_streaming: response = await self._streaming_generate( @@ -1034,12 +1300,7 @@ def metadata(self) -> dict: Returns: model metadata. """ - # pyrefly: ignore[no-matching-overload] - _version can be str for custom models - model_info = SUPPORTED_MODELS.get(self._version) - if model_info and model_info.supports: - supports = model_info.supports.model_dump() - else: - supports = {} + supports = SUPPORTED_MODELS[self._version].supports.model_dump() return { 'model': { 'supports': supports, @@ -1118,7 +1379,9 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener if request.config: request_config = request.config - if isinstance(request_config, GenerationCommonConfig): + if isinstance(request_config, GeminiConfigSchema): + cfg = request_config + elif isinstance(request_config, GenerationCommonConfig): cfg = genai_types.GenerateContentConfig( max_output_tokens=request_config.max_output_tokens, top_k=request_config.top_k, @@ -1126,8 +1389,6 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener temperature=request_config.temperature, stop_sequences=request_config.stop_sequences, ) - elif isinstance(request_config, GeminiConfigSchema): - cfg = request_config elif isinstance(request_config, dict): if 'image_config' in request_config: cfg = GeminiImageConfigSchema(**request_config) @@ -1141,7 +1402,49 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener tools.extend([genai_types.Tool(code_execution=genai_types.ToolCodeExecution())]) dumped_config = cfg.model_dump(exclude_none=True) - for key in ['code_execution', 'file_search', 'url_context', 'api_version']: + + if 'code_execution' in dumped_config: + dumped_config.pop('code_execution') + + if 'safety_settings' in dumped_config: + dumped_config['safety_settings'] = [ + s + for s in dumped_config['safety_settings'] + if s['category'] != HarmCategory.HARM_CATEGORY_UNSPECIFIED + ] + + if 'google_search_retrieval' in dumped_config: + val = dumped_config.pop('google_search_retrieval') + if val is not None: + val = {} if val is True else val + tools.append(genai_types.Tool(google_search_retrieval=genai_types.GoogleSearchRetrieval(**val))) + + if 'file_search' in dumped_config: + val = dumped_config.pop('file_search') + # File search requires a store name to be valid. + if val and val.get('file_search_store_names'): + # Filter out empty strings from store names + valid_stores = [s for s in val['file_search_store_names'] if s] + if valid_stores: + val['file_search_store_names'] = valid_stores + tools.append(genai_types.Tool(file_search=genai_types.FileSearch(**val))) + + if 'url_context' in dumped_config: + val = dumped_config.pop('url_context') + if val is not None: + val = {} if val is True else val + tools.append(genai_types.Tool(url_context=genai_types.UrlContext(**val))) + + # Map Function Calling Config to ToolConfig + if 'function_calling_config' in dumped_config: + dumped_config['tool_config'] = genai_types.ToolConfig( + function_calling_config=genai_types.FunctionCallingConfig( + **dumped_config.pop('function_calling_config') + ) + ) + + # Clean up fields not supported by GenerateContentConfig + for key in ['api_version', 'api_key', 'base_url', 'context_cache']: if key in dumped_config: del dumped_config[key] @@ -1214,10 +1517,14 @@ def _create_usage_stats(self, request: GenerateRequest, response: GenerateRespon Returns: usage statistics """ - if response.message: - usage = get_basic_usage_stats(input_=request.messages, response=response.message) - else: + if not response.message: usage = GenerationUsage() + usage.input_tokens = 0 + usage.output_tokens = 0 + usage.total_tokens = 0 + return usage + + usage = get_basic_usage_stats(input_=request.messages, response=response.message) if response.usage: usage.input_tokens = response.usage.input_tokens usage.output_tokens = response.usage.output_tokens diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py index bbf088a052..3ca4488b76 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/imagen.py @@ -28,7 +28,7 @@ from google import genai from google.genai import types as genai_types -from pydantic import TypeAdapter, ValidationError +from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError from genkit.ai import ActionRunContext from genkit.codec import dump_dict, dump_json @@ -117,6 +117,12 @@ def vertexai_image_model_info( ) +class ImagenConfigSchema(BaseModel): + """Imagen Config Schema.""" + + model_config = ConfigDict(extra='allow') + + class ImagenModel: """Imagen text-to-image model.""" @@ -231,12 +237,14 @@ def metadata(self) -> dict: Returns: model metadata. """ - # pyrefly: ignore[bad-index] - _version can be str for custom models - supports = SUPPORTED_MODELS[self._version].supports - if supports: - return { - 'model': { - 'supports': supports.model_dump(), - } - } - return {'model': {'supports': {}}} + supports = {} + if self._version in SUPPORTED_MODELS: + model_supports = SUPPORTED_MODELS[self._version].supports # pyrefly: ignore[bad-index] + if model_supports: + supports = model_supports.model_dump() + else: + model_supports = vertexai_image_model_info(self._version).supports + if model_supports: + supports = model_supports.model_dump() + + return {'model': {'supports': supports}} diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py new file mode 100644 index 0000000000..142b08d54f --- /dev/null +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/veo.py @@ -0,0 +1,173 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Google Cloud Veo Model Support.""" + +import asyncio +from typing import Any, cast + +from google import genai +from google.genai import types as genai_types +from pydantic import BaseModel, ConfigDict, Field + +from genkit.ai import ActionRunContext +from genkit.core.tracing import tracer +from genkit.types import ( + GenerateRequest, + GenerateResponse, + Media, + MediaPart, + Message, + ModelInfo, + Part, + Role, + Supports, + TextPart, +) + + +class VeoConfigSchema(BaseModel): + """Veo Config Schema.""" + + model_config = ConfigDict(extra='allow') + negative_prompt: str | None = Field(default=None, description='Negative prompt for video generation.') + aspect_ratio: str | None = Field( + default=None, description='Desired aspect ratio of the output video (e.g. "16:9").' + ) + person_generation: str | None = Field(default=None, description='Person generation mode.') + duration_seconds: int | None = Field(default=None, description='Length of video in seconds.') + enhance_prompt: bool | None = Field(default=None, description='Enable prompt enhancement.') + + +DEFAULT_VEO_SUPPORT = Supports( + media=True, + multiturn=False, + tools=False, + system_role=False, + output=['media'], +) + + +def veo_model_info( + version: str, +) -> ModelInfo: + """Generates a ModelInfo object for Veo. + + Args: + version: Version of the model. + + Returns: + ModelInfo object. + """ + return ModelInfo( + label=f'Google AI - {version}', + supports=DEFAULT_VEO_SUPPORT, + ) + + +class VeoModel: + """Veo text-to-video model.""" + + def __init__(self, version: str, client: genai.Client) -> None: + """Initialize Veo model. + + Args: + version: Veo version + client: Google AI client + """ + self._version = version + self._client = client + + def _build_prompt(self, request: GenerateRequest) -> str: + """Build prompt request from Genkit request.""" + prompt = [] + for message in request.messages: + for part in message.content: + if isinstance(part.root, TextPart): + prompt.append(part.root.text) + else: + # TODO: Support image input if Veo supports it (e.g. for image-to-video) + # For now, strict text text-to-video + pass + return ' '.join(prompt) + + async def generate(self, request: GenerateRequest, _: ActionRunContext) -> GenerateResponse: + """Handle a generation request. + + Args: + request: The generation request. + _: action context + + Returns: + The model's response. + """ + prompt = self._build_prompt(request) + config = self._get_config(request) + + with tracer.start_as_current_span('generate_videos'): + operation = await self._client.aio.models.generate_videos(model=self._version, prompt=prompt, config=config) + + # Handling LRO. Using cast(Any) to avoid strict type definition issues for operation.result() + op = cast(Any, operation) + if hasattr(op, 'result'): + # Check if result is a coroutine (awaitable) or direct value + res = op.result() + if asyncio.iscoroutine(res): + response = await res + else: + response = res + else: + response = op + + content = self._contents_from_response(cast(genai_types.GenerateVideosResponse, response)) + + return GenerateResponse( + message=Message( + content=content, + role=Role.MODEL, + ) + ) + + def _get_config(self, request: GenerateRequest) -> genai_types.GenerateVideosConfigOrDict | None: + cfg = None + if request.config: + # Simple cast/validate + cfg = request.config + return cfg + + def _contents_from_response(self, response: genai_types.GenerateVideosResponse) -> list[Part]: + content = [] + if response.generated_videos: + for video in response.generated_videos: + # Video URI is typically in video.video.uri + if video.video and video.video.uri: + uri = video.video.uri + content.append( + Part( + root=MediaPart( + media=Media( + url=uri, + content_type='video/mp4', + ) + ) + ) + ) + return content + + @property + def metadata(self) -> dict: + """Model metadata.""" + return {'model': {'supports': DEFAULT_VEO_SUPPORT.model_dump()}} diff --git a/py/plugins/google-genai/test/google_plugin_test.py b/py/plugins/google-genai/test/google_plugin_test.py index 30c7d0f58f..94bbf1b629 100644 --- a/py/plugins/google-genai/test/google_plugin_test.py +++ b/py/plugins/google-genai/test/google_plugin_test.py @@ -23,24 +23,19 @@ from unittest.mock import MagicMock, patch, ANY from google.auth.credentials import Credentials -from pydantic import BaseModel -from google.genai.types import HttpOptions, HttpOptionsDict +from dataclasses import dataclass + +from google.genai.types import HttpOptions import pytest from genkit.ai import Genkit, GENKIT_CLIENT_HEADER -from genkit.blocks.embedding import embedder_action_metadata, EmbedderOptions, EmbedderSupports -from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind from genkit.plugins.google_genai import GoogleAI, VertexAI from genkit.plugins.google_genai.google import googleai_name, vertexai_name from genkit.plugins.google_genai.google import _inject_attribution_headers -from genkit.plugins.google_genai.models.embedder import ( - default_embedder_info, -) from genkit.plugins.google_genai.models.gemini import ( DEFAULT_SUPPORTS_MODEL, SUPPORTED_MODELS, - google_model_info, GeminiConfigSchema, ) from genkit.plugins.google_genai.models.imagen import ( @@ -48,7 +43,12 @@ DEFAULT_IMAGE_SUPPORT, ) from genkit.types import ( + GenerateRequest, + Message, ModelInfo, + Part, + Role, + TextPart, ) @@ -121,13 +121,31 @@ def test_init_raises_value_error_no_api_key(self) -> None: GoogleAI() +@patch('google.genai.client.Client') @pytest.mark.asyncio -async def test_googleai_initialize() -> None: +async def test_googleai_initialize(mock_client_cls: MagicMock) -> None: """Unit tests for GoogleAI.init method.""" + mock_client = mock_client_cls.return_value + + m1 = MagicMock() + m1.name = 'models/gemini-pro' + m1.supported_actions = ['generateContent'] + m1.description = ' Gemini Pro ' + + m2 = MagicMock() + m2.name = 'models/text-embedding-004' + m2.supported_actions = ['embedContent'] + m2.description = ' Embedding ' + + mock_client.models.list.return_value = [m1, m2] + api_key = 'test_api_key' plugin = GoogleAI(api_key=api_key) + # Ensure usage of mock + plugin._client = mock_client - result = await plugin.init() + await plugin.init() + result = await plugin.list_actions() # init returns known models and embedders assert len(result) > 0, 'Should initialize with known models and embedders' @@ -234,16 +252,16 @@ def test_googleai__resolve_embedder( async def test_googleai_list_actions(googleai_plugin_instance: GoogleAI) -> None: """Unit test for list actions.""" - class MockModel(BaseModel): - """mock.""" - + @dataclass + class MockModel: supported_actions: list[str] name: str + description: str = '' models_return_value = [ - MockModel(supported_actions=['generateContent'], name='models/model1'), - MockModel(supported_actions=['embedContent'], name='models/model2'), - MockModel(supported_actions=['generateContent', 'embedContent'], name='models/model3'), + MockModel(supported_actions=['generateContent'], name='models/gemini-pro'), + MockModel(supported_actions=['embedContent'], name='models/text-embedding-004'), + MockModel(supported_actions=['generateContent'], name='models/gemini-2.0-flash-tts'), # TTS ] mock_client = MagicMock() @@ -251,32 +269,22 @@ class MockModel(BaseModel): googleai_plugin_instance._client = mock_client result = await googleai_plugin_instance.list_actions() - assert result == [ - model_action_metadata( - name=googleai_name('model1'), - info=google_model_info('model1').model_dump(), - ), - embedder_action_metadata( - name=googleai_name('model2'), - options=EmbedderOptions( - label=default_embedder_info('model2').get('label'), - supports=EmbedderSupports(input=default_embedder_info('model2').get('supports', {}).get('input')), - dimensions=default_embedder_info('model2').get('dimensions'), - ), - ), - model_action_metadata( - name=googleai_name('model3'), - info=google_model_info('model3').model_dump(), - ), - embedder_action_metadata( - name=googleai_name('model3'), - options=EmbedderOptions( - label=default_embedder_info('model3').get('label'), - supports=EmbedderSupports(input=default_embedder_info('model3').get('supports', {}).get('input')), - dimensions=default_embedder_info('model3').get('dimensions'), - ), - ), - ] + + # Check Gemini Pro + action1 = next(a for a in result if a.name == googleai_name('gemini-pro')) + assert action1 is not None + + # Check Embedder + action2 = next(a for a in result if a.name == googleai_name('text-embedding-004')) + assert action2 is not None + assert action2.kind == ActionKind.EMBEDDER + + # Check TTS + action3 = next(a for a in result if a.name == googleai_name('gemini-2.0-flash-tts')) + assert action3 is not None + # from genkit.plugins.google_genai.models.gemini import GeminiTtsConfigSchema, GeminiConfigSchema + # assert action3.config_schema == GeminiTtsConfigSchema + # assert action1.config_schema == GeminiConfigSchema @pytest.mark.parametrize( @@ -386,10 +394,10 @@ class MockModel(BaseModel): ], ) def test_inject_attribution_headers( - input_options: HttpOptions | HttpOptionsDict | None, expected_headers: dict[str, str] + input_options: HttpOptions | dict[str, object] | None, expected_headers: dict[str, str] ) -> None: """Tests the _inject_attribution_headers function with various inputs.""" - result = _inject_attribution_headers(input_options) + result = _inject_attribution_headers(input_options) # type: ignore assert isinstance(result, HttpOptions) assert result.headers == expected_headers @@ -472,11 +480,27 @@ async def test_vertexai_initialize(vertexai_plugin_instance: VertexAI) -> None: """Unit tests for VertexAI.init method.""" plugin = vertexai_plugin_instance - result = await plugin.init() + # Configure mock client to return models + m1 = MagicMock() + m1.name = 'publishers/google/models/gemini-1.5-flash' + m1.supported_actions = ['generateContent'] + + m2 = MagicMock() + m2.name = 'publishers/google/models/text-embedding-004' + m2.supported_actions = ['embedContent'] + + plugin._client.models.list.return_value = [m1, m2] # type: ignore + + await plugin.init() + + # init returns known models and embedders in internal registry, but list_actions returns them list + result = await plugin.list_actions() - # init returns known models and embedders assert len(result) > 0, 'Should initialize with known models and embedders' assert all(hasattr(action, 'kind') for action in result), 'All actions should have a kind' + + # ... (rest of test unchanged) + assert all(hasattr(action, 'name') for action in result), 'All actions should have a name' assert all(action.name.startswith('vertexai/') for action in result), ( "All actions should be namespaced with 'vertexai/'" @@ -626,56 +650,99 @@ def test_vertexai__resolve_embedder( async def test_vertexai_list_actions(vertexai_plugin_instance: VertexAI) -> None: """Unit test for list actions.""" - class MockModel(BaseModel): - """mock.""" - + @dataclass + class MockModel: name: str + description: str = '' - models_return_value = [ - MockModel(name='publishers/google/models/model1'), - MockModel(name='publishers/google/models/model2_embeddings'), - MockModel(name='publishers/google/models/model3_embedder'), + [ + MockModel(name='publishers/google/models/gemini-1.5-flash'), + MockModel(name='publishers/google/models/text-embedding-004'), + MockModel(name='publishers/google/models/imagen-3.0-generate-001'), + MockModel(name='publishers/google/models/veo-2.0-generate-001'), ] mock_client = MagicMock() - mock_client.models.list.return_value = models_return_value + # Create sophisticated mocks that have supported_actions + m1 = MagicMock() + m1.name = 'publishers/google/models/gemini-1.5-flash' + m1.supported_actions = ['generateContent'] + m1.description = 'Gemini model' + + m2 = MagicMock() + m2.name = 'publishers/google/models/text-embedding-004' + m2.supported_actions = ['embedContent'] + m2.description = 'Embedder' + + m3 = MagicMock() + m3.name = 'publishers/google/models/imagen-3.0-generate-001' + m3.supported_actions = ['predict'] # Imagen uses predict + m3.description = 'Imagen' + + m4 = MagicMock() + m4.name = 'publishers/google/models/veo-2.0-generate-001' + m4.supported_actions = ['generateVideos'] # Veo uses generateVideos + m4.description = 'Veo' + + mock_client.models.list.return_value = [m1, m2, m3, m4] vertexai_plugin_instance._client = mock_client result = await vertexai_plugin_instance.list_actions() - assert result == [ - model_action_metadata( - name=vertexai_name('model1'), - info=google_model_info('model1').model_dump(), - config_schema=GeminiConfigSchema, - ), - embedder_action_metadata( - name=vertexai_name('model2_embeddings'), - options=EmbedderOptions( - label=default_embedder_info('model2_embeddings').get('label'), - supports=EmbedderSupports( - input=default_embedder_info('model2_embeddings').get('supports', {}).get('input') - ), - dimensions=default_embedder_info('model2_embeddings').get('dimensions'), - ), - ), - model_action_metadata( - name=vertexai_name('model2_embeddings'), - info=google_model_info('model2_embeddings').model_dump(), - config_schema=GeminiConfigSchema, - ), - embedder_action_metadata( - name=vertexai_name('model3_embedder'), - options=EmbedderOptions( - label=default_embedder_info('model3_embedder').get('label'), - supports=EmbedderSupports( - input=default_embedder_info('model3_embedder').get('supports', {}).get('input') - ), - dimensions=default_embedder_info('model3_embedder').get('dimensions'), - ), - ), - model_action_metadata( - name=vertexai_name('model3_embedder'), - info=google_model_info('model3_embedder').model_dump(), - config_schema=GeminiConfigSchema, - ), - ] + + # Verify Gemini + action1 = next(a for a in result if a.name == vertexai_name('gemini-1.5-flash')) + assert action1 is not None + + # Verify Embedder + action2 = next(a for a in result if a.name == vertexai_name('text-embedding-004')) + assert action2 is not None + + # Verify Imagen + action3 = next(a for a in result if a.name == vertexai_name('imagen-3.0-generate-001')) + assert action3 is not None + assert action3.kind == ActionKind.MODEL + + # Verify Veo + action4 = next(a for a in result if a.name == vertexai_name('veo-2.0-generate-001')) + assert action4 is not None + # from genkit.plugins.google_genai.models.veo import VeoConfigSchema + # assert action4.config_schema == VeoConfigSchema + + +def test_config_schema_extra_fields() -> None: + """Test that config schema accepts extra fields (dynamic config).""" + # Validation should succeed with unknown field by using model_validate for dynamic fields + # to avoid static type checker errors on constructor + config_data = {'temperature': 0.5, 'new_experimental_param': 'test'} + config = GeminiConfigSchema.model_validate(config_data) + + assert config.temperature == 0.5 + # Access dynamic fields via getattr or __dict__ to make type checker happy + assert config.new_experimental_param == 'test' # type: ignore + assert config.model_dump()['new_experimental_param'] == 'test' + + +def test_system_prompt_handling() -> None: + """Test that system prompts are correctly extracted to config.""" + from google import genai + + from genkit.plugins.google_genai.models.gemini import GeminiModel + + mock_client = MagicMock(spec=genai.Client) + model = GeminiModel(version='gemini-1.5-flash', client=mock_client) + + request = GenerateRequest( + messages=[ + Message(role=Role.SYSTEM, content=[Part(root=TextPart(text='You are a helpful assistant'))]), + Message(role=Role.USER, content=[Part(root=TextPart(text='Hello'))]), + ], + config=None, + ) + + cfg = model._genkit_to_googleai_cfg(request) + + assert cfg is not None + assert cfg.system_instruction is not None + assert cfg.system_instruction.parts is not None # type: ignore + assert len(cfg.system_instruction.parts) == 1 # type: ignore + assert cfg.system_instruction.parts[0].text == 'You are a helpful assistant' # type: ignore diff --git a/py/samples/google-genai-code-execution/src/main.py b/py/samples/google-genai-code-execution/src/main.py index ab695a4644..bdd5894c85 100755 --- a/py/samples/google-genai-code-execution/src/main.py +++ b/py/samples/google-genai-code-execution/src/main.py @@ -74,7 +74,7 @@ async def execute_code( """ response = await ai.generate( prompt=f'Generate and run code for the task: {task}', - config=GeminiConfigSchema(temperature=1, code_execution=True).model_dump(), + config=GeminiConfigSchema.model_validate({'temperature': 1, 'code_execution': True}).model_dump(), ) if not response.message: raise ValueError('No message returned from model') diff --git a/py/samples/google-genai-hello/src/main.py b/py/samples/google-genai-hello/src/main.py index ad113cf310..068a5ed2a2 100755 --- a/py/samples/google-genai-hello/src/main.py +++ b/py/samples/google-genai-hello/src/main.py @@ -53,9 +53,12 @@ import argparse import asyncio +import base64 import os import sys +from google import genai as google_genai_sdk + if sys.version_info < (3, 11): from strenum import StrEnum # pyright: ignore[reportUnreachable] else: @@ -75,7 +78,6 @@ GoogleAI, ) from genkit.types import ( - Embedding, GenerationCommonConfig, Media, MediaPart, @@ -94,7 +96,7 @@ ai = Genkit( plugins=[GoogleAI()], - model='googleai/gemini-3-flash-preview', + model='googleai/gemini-3-pro-preview', ) define_genkit_evaluators( @@ -129,23 +131,6 @@ class GablorkenInput(BaseModel): value: int = Field(description='value to calculate gablorken for') -class Skills(BaseModel): - """Skills for an RPG character.""" - - strength: int = Field(description='strength (0-100)') - charisma: int = Field(description='charisma (0-100)') - endurance: int = Field(description='endurance (0-100)') - - -class RpgCharacter(BaseModel): - """An RPG character.""" - - name: str = Field(description='name of the character') - back_story: str = Field(description='back story', alias='backStory') - abilities: list[str] = Field(description='list of abilities (3-4)') - skills: Skills - - class ThinkingLevel(StrEnum): """Thinking level enum.""" @@ -168,12 +153,6 @@ class WeatherInput(BaseModel): location: str = Field(description='The city and state, e.g. San Francisco, CA') -@ai.tool(name='celsiusToFahrenheit') -def celsius_to_fahrenheit(celsius: float) -> float: - """Converts Celsius to Fahrenheit.""" - return (celsius * 9) / 5 + 32 - - @ai.tool() def convert_currency(input: CurrencyInput) -> str: """Convert currency amount. @@ -199,22 +178,61 @@ def convert_currency(input: CurrencyInput) -> str: @ai.flow() -async def currency_exchange(input: CurrencyExchangeInput) -> str: - """Convert currency using tools. +async def simple_generate_with_interrupts(value: Annotated[int, Field(default=42)] = 42) -> str: + """Generate a greeting for the given name. Args: - input: Currency exchange parameters. + value: the integer to send to test function Returns: - Conversion result. + The generated response with a function. """ + response1 = await ai.generate( + messages=[ + Message( + role=Role.USER, + content=[Part(root=TextPart(text=f'what is a gablorken of {value}'))], + ), + ], + tools=['gablorkenTool2'], + ) + await logger.ainfo(f'len(response.tool_requests)={len(response1.tool_requests)}') + if len(response1.interrupts) == 0: + return response1.text + + tr = tool_response(response1.interrupts[0], {'output': 178}) response = await ai.generate( - prompt=f'Convert {input.amount} {input.from_curr} to {input.to_curr}', - tools=['convert_currency'], + messages=response1.messages, + tool_responses=[tr], + tools=['gablorkenTool'], ) return response.text +@ai.flow() +async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: + """Generate a greeting for the given name. + + Args: + name: the name to send to test function + + Returns: + The generated response with a function. + """ + resp = await ai.generate( + prompt=f'hi {name}', + ) + + await logger.ainfo( + 'generation_response', + has_usage=hasattr(resp, 'usage'), + usage_dict=resp.usage.model_dump() if hasattr(resp, 'usage') and resp.usage else None, + text_length=len(resp.text), + ) + + return resp.text + + @ai.flow() async def demo_dynamic_tools( input_val: Annotated[str, Field(default='Dynamic tools demo')] = 'Dynamic tools demo', @@ -254,60 +272,20 @@ def multiplier_fn(x: int) -> int: @ai.flow() -async def describe_image( - image_url: Annotated[ - str, Field(default='https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png') - ] = 'https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png', -) -> str: - """Describe an image.""" - response = await ai.generate( - model='googleai/gemini-3-flash-preview', - prompt=[ - Part(root=TextPart(text='Describe this image')), - Part(root=MediaPart(media=Media(url=image_url, content_type='image/png'))), - ], - config={'api_version': 'v1alpha'}, - ) - return response.text - - -@ai.flow() -async def embed_docs( - docs: list[str] | None = None, -) -> list[Embedding]: - """Generate an embedding for the words in a list. +async def describe_image() -> str: + """Describe an image (reads from photo.jpg).""" + # Read the photo.jpg file and encode to base64 + current_dir = os.path.dirname(os.path.abspath(__file__)) + photo_path = os.path.join(current_dir, '..', 'photo.jpg') - Args: - docs: list of texts (string) + with open(photo_path, 'rb') as photo_file: + photo_base64 = base64.b64encode(photo_file.read()).decode('utf-8') - Returns: - The generated embedding. - """ - if docs is None: - docs = ['Hello world', 'Genkit is great', 'Embeddings are fun'] - options = {'task_type': EmbeddingTaskType.CLUSTERING} - return await ai.embed_many( - embedder='googleai/text-embedding-004', - content=docs, - options=options, - ) - - -@ai.flow() -async def file_search() -> str: - """File Search.""" - # TODO: add file search store - store_name = 'fileSearchStores/sample-store' response = await ai.generate( - model='googleai/gemini-3-flash-preview', - prompt="What is the character's name in the story?", - config={ - 'file_search': { - 'file_search_store_names': [store_name], - 'metadata_filter': 'author=foo', - }, - 'api_version': 'v1alpha', - }, + prompt=[ + Part(root=TextPart(text='describe this photo')), + Part(root=MediaPart(media=Media(url=f'data:image/jpeg;base64,{photo_base64}', content_type='image/jpeg'))), + ], ) return response.text @@ -324,16 +302,25 @@ def gablorken_tool(input_: GablorkenInput) -> dict[str, int]: @ai.tool(name='gablorkenTool2') def gablorken_tool2(_input: GablorkenInput, ctx: ToolRunContext) -> None: - """The user-defined tool function. + """The user-defined tool function.""" + pass - Args: - input_: the input to the tool - ctx: the tool run context - Returns: - The calculated gablorken. - """ - ctx.interrupt() +class Skills(BaseModel): + """Skills for an RPG character.""" + + strength: int = Field(description='strength (0-100)') + charisma: int = Field(description='charisma (0-100)') + endurance: int = Field(description='endurance (0-100)') + + +class RpgCharacter(BaseModel): + """An RPG character.""" + + name: str = Field(description='name of the character') + back_story: str = Field(description='back story', alias='backStory') + abilities: list[str] = Field(description='list of abilities (3-4)') + skills: Skills @ai.flow() @@ -390,40 +377,6 @@ async def generate_character_unconstrained( return result.output -@ai.tool(name='getWeather') -def get_weather(input_: WeatherInput) -> dict[str, str | float]: - """Used to get current weather for a location.""" - return { - 'location': input_.location, - 'temperature_celcius': 21.5, - 'conditions': 'cloudy', - } - - -@ai.flow() -async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: - """Generate a greeting for the given name. - - Args: - name: the name to send to test function - - Returns: - The generated response with a function. - """ - resp = await ai.generate( - prompt=f'hi {name}', - ) - - await logger.ainfo( - 'generation_response', - has_usage=hasattr(resp, 'usage'), - usage_dict=resp.usage.model_dump() if hasattr(resp, 'usage') and resp.usage else None, - text_length=len(resp.text), - ) - - return resp.text - - @ai.flow() async def say_hi_stream( name: Annotated[str, Field(default='Alice')] = 'Alice', @@ -470,63 +423,25 @@ async def say_hi_with_configured_temperature( async def search_grounding() -> str: """Search grounding.""" response = await ai.generate( - model='googleai/gemini-3-flash-preview', prompt='Who is Albert Einstein?', - config={'tools': [{'googleSearch': {}}], 'api_version': 'v1alpha'}, + config={'google_search_retrieval': True}, ) return response.text @ai.flow() -async def simple_generate_with_interrupts(value: Annotated[int, Field(default=42)] = 42) -> str: +async def simple_generate_with_tools_flow(value: Annotated[int, Field(default=42)] = 42) -> str: """Generate a greeting for the given name. Args: value: the integer to send to test function - Returns: - The generated response with a function. - """ - response1 = await ai.generate( - messages=[ - Message( - role=Role.USER, - content=[Part(root=TextPart(text=f'what is a gablorken of {value}'))], - ), - ], - tools=['gablorkenTool2'], - ) - await logger.ainfo(f'len(response.tool_requests)={len(response1.tool_requests)}') - if len(response1.interrupts) == 0: - return response1.text - - tr = tool_response(response1.interrupts[0], {'output': 178}) - response = await ai.generate( - messages=response1.messages, - tool_responses=[tr], - tools=['gablorkenTool'], - ) - return response.text - - -@ai.flow() -async def simple_generate_with_tools_flow( - value: Annotated[int, Field(default=42)] = 42, - ctx: ActionRunContext | None = None, -) -> str: - """Generate a greeting for the given name. - - Args: - value: the integer to send to test function - ctx: the flow context - Returns: The generated response with a function. """ response = await ai.generate( prompt=f'what is a gablorken of {value}', tools=['gablorkenTool'], - on_chunk=ctx.send_chunk if ctx is not None else None, ) return response.text @@ -535,7 +450,6 @@ async def simple_generate_with_tools_flow( async def thinking_level_flash(_level: ThinkingLevelFlash) -> str: """Gemini 3.0 thinkingLevel config (Flash).""" response = await ai.generate( - model='googleai/gemini-3-flash-preview', prompt=( 'Alice, Bob, and Carol each live in a different house on the ' 'same street: red, green, and blue. The person who lives in the red house ' @@ -554,11 +468,19 @@ async def thinking_level_flash(_level: ThinkingLevelFlash) -> str: return response.text +class ThinkingLevelFlash(StrEnum): + """Thinking level flash enum.""" + + MINIMAL = 'MINIMAL' + LOW = 'LOW' + MEDIUM = 'MEDIUM' + HIGH = 'HIGH' + + @ai.flow() async def thinking_level_pro(_level: ThinkingLevel) -> str: """Gemini 3.0 thinkingLevel config (Pro).""" response = await ai.generate( - model='googleai/gemini-3-pro-preview', prompt=( 'Alice, Bob, and Carol each live in a different house on the ' 'same street: red, green, and blue. The person who lives in the red house ' @@ -577,23 +499,10 @@ async def thinking_level_pro(_level: ThinkingLevel) -> str: return response.text -@ai.flow() -async def tool_calling(location: Annotated[str, Field(default='Paris, France')] = 'Paris, France') -> str: - """Tool calling with Gemini.""" - response = await ai.generate( - model='googleai/gemini-2.5-flash', - tools=['getWeather', 'celsiusToFahrenheit'], - prompt=f"What's the weather in {location}? Convert the temperature to Fahrenheit.", - config=GenerationCommonConfig(temperature=1), - ) - return response.text - - @ai.flow() async def url_context() -> str: """Url context.""" response = await ai.generate( - model='googleai/gemini-3-flash-preview', prompt=( 'Compare the ingredients and cooking times from the recipes at https://www.foodnetwork.com/recipes/ina-garten/' 'perfect-roast-chicken-recipe-1940592 and https://www.allrecipes.com/recipe/70679/' @@ -604,25 +513,181 @@ async def url_context() -> str: return response.text +async def create_file_search_store(client: google_genai_sdk.Client) -> str: + """Creates a file search store.""" + file_search_store = await client.aio.file_search_stores.create() + if not file_search_store.name: + raise ValueError('File Search Store created without a name.') + return file_search_store.name + + +async def upload_blob_to_file_search_store(client: google_genai_sdk.Client, file_search_store_name: str) -> None: + """Uploads a blob to the file search store.""" + text_content = ( + 'The Whispering Woods In the heart of Eldergrove, there stood a forest whispered about by the villagers. ' + 'They spoke of trees that could talk and streams that sang. Young Elara, curious and adventurous, ' + 'decided to explore the woods one crisp autumn morning. As she wandered deeper, the leaves rustled with ' + 'excitement, revealing hidden paths. Elara noticed the trees bending slightly as if beckoning her to come ' + 'closer. When she paused to listen, she heard soft murmurs—stories of lost treasures and forgotten dreams. ' + 'Drawn by the enchanting sounds, she followed a narrow trail until she stumbled upon a shimmering pond. ' + 'At its edge, a wise old willow tree spoke, “Child of the village, what do you seek?” “I seek adventure,” ' + 'Elara replied, her heart racing. “Adventure lies not in faraway lands but within your spirit,” the willow ' + 'said, swaying gently. “Every choice you make is a step into the unknown.” With newfound courage, Elara left ' + 'the woods, her mind buzzing with possibilities. The villagers would say the woods were magical, but to Elara, ' + 'it was the spark of her imagination that had transformed her ordinary world into a realm of endless ' + 'adventures. She smiled, knowing her journey was just beginning' + ) + + # Create a temporary file to upload + import tempfile + + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as tmp: + tmp.write(text_content) + tmp_path = tmp.name + + try: + # Use the high-level helper to upload directly to the store with metadata + print(f'Uploading file to store {file_search_store_name}...') + op = await client.aio.file_search_stores.upload_to_file_search_store( + file_search_store_name=file_search_store_name, + file=tmp_path, + config={'custom_metadata': [{'key': 'author', 'string_value': 'foo'}]}, + ) + + # Poll for completion + while not op.done: + await asyncio.sleep(2) + # Fetch the updated operation status + op = await client.aio.operations.get(operation=op) + print(f'Operation status: {op.metadata.get("state") if op.metadata else "processing"}') + + print('Upload complete.') + + finally: + os.unlink(tmp_path) + return + + +async def delete_file_search_store(client: google_genai_sdk.Client, name: str) -> None: + """Deletes the file search store.""" + await client.aio.file_search_stores.delete(name=name, config={'force': True}) + + +@ai.flow() +async def file_search() -> str: + """File Search.""" + # Create a client using the same API Key as the plugin + api_key = os.environ.get('GEMINI_API_KEY') + client = google_genai_sdk.Client(api_key=api_key) + + # 1. Create Store + store_name = await create_file_search_store(client) + print(f'Created store: {store_name}') + + try: + # 2. Upload Blob (Story) + await upload_blob_to_file_search_store(client, store_name) + + # 3. Generate + response = await ai.generate( + prompt="What is the character's name in the story?", + config={ + 'file_search': { + 'file_search_store_names': [store_name], + 'metadata_filter': 'author=foo', + }, + 'api_version': 'v1alpha', + }, + ) + return response.text + finally: + # 4. Cleanup + await delete_file_search_store(client, store_name) + print(f'Deleted store: {store_name}') + + +@ai.flow() +async def embed_docs( + docs: list[str] | None = None, +) -> list[dict]: + """Generate an embedding for the words in a list. + + Args: + docs: list of texts (string) + + Returns: + The generated embeddings as serializable dicts. + """ + if docs is None: + docs = ['Hello world', 'Genkit is great', 'Embeddings are fun'] + options = {'task_type': EmbeddingTaskType.CLUSTERING} + embeddings = await ai.embed_many( + embedder='googleai/text-embedding-004', + content=docs, + options=options, + ) + # Serialize embeddings to dicts for JSON compatibility + return [emb.model_dump(by_alias=True) for emb in embeddings] + + @ai.flow() async def youtube_videos() -> str: """YouTube videos.""" response = await ai.generate( - model='googleai/gemini-3-flash-preview', prompt=[ Part(root=TextPart(text='transcribe this video')), Part( root=MediaPart(media=Media(url='https://www.youtube.com/watch?v=3p1P5grjXIQ', content_type='video/mp4')) ), ], - config={'api_version': 'v1alpha'}, + config={}, + ) + return response.text + + +class ScreenshotInput(BaseModel): + """Input for screenshot tool.""" + + url: str = Field(description='The URL to take a screenshot of') + + +@ai.tool(name='screenShot') +def take_screenshot(input_: ScreenshotInput) -> dict: + """Take a screenshot of a given URL.""" + # Implement your screenshot logic here + print(f'Taking screenshot of {input_.url}') + return {'url': input_.url, 'screenshot_path': '/tmp/screenshot.png'} + + +@ai.tool(name='getWeather') +def get_weather(input_: WeatherInput) -> dict: + """Used to get current weather for a location.""" + return { + 'location': input_.location, + 'temperature_celcius': 21.5, + 'conditions': 'cloudy', + } + + +@ai.tool(name='celsiusToFahrenheit') +def celsius_to_fahrenheit(celsius: float) -> float: + """Converts Celsius to Fahrenheit.""" + return (celsius * 9) / 5 + 32 + + +@ai.flow() +async def tool_calling(location: Annotated[str, Field(default='Paris, France')] = 'Paris, France') -> str: + """Tool calling with Gemini.""" + response = await ai.generate( + tools=['getWeather', 'celsiusToFahrenheit'], + prompt=f"What's the weather in {location}? Convert the temperature to Fahrenheit.", + config=GenerationCommonConfig(temperature=1), ) return response.text async def main() -> None: """Main function - keep alive for Dev UI.""" - logger.info('Genkit server running. Press Ctrl+C to stop.') # Keep the process alive for Dev UI _ = await asyncio.Event().wait() diff --git a/py/samples/google-genai-hello/src/main_vertexai.py b/py/samples/google-genai-hello/src/main_vertexai.py index cbcaee8467..f9179cacfd 100644 --- a/py/samples/google-genai-hello/src/main_vertexai.py +++ b/py/samples/google-genai-hello/src/main_vertexai.py @@ -191,10 +191,10 @@ async def gemini_image_editing() -> Media | None: Part(root=MediaPart(media=Media(url=f'data:image/png;base64,{plant_b64}'))), Part(root=MediaPart(media=Media(url=f'data:image/png;base64,{room_b64}'))), ], - config=GeminiImageConfigSchema( - response_modalities=['TEXT', 'IMAGE'], - image_config={'aspect_ratio': '1:1'}, - ).model_dump(exclude_none=True), + config=GeminiImageConfigSchema.model_validate({ + 'response_modalities': ['TEXT', 'IMAGE'], + 'image_config': {'aspect_ratio': '1:1'}, + }).model_dump(exclude_none=True), ) for part in response.message.content if response.message else []: diff --git a/py/samples/google-genai-image/src/main.py b/py/samples/google-genai-image/src/main.py index 845e34f628..5d3edf00eb 100755 --- a/py/samples/google-genai-image/src/main.py +++ b/py/samples/google-genai-image/src/main.py @@ -144,9 +144,10 @@ async def generate_images( return await ai.generate( model='googleai/gemini-3-pro-image-preview', prompt=f'tell me about {name} with photos', - config=GeminiConfigSchema(response_modalities=['text', 'image'], api_version='v1alpha').model_dump( - exclude_none=True - ), + config=GeminiConfigSchema.model_validate({ + 'response_modalities': ['text', 'image'], + 'api_version': 'v1alpha', + }).model_dump(exclude_none=True), ) @@ -193,11 +194,11 @@ async def gemini_image_editing() -> Media | None: Part(root=MediaPart(media=Media(url=f'data:image/png;base64,{plant_b64}'))), Part(root=MediaPart(media=Media(url=f'data:image/png;base64,{room_b64}'))), ], - config=GeminiImageConfigSchema( - response_modalities=['TEXT', 'IMAGE'], - image_config={'aspect_ratio': '1:1'}, - api_version='v1alpha', - ).model_dump(exclude_none=True), + config=GeminiImageConfigSchema.model_validate({ + 'response_modalities': ['TEXT', 'IMAGE'], + 'image_config': {'aspect_ratio': '1:1'}, + 'api_version': 'v1alpha', + }).model_dump(exclude_none=True), ) for part in response.message.content if response.message else []: if isinstance(part.root, MediaPart): diff --git a/py/samples/short-n-long/src/main.py b/py/samples/short-n-long/src/main.py index f45382aa65..4bbf18b6bf 100755 --- a/py/samples/short-n-long/src/main.py +++ b/py/samples/short-n-long/src/main.py @@ -370,7 +370,9 @@ async def generate_images( """ return await ai.generate( prompt='tell me a about the Eifel Tower with photos', - config=GeminiConfigSchema(response_modalities=['text', 'image']).model_dump(), + config=GeminiConfigSchema.model_validate({ + 'response_modalities': ['text', 'image'], + }).model_dump(), )