diff --git a/.github/.release-please-manifest-v1.json b/.github/.release-please-manifest-v1.json new file mode 100644 index 0000000000..9fa2aaa5ed --- /dev/null +++ b/.github/.release-please-manifest-v1.json @@ -0,0 +1,3 @@ +{ + ".": "1.35.0" +} diff --git a/.github/.release-please-manifest-v2.json b/.github/.release-please-manifest-v2.json deleted file mode 100644 index 0739396e62..0000000000 --- a/.github/.release-please-manifest-v2.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - ".": "2.0.0-alpha.0" -} diff --git a/.github/release-please-config-v2.json b/.github/release-please-config-v1.json similarity index 90% rename from .github/release-please-config-v2.json rename to .github/release-please-config-v1.json index 6947d9e15a..9668c5c115 100644 --- a/.github/release-please-config-v2.json +++ b/.github/release-please-config-v1.json @@ -3,13 +3,11 @@ "packages": { ".": { "release-type": "python", - "versioning": "prerelease", - "prerelease": true, - "prerelease-type": "alpha", + "versioning": "default", "package-name": "google-adk", "include-component-in-tag": false, "skip-github-release": true, - "changelog-path": "CHANGELOG-v2.md", + "changelog-path": "CHANGELOG.md", "changelog-sections": [ { "type": "feat", @@ -58,5 +56,6 @@ } ] } - } + }, + "last-release-sha": "ecfdaf5f5e7accfbb4294cb8cc56c910dad2b1a8" } diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 23a032f87a..f18020a86b 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,13 +16,13 @@ name: Pre-commit Checks on: push: - branches: [main, v2] + branches: [main, v1, v2] paths: - '**.py' - '.pre-commit-config.yaml' - 'pyproject.toml' pull_request: - branches: [main, v2] + branches: [main, v1, v2] paths: - '**.py' - '.pre-commit-config.yaml' diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 6e204a8e67..457460db9d 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -16,9 +16,9 @@ name: Python Unit Tests on: push: - branches: [ main ] + branches: [ main, v1 ] pull_request: - branches: [ main ] + branches: [ main, v1 ] permissions: contents: read diff --git a/.github/workflows/release-v2-cherry-pick.yml b/.github/workflows/release-v1-cherry-pick.yml similarity index 75% rename from .github/workflows/release-v2-cherry-pick.yml rename to .github/workflows/release-v1-cherry-pick.yml index f5641a55b3..91e858590b 100644 --- a/.github/workflows/release-v2-cherry-pick.yml +++ b/.github/workflows/release-v1-cherry-pick.yml @@ -1,7 +1,7 @@ -# Step 3 (v2, optional): Cherry-picks a commit from v2 to the release/v2-candidate branch. +# Step 3 (v1, optional): Cherry-picks a commit from v1 to the release/v1-candidate branch. # Use between step 1 and step 4 to include bug fixes in an in-progress release. # Note: Does NOT auto-trigger release-please to preserve manual changelog edits. -name: "Release v2: Cherry-pick" +name: "Release v1: Cherry-pick" on: workflow_dispatch: @@ -20,7 +20,7 @@ jobs: steps: - uses: actions/checkout@v6 with: - ref: release/v2-candidate + ref: release/v1-candidate fetch-depth: 0 - name: Configure git @@ -30,14 +30,14 @@ jobs: - name: Cherry-pick commit run: | - echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/v2-candidate" + echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/v1-candidate" git cherry-pick ${INPUTS_COMMIT_SHA} env: INPUTS_COMMIT_SHA: ${{ inputs.commit_sha }} - name: Push changes run: | - git push origin release/v2-candidate - echo "Successfully cherry-picked commit to release/v2-candidate" + git push origin release/v1-candidate + echo "Successfully cherry-picked commit to release/v1-candidate" echo "Note: Release Please is NOT auto-triggered to preserve manual changelog edits." - echo "Run release-v2-please.yml manually if you want to regenerate the changelog." + echo "Run release-v1-please.yml manually if you want to regenerate the changelog." diff --git a/.github/workflows/release-v1-cut.yml b/.github/workflows/release-v1-cut.yml new file mode 100644 index 0000000000..7a35b6a120 --- /dev/null +++ b/.github/workflows/release-v1-cut.yml @@ -0,0 +1,46 @@ +# Step 1 (v1): Starts the v1 release process by creating a release/v1-candidate branch. +# Generates a changelog PR for review (step 2). +name: "Release v1: Cut" + +on: + workflow_dispatch: + inputs: + commit_sha: + description: 'Commit SHA to cut from (leave empty for latest v1)' + required: false + type: string + +permissions: + contents: write + actions: write + +jobs: + cut-release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + ref: ${{ inputs.commit_sha || 'v1' }} + + - name: Check for existing release/v1-candidate branch + env: + GH_TOKEN: ${{ github.token }} + run: | + if git ls-remote --exit-code --heads origin release/v1-candidate &>/dev/null; then + echo "Error: release/v1-candidate branch already exists" + echo "Please finalize or delete the existing release candidate before starting a new one" + exit 1 + fi + + - name: Create and push release/v1-candidate branch + run: | + git checkout -b release/v1-candidate + git push origin release/v1-candidate + echo "Created branch: release/v1-candidate" + + - name: Trigger Release Please + env: + GH_TOKEN: ${{ github.token }} + run: | + gh workflow run release-v1-please.yml --repo ${{ github.repository }} --ref release/v1-candidate + echo "Triggered Release Please workflow for v1" diff --git a/.github/workflows/release-v2-finalize.yml b/.github/workflows/release-v1-finalize.yml similarity index 71% rename from .github/workflows/release-v2-finalize.yml rename to .github/workflows/release-v1-finalize.yml index c8b1020948..58ce170563 100644 --- a/.github/workflows/release-v2-finalize.yml +++ b/.github/workflows/release-v1-finalize.yml @@ -1,12 +1,12 @@ -# Step 4 (v2): Triggers when the changelog PR is merged to release/v2-candidate. -# Records last-release-sha and renames release/v2-candidate to release/v{version}. -name: "Release v2: Finalize" +# Step 4 (v1): Triggers when the changelog PR is merged to release/v1-candidate. +# Records last-release-sha and renames release/v1-candidate to release/v{version}. +name: "Release v1: Finalize" on: pull_request: types: [closed] branches: - - release/v2-candidate + - release/v1-candidate permissions: contents: write @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v6 if: steps.check.outputs.is_release_pr == 'true' with: - ref: release/v2-candidate + ref: release/v1-candidate token: ${{ secrets.RELEASE_PAT }} fetch-depth: 0 @@ -40,7 +40,7 @@ jobs: if: steps.check.outputs.is_release_pr == 'true' id: version run: | - VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v2.json) + VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v1.json) echo "version=$VERSION" >> $GITHUB_OUTPUT echo "Extracted version: $VERSION" @@ -56,21 +56,21 @@ jobs: - name: Record last-release-sha for release-please if: steps.check.outputs.is_release_pr == 'true' run: | - git fetch origin v2 - CUT_SHA=$(git merge-base origin/v2 HEAD) - echo "Release was cut from v2 at: $CUT_SHA" + git fetch origin v1 + CUT_SHA=$(git merge-base origin/v1 HEAD) + echo "Release was cut from v1 at: $CUT_SHA" jq --arg sha "$CUT_SHA" '. + {"last-release-sha": $sha}' \ - .github/release-please-config-v2.json > tmp.json && mv tmp.json .github/release-please-config-v2.json - git add .github/release-please-config-v2.json - git commit -m "chore: update last-release-sha for next v2 release" - git push origin release/v2-candidate + .github/release-please-config-v1.json > tmp.json && mv tmp.json .github/release-please-config-v1.json + git add .github/release-please-config-v1.json + git commit -m "chore: update last-release-sha for next v1 release" + git push origin release/v1-candidate - - name: Rename release/v2-candidate to release/v{version} + - name: Rename release/v1-candidate to release/v{version} if: steps.check.outputs.is_release_pr == 'true' run: | VERSION="v${STEPS_VERSION_OUTPUTS_VERSION}" - git push origin "release/v2-candidate:refs/heads/release/$VERSION" ":release/v2-candidate" - echo "Renamed release/v2-candidate to release/$VERSION" + git push origin "release/v1-candidate:refs/heads/release/$VERSION" ":release/v1-candidate" + echo "Renamed release/v1-candidate to release/$VERSION" env: STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} @@ -79,6 +79,7 @@ jobs: env: GH_TOKEN: ${{ github.token }} run: | + gh label create "autorelease: tagged" --color "EDEDED" --description "Tagged release" || true gh pr edit ${{ github.event.pull_request.number }} \ --remove-label "autorelease: pending" \ --add-label "autorelease: tagged" diff --git a/.github/workflows/release-v2-please.yml b/.github/workflows/release-v1-please.yml similarity index 60% rename from .github/workflows/release-v2-please.yml rename to .github/workflows/release-v1-please.yml index d659ca1c4e..9f1344dd31 100644 --- a/.github/workflows/release-v2-please.yml +++ b/.github/workflows/release-v1-please.yml @@ -1,10 +1,10 @@ -# Runs release-please to create/update a PR with version bump and changelog for v2. -# Triggered only by workflow_dispatch (from release-v2-cut.yml). +# Runs release-please to create/update a PR with version bump and changelog for v1. +# Triggered only by workflow_dispatch (from release-v1-cut.yml). # Does NOT auto-run on push to preserve manual changelog edits after cherry-picks. -name: "Release v2: Please" +name: "Release v1: Please" on: - # Only run via workflow_dispatch (triggered by release-v2-cut.yml) + # Only run via workflow_dispatch (triggered by release-v1-cut.yml) workflow_dispatch: permissions: @@ -15,27 +15,27 @@ jobs: release-please: runs-on: ubuntu-latest steps: - - name: Check if release/v2-candidate still exists + - name: Check if release/v1-candidate still exists id: check env: GH_TOKEN: ${{ github.token }} run: | - if gh api repos/${{ github.repository }}/branches/release/v2-candidate --silent 2>/dev/null; then + if gh api repos/${{ github.repository }}/branches/release/v1-candidate --silent 2>/dev/null; then echo "exists=true" >> $GITHUB_OUTPUT else - echo "release/v2-candidate branch no longer exists, skipping" + echo "release/v1-candidate branch no longer exists, skipping" echo "exists=false" >> $GITHUB_OUTPUT fi - uses: actions/checkout@v6 if: steps.check.outputs.exists == 'true' with: - ref: release/v2-candidate + ref: release/v1-candidate - uses: googleapis/release-please-action@v4 if: steps.check.outputs.exists == 'true' with: token: ${{ secrets.RELEASE_PAT }} - config-file: .github/release-please-config-v2.json - manifest-file: .github/.release-please-manifest-v2.json - target-branch: release/v2-candidate + config-file: .github/release-please-config-v1.json + manifest-file: .github/.release-please-manifest-v1.json + target-branch: release/v1-candidate diff --git a/.github/workflows/release-v2-publish.yml b/.github/workflows/release-v1-publish.yml similarity index 81% rename from .github/workflows/release-v2-publish.yml rename to .github/workflows/release-v1-publish.yml index 41edc78d9e..a4f3b1419f 100644 --- a/.github/workflows/release-v2-publish.yml +++ b/.github/workflows/release-v1-publish.yml @@ -1,8 +1,8 @@ -# Step 6 (v2): Builds and publishes the v2 package to PyPI from a release/v{version} branch. -# Reads version from .release-please-manifest-v2.json, converts to PEP 440, +# Step 6 (v1): Builds and publishes the v1 package to PyPI from a release/v{version} branch. +# Reads version from .release-please-manifest-v1.json, converts to PEP 440, # updates version.py, then builds and publishes. -# Creates a merge-back PR (step 7) to sync release changes to v2. -name: "Release v2: Publish to PyPi" +# Creates a merge-back PR (step 7) to sync release changes to v1. +name: "Release v1: Publish to PyPi" on: workflow_dispatch: @@ -18,7 +18,7 @@ jobs: - name: Validate branch run: | if [[ ! "${GITHUB_REF_NAME}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+ ]]; then - echo "Error: Must run from a release/v* branch (e.g., release/v2.0.0-alpha.2)" + echo "Error: Must run from a release/v* branch (e.g., release/v1.34.1)" exit 1 fi @@ -27,15 +27,15 @@ jobs: - name: Extract version from manifest and convert to PEP 440 id: version run: | - VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v2.json) + VERSION=$(jq -r '.["."]' .github/.release-please-manifest-v1.json) echo "semver=$VERSION" >> $GITHUB_OUTPUT echo "Semver version: $VERSION" # Convert semver pre-release to PEP 440: - # 2.0.0-alpha.1 -> 2.0.0a1 - # 2.0.0-beta.1 -> 2.0.0b1 - # 2.0.0-rc.1 -> 2.0.0rc1 - # 2.0.0 -> 2.0.0 (no change for stable) + # 1.35.0-alpha.1 -> 1.35.0a1 + # 1.35.0-beta.1 -> 1.35.0b1 + # 1.35.0-rc.1 -> 1.35.0rc1 + # 1.35.0 -> 1.35.0 (no change for stable) PEP440=$(echo "$VERSION" | sed -E 's/-alpha\./a/; s/-beta\./b/; s/-rc\./rc/') echo "pep440=$PEP440" >> $GITHUB_OUTPUT echo "PEP 440 version: $PEP440" @@ -73,7 +73,7 @@ jobs: PEP440_VERSION: ${{ steps.version.outputs.pep440 }} run: | gh pr create \ - --base v2 \ + --base v1 \ --head "${GITHUB_REF_NAME}" \ - --title "chore: merge release v${PEP440_VERSION} to v2" \ - --body "Syncs version bump and CHANGELOG from release v${SEMVER_VERSION} to v2." + --title "chore: merge release v${PEP440_VERSION} to v1" \ + --body "Syncs version bump and CHANGELOG from release v${SEMVER_VERSION} to v1." diff --git a/.github/workflows/release-v2-cut.yml b/.github/workflows/release-v2-cut.yml deleted file mode 100644 index 52af5bf038..0000000000 --- a/.github/workflows/release-v2-cut.yml +++ /dev/null @@ -1,46 +0,0 @@ -# Step 1 (v2): Starts the v2 release process by creating a release/v2-candidate branch. -# Generates a changelog PR for review (step 2). -name: "Release v2: Cut" - -on: - workflow_dispatch: - inputs: - commit_sha: - description: 'Commit SHA to cut from (leave empty for latest v2)' - required: false - type: string - -permissions: - contents: write - actions: write - -jobs: - cut-release: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - with: - ref: ${{ inputs.commit_sha || 'v2' }} - - - name: Check for existing release/v2-candidate branch - env: - GH_TOKEN: ${{ github.token }} - run: | - if git ls-remote --exit-code --heads origin release/v2-candidate &>/dev/null; then - echo "Error: release/v2-candidate branch already exists" - echo "Please finalize or delete the existing release candidate before starting a new one" - exit 1 - fi - - - name: Create and push release/v2-candidate branch - run: | - git checkout -b release/v2-candidate - git push origin release/v2-candidate - echo "Created branch: release/v2-candidate" - - - name: Trigger Release Please - env: - GH_TOKEN: ${{ github.token }} - run: | - gh workflow run release-v2-please.yml --repo ${{ github.repository }} --ref release/v2-candidate - echo "Triggered Release Please workflow for v2" diff --git a/.github/workflows/v2-sync.yml b/.github/workflows/v2-sync.yml deleted file mode 100644 index c627f40d46..0000000000 --- a/.github/workflows/v2-sync.yml +++ /dev/null @@ -1,59 +0,0 @@ -# Automatically creates a PR to merge main (v1) into v2 to keep v2 up to date. -# The oncall is responsible for reviewing and merging the sync PR. -name: "Sync: main -> v2" - -on: - schedule: - - cron: '0 6 * * *' # Daily at 6am UTC - workflow_dispatch: - -permissions: - contents: write - pull-requests: write - -jobs: - sync: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - with: - ref: v2 - fetch-depth: 0 - token: ${{ secrets.RELEASE_PAT }} - - - name: Check for new commits on main - id: check - run: | - git fetch origin main - BEHIND=$(git rev-list --count HEAD..origin/main) - echo "behind=$BEHIND" >> $GITHUB_OUTPUT - if [ "$BEHIND" -eq 0 ]; then - echo "v2 is up to date with main, nothing to sync" - else - echo "v2 is $BEHIND commit(s) behind main" - fi - - - name: Check for existing sync PR - if: steps.check.outputs.behind != '0' - id: existing - env: - GH_TOKEN: ${{ github.token }} - run: | - PR=$(gh pr list --base v2 --head main --state open --json number --jq '.[0].number // empty') - if [ -n "$PR" ]; then - echo "Sync PR #$PR already exists, skipping" - echo "exists=true" >> $GITHUB_OUTPUT - else - echo "exists=false" >> $GITHUB_OUTPUT - fi - - - name: Create sync PR - if: steps.check.outputs.behind != '0' && steps.existing.outputs.exists == 'false' - env: - GH_TOKEN: ${{ secrets.RELEASE_PAT }} - run: | - gh pr create \ - --base v2 \ - --head main \ - --title "chore: sync main -> v2" \ - --body "Automated sync of v1 changes from main into v2. The oncall is responsible for reviewing and merging this PR. Resolve conflicts in favor of the v2 implementation." diff --git a/CHANGELOG.md b/CHANGELOG.md index 08799f7a13..224327bcfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,38 @@ # Changelog +## [1.35.0](https://github.com/google/adk-python/compare/v1.34.2...v1.35.0) (2026-06-10) + + +### Features + +* **live:** Handle input transcription differently for Gemini Live 3.1 models ([#6045](https://github.com/google/adk-python/issues/6045)) ([ecfdaf5](https://github.com/google/adk-python/commit/ecfdaf5f5e7accfbb4294cb8cc56c910dad2b1a8)) + + +### Bug Fixes + +* add missing Gemini imports in base_llm_flow ([#5943](https://github.com/google/adk-python/issues/5943)) ([6d027b4](https://github.com/google/adk-python/commit/6d027b4ce8bc1c5d288b02e1e3819917117038ec)) +* **flows:** Reset reconnect attempts on connection success ([#6042](https://github.com/google/adk-python/issues/6042)) ([87abf23](https://github.com/google/adk-python/commit/87abf230dbc21b49fa5606e18627c7f62df0d37b)) +* **models:** Default grounding metadata for Gemini 3.1 live ([#6018](https://github.com/google/adk-python/issues/6018)) ([fafafb3](https://github.com/google/adk-python/commit/fafafb38e1027a5cfe185357f6b8a107bbd3779e)) +* Support generalized history config injection for Gemini 3.1 Live on Vertex AI ([#5999](https://github.com/google/adk-python/issues/5999)) ([aafd97f](https://github.com/google/adk-python/commit/aafd97f6f0ae114b0ca772b4f5176602e3677e79)) + +## [1.34.2](https://github.com/google/adk-python/compare/v1.34.1...v1.34.2) (2026-06-01) + + +### Bug Fixes + +* Fix bug where grounding metadata in Gemini 3.1 live was being silently discarded ([9b6b9e9](https://github.com/google/adk-python/commit/9b6b9e976300ab223c77075554d6cd66ce1179ff)) +* fix input and output transcription finished events for Gemini v3.1 ([13763d7](https://github.com/google/adk-python/commit/13763d71f883b215dae08feb3f042869b9cd5d18)) +* **tools:** Prevent session drop on MCP tool error ([1fd406b](https://github.com/google/adk-python/commit/1fd406b90ae00c59d84093c33bc04530825bc760)) + +## [1.34.1](https://github.com/google/adk-python/compare/v1.34.0...v1.34.1) (2026-05-22) + + +### Bug Fixes + +* Fix bug where grounding metadata in Gemini 3.1 live was being silently discarded ([9b6b9e9](https://github.com/google/adk-python/commit/9b6b9e976300ab223c77075554d6cd66ce1179ff)) +* fix input and output transcription finished events for Gemini v3.1 ([13763d7](https://github.com/google/adk-python/commit/13763d71f883b215dae08feb3f042869b9cd5d18)) +* **tools:** Prevent session drop on MCP tool error ([1fd406b](https://github.com/google/adk-python/commit/1fd406b90ae00c59d84093c33bc04530825bc760)) + ## [1.34.0](https://github.com/google/adk-python/compare/v1.33.0...v1.34.0) (2026-05-18) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index e059cd957d..8126ac5bf3 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -247,6 +247,9 @@ class RunConfig(BaseModel): session_resumption: Optional[types.SessionResumptionConfig] = None """Configures session resumption mechanism. Only support transparent session resumption mode now.""" + history_config: Optional[types.HistoryConfig] = None + """Configures the exchange of history between the client and the server.""" + context_window_compression: Optional[types.ContextWindowCompressionConfig] = ( None ) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index b5f51f2825..5f67e16607 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -138,7 +138,7 @@ class FeatureConfig: FeatureStage.WIP, default_on=False ), FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig( - FeatureStage.EXPERIMENTAL, default_on=False + FeatureStage.EXPERIMENTAL, default_on=True ), FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db897637c3..39a66244ca 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -39,6 +39,8 @@ from ...auth.auth_tool import AuthConfig from ...events.event import Event from ...models.base_llm_connection import BaseLlmConnection +from ...models.google_llm import Gemini +from ...models.google_llm import GoogleLLMVariant from ...models.llm_request import LlmRequest from ...models.llm_response import LlmResponse from ...telemetry import tracing @@ -47,6 +49,7 @@ from ...telemetry.tracing import tracer from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext +from ...utils import model_name_utils from ...utils.context_utils import Aclosing from .audio_cache_manager import AudioCacheManager from .functions import build_auth_request_event @@ -514,13 +517,47 @@ async def run_live( llm_request.live_connect_config.session_resumption.handle = ( invocation_context.live_session_resumption_handle ) - llm_request.live_connect_config.session_resumption.transparent = True + # Only set transparent=True for Vertex AI backend, as the Gemini API + # backend explicitly rejects it. + if ( + isinstance(llm, Gemini) + and llm._api_backend == GoogleLLMVariant.VERTEX_AI # pylint: disable=protected-access + ): + session_resumption = ( + llm_request.live_connect_config.session_resumption + ) + if session_resumption.transparent is None: + session_resumption.transparent = True + + # When seeding a fresh connection with prior conversation history, set + # initial_history_in_client_content to True. This tells the Live server + # that the provided history already includes the model's past responses, + # preventing the server from generating duplicate responses for those replayed turns. + if ( + llm_request.contents + and not invocation_context.live_session_resumption_handle + ): + if not llm_request.live_connect_config: + llm_request.live_connect_config = types.LiveConnectConfig() + if not llm_request.live_connect_config.history_config: + llm_request.live_connect_config.history_config = ( + types.HistoryConfig() + ) + if ( + llm_request.live_connect_config.history_config.initial_history_in_client_content + is None + ): + llm_request.live_connect_config.history_config.initial_history_in_client_content = ( + True + ) logger.info( 'Establishing live connection for agent: %s', invocation_context.agent.name, ) async with llm.connect(llm_request) as llm_connection: + # Reset attempt counter on successful connection. + attempt = 1 # Skip sending history if we are resuming a session. The server # already has the state associated with the resumption handle. if ( @@ -550,8 +587,6 @@ async def run_live( ) ) as agen: async for event in agen: - # Reset attempt counter on successful communication. - attempt = 1 # Empty event means the queue is closed. if not event: break @@ -976,6 +1011,7 @@ async def _postprocess_async( not llm_response.content and not llm_response.error_code and not llm_response.interrupted + and not llm_response.grounding_metadata ): return @@ -1040,6 +1076,7 @@ async def _postprocess_live( and not llm_response.output_transcription and not llm_response.usage_metadata and not llm_response.live_session_resumption_update + and not llm_response.grounding_metadata ): return diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 8e9bfa514c..50f03d0bf1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -25,6 +25,7 @@ from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest +from ...utils import model_name_utils from ...utils.output_schema_utils import can_use_output_schema_with_tools from ._base_llm_processor import BaseLlmRequestProcessor @@ -78,15 +79,25 @@ def _build_basic_request( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) + active_model_name = ( + getattr(getattr(agent, 'canonical_live_model', None), 'model', None) + or llm_request.model + ) + is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name) llm_request.live_connect_config.enable_affective_dialog = ( - invocation_context.run_config.enable_affective_dialog + None + if is_gemini_31 + else invocation_context.run_config.enable_affective_dialog ) llm_request.live_connect_config.proactivity = ( - invocation_context.run_config.proactivity + None if is_gemini_31 else invocation_context.run_config.proactivity ) llm_request.live_connect_config.session_resumption = ( invocation_context.run_config.session_resumption ) + llm_request.live_connect_config.history_config = ( + invocation_context.run_config.history_config + ) llm_request.live_connect_config.context_window_compression = ( invocation_context.run_config.context_window_compression ) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 11ed8386e1..ff2f81e657 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -50,6 +50,9 @@ def __init__( self._output_transcription_text: str = '' self._api_backend = api_backend self._model_version = model_version + self._is_gemini_3_1_flash_live = model_name_utils.is_gemini_3_1_flash_live( + model_version + ) async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -80,10 +83,30 @@ async def send_history(self, history: list[types.Content]): ] if contents: + # Gemini Enterprise Agent Platform does not support history_config in the SDK. + # To initialize a live session with prior history without hitting a 1007 + # protocol error (invalid role mid-session), we consolidate previous multi-turn + # interactions into a unified contextual preamble on a single user role turn. + if ( + self._is_gemini_3_1_flash_live + and self._api_backend != GoogleLLMVariant.GEMINI_API + ): + collapsed_text = 'Previous conversation history:\n' + for c in contents: + text_parts = ''.join(p.text for p in c.parts if p.text) + collapsed_text += f'[{c.role}]: {text_parts}\n' + contents = [ + types.Content( + role='user', parts=[types.Part.from_text(text=collapsed_text)] + ) + ] + logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send_client_content( turns=contents, - turn_complete=contents[-1].role == 'user', + turn_complete=True + if self._is_gemini_3_1_flash_live + else (contents[-1].role == 'user'), ) else: logger.info('no content is sent') @@ -108,10 +131,11 @@ async def send_content(self, content: types.Content): ) else: logger.debug('Sending LLM new content %s', content) - is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live( - self._model_version - ) - if is_gemini_31 and len(content.parts) == 1 and content.parts[0].text: + if ( + self._is_gemini_3_1_flash_live + and len(content.parts) == 1 + and content.parts[0].text + ): logger.debug('Using send_realtime_input for Gemini 3.1 text input') await self._gemini_session.send_realtime_input( text=content.parts[0].text @@ -133,10 +157,7 @@ async def send_realtime(self, input: RealtimeInput): if isinstance(input, types.Blob): # The blob is binary and is very large. So let's not log it. logger.debug('Sending LLM Blob.') - is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live( - self._model_version - ) - if is_gemini_31: + if self._is_gemini_3_1_flash_live: if input.mime_type and input.mime_type.startswith('audio/'): await self._gemini_session.send_realtime_input(audio=input) elif input.mime_type and input.mime_type.startswith('image/'): @@ -159,7 +180,12 @@ async def send_realtime(self, input: RealtimeInput): else: raise ValueError('Unsupported input type: %s' % type(input)) - def __build_full_text_response(self, text: str): + def __build_full_text_response( + self, + text: str, + is_thought: bool = False, + grounding_metadata: types.GroundingMetadata | None = None, + ): """Builds a full text response. The text should not be partial and the returned LlmResponse is not @@ -167,15 +193,24 @@ def __build_full_text_response(self, text: str): Args: text: The text to be included in the response. + is_thought: Whether the text is a thought. + grounding_metadata: The grounding metadata to include. Returns: An LlmResponse containing the full text. """ + part = types.Part.from_text(text=text) + if is_thought: + part.thought = True + if grounding_metadata is None and self._is_gemini_3_1_flash_live: + grounding_metadata = types.GroundingMetadata() return LlmResponse( content=types.Content( role='model', - parts=[types.Part.from_text(text=text)], + parts=[part], ), + grounding_metadata=grounding_metadata, + partial=False, live_session_id=self._gemini_session.session_id, ) @@ -187,7 +222,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: """ text = '' + is_thought = False tool_call_parts = [] + pending_grounding_metadata = None async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -203,6 +240,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) if message.server_content: content = message.server_content.model_turn + if message.server_content.grounding_metadata: + pending_grounding_metadata = ( + message.server_content.grounding_metadata + ) # Standalone grounding_metadata event (when content is empty) if ( @@ -215,6 +256,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) if content and content.parts: @@ -223,51 +267,83 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) # grounding_metadata is yielded again at turn_complete, # so avoid duplicating it here if turn_complete is true. if not message.server_content.turn_complete: - llm_response.grounding_metadata = ( - message.server_content.grounding_metadata - ) - if content.parts[0].text: - text += content.parts[0].text - llm_response.partial = True - # don't yield the merged text event when receiving audio data - elif text and not content.parts[0].inline_data: - yield self.__build_full_text_response(text) + if message.server_content.grounding_metadata is not None: + llm_response.grounding_metadata = ( + message.server_content.grounding_metadata + ) + elif self._is_gemini_3_1_flash_live: + llm_response.grounding_metadata = types.GroundingMetadata() + has_inline_data = any(p.inline_data for p in content.parts) + for part in content.parts: + if part.text: + current_is_thought = getattr(part, 'thought', False) + if text and current_is_thought != is_thought: + yield self.__build_full_text_response(text, is_thought) + text = '' + is_thought = False + + text += part.text + is_thought = current_is_thought + llm_response.partial = True + if ( + text + and not any(p.text for p in content.parts) + and not has_inline_data + ): + yield self.__build_full_text_response(text, is_thought) text = '' yield llm_response # Note: in some cases, tool_call may arrive before # generation_complete, causing transcription to appear after # tool_call in the session log. if message.server_content.input_transcription: - if message.server_content.input_transcription.text: - self._input_transcription_text += ( - message.server_content.input_transcription.text - ) - yield LlmResponse( - input_transcription=types.Transcription( - text=message.server_content.input_transcription.text, - finished=False, - ), - partial=True, - model_version=self._model_version, - live_session_id=live_session_id, - ) - # finished=True and partial transcription may happen in the same - # message. - if message.server_content.input_transcription.finished: - yield LlmResponse( - input_transcription=types.Transcription( - text=self._input_transcription_text, - finished=True, - ), - partial=False, - model_version=self._model_version, - live_session_id=live_session_id, - ) - self._input_transcription_text = '' + # Gemini 3.1 Flash Live only sends a single final input + # transcription + if self._is_gemini_3_1_flash_live: + if message.server_content.input_transcription.text: + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=True, + ), + partial=False, + model_version=self._model_version, + live_session_id=live_session_id, + ) + else: + if message.server_content.input_transcription.text: + self._input_transcription_text += ( + message.server_content.input_transcription.text + ) + yield LlmResponse( + input_transcription=types.Transcription( + text=message.server_content.input_transcription.text, + finished=False, + ), + partial=True, + model_version=self._model_version, + live_session_id=live_session_id, + ) + # finished=True and partial transcription may happen in the same + # message. + if message.server_content.input_transcription.finished: + yield LlmResponse( + input_transcription=types.Transcription( + text=self._input_transcription_text, + finished=True, + ), + partial=False, + model_version=self._model_version, + live_session_id=live_session_id, + ) + self._input_transcription_text = '' if message.server_content.output_transcription: if message.server_content.output_transcription.text: self._output_transcription_text += ( @@ -293,10 +369,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: live_session_id=live_session_id, ) self._output_transcription_text = '' - # The Gemini API might not send a transcription finished signal. + # The Gemini API or Vertex AI might not send a transcription finished signal. # Instead, we rely on generation_complete, turn_complete or # interrupted signals to flush any pending transcriptions. - if self._api_backend == GoogleLLMVariant.GEMINI_API and ( + if ( message.server_content.interrupted or message.server_content.turn_complete or message.server_content.generation_complete @@ -324,9 +400,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._output_transcription_text = '' if message.server_content.turn_complete: + g_metadata_to_yield = pending_grounding_metadata if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response( + text, is_thought, g_metadata_to_yield + ) text = '' + is_thought = False + g_metadata_to_yield = None if tool_call_parts: logger.debug('Returning aggregated tool_call_parts') yield LlmResponse( @@ -338,9 +419,18 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, - grounding_metadata=message.server_content.grounding_metadata, + grounding_metadata=message.server_content.grounding_metadata + or g_metadata_to_yield + or ( + types.GroundingMetadata() + if self._is_gemini_3_1_flash_live + else None + ), model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) break # in case of empty content or parts, we still surface it @@ -371,10 +461,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # deadlocking the conversation. Other models (e.g. 2.5-pro, # native-audio) send turn_complete after tool calls, so buffer # and merge them into a single response at turn_complete. - if ( - model_name_utils.is_gemini_3_1_flash_live(self._model_version) - and tool_call_parts - ): + if self._is_gemini_3_1_flash_live and tool_call_parts: logger.debug( 'Yielding tool_call_parts immediately for Gemini 3.1 live tool' ' call' @@ -383,6 +470,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: content=types.Content(role='model', parts=tool_call_parts), model_version=self._model_version, live_session_id=live_session_id, + grounding_metadata=types.GroundingMetadata(), ) tool_call_parts = [] if message.session_resumption_update: diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index c921f197c3..333034565f 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -81,6 +81,12 @@ class LlmResponse(BaseModel): Only used for streaming mode. """ + turn_complete_reason: Optional[types.TurnCompleteReason] = None + """The reason why the turn is complete. + + Only used for streaming mode. + """ + finish_reason: Optional[types.FinishReason] = None """The finish reason of the response.""" diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index 0ad63044d4..5249423cd1 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -130,6 +130,12 @@ async def start(self) -> ClientSession: if not self._task: self._task = asyncio.create_task(self._run()) + def _retrieve_exception(t: asyncio.Task): + if not t.cancelled(): + t.exception() + + self._task.add_done_callback(_retrieve_exception) + await self._ready_event.wait() if self._task.cancelled(): diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index b2e032e0d1..dbb3a08193 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -172,4 +172,5 @@ def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool: """ if not model_string: return False - return model_string.startswith('gemini-3.1-flash-live') + model_name = extract_model_name(model_string) + return model_name.startswith('gemini-3.1-flash-live') diff --git a/src/google/adk/version.py b/src/google/adk/version.py index cf2713c03f..209fa91742 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.34.0" +__version__ = "1.35.0" diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index ce2e83b6f7..b5c3f1a612 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -17,12 +17,14 @@ from unittest import mock from unittest.mock import AsyncMock +from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import Agent from google.adk.agents.run_config import RunConfig from google.adk.events.event import Event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini +from google.adk.models.google_llm import GoogleLLMVariant from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin @@ -30,6 +32,7 @@ from google.adk.tools.google_search_tool import GoogleSearchTool from google.genai import types import pytest +from websockets.exceptions import ConnectionClosed from ... import testing_utils @@ -490,8 +493,6 @@ async def call(self, **kwargs): @pytest.mark.asyncio async def test_run_live_reconnects_on_connection_closed(): """Test that run_live reconnects when ConnectionClosed occurs.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() mock_connection = mock.AsyncMock() @@ -558,7 +559,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_reconnects_on_api_error(): """Test that run_live reconnects when APIError occurs.""" - from google.adk.agents.live_request_queue import LiveRequestQueue from google.genai.errors import APIError real_model = Gemini() @@ -626,7 +626,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_skips_send_history_on_resumption(): """Test that run_live skips send_history when resuming a session.""" - from google.adk.agents.live_request_queue import LiveRequestQueue real_model = Gemini() mock_connection = mock.AsyncMock() @@ -684,7 +683,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_live_session_resumption_go_away(): """Test that go_away triggers reconnection.""" - from google.adk.agents.live_request_queue import LiveRequestQueue real_model = Gemini() mock_connection = mock.AsyncMock() @@ -743,8 +741,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_no_reconnect_without_handle(): """Test that run_live does not reconnect when handle is missing.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() mock_connection = mock.AsyncMock() @@ -786,8 +782,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_limit(): """Test that run_live stops reconnecting after 5 attempts.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() @@ -796,17 +790,17 @@ async def test_run_live_reconnect_limit(): async def mock_connect_impl(*args, **kwargs): nonlocal connection_cnt connection_cnt += 1 + if connection_cnt > 1: + raise ConnectionClosed(None, None) conn = mock.AsyncMock() async def mock_receive(): - if connection_cnt == 1: - # Yield handle only on the first connection. - yield LlmResponse( - live_session_resumption_update=types.LiveServerSessionResumptionUpdate( - new_handle='test_handle' - ), - turn_complete=True, - ) + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ), + turn_complete=True, + ) # All subsequent receives (and all receives on later connections) fail. raise ConnectionClosed(None, None) @@ -842,10 +836,8 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_reset_attempt(): - """Test that attempt counter is reset on successful communication.""" - from google.adk.agents.live_request_queue import LiveRequestQueue + """Test that attempt counter is reset on successful connection establishment.""" from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS - from websockets.exceptions import ConnectionClosed real_model = Gemini() @@ -854,23 +846,29 @@ async def test_run_live_reconnect_reset_attempt(): async def mock_connect_impl(*args, **kwargs): nonlocal connection_cnt connection_cnt += 1 - conn = mock.AsyncMock() + # Establish connection successfully on attempts 1, 2, and 5 + if connection_cnt in (1, 2, 5): + conn = mock.AsyncMock() - async def mock_receive(): - if connection_cnt <= 2: - # Yield handle on the first two connections. - yield LlmResponse( - live_session_resumption_update=types.LiveServerSessionResumptionUpdate( - new_handle='test_handle' - ), - turn_complete=True, - ) - # All subsequent receives fail. + async def mock_receive(): + if connection_cnt == 1: + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ), + turn_complete=True, + ) + else: + if False: + yield + raise ConnectionClosed(None, None) + + conn.receive = mock.Mock(side_effect=mock_receive) + return conn + else: + # Failed connection establishments on other attempts raise ConnectionClosed(None, None) - conn.receive = mock.Mock(side_effect=mock_receive) - return conn - agent = Agent(name='test_agent', model=real_model) invocation_context = await testing_utils.create_invocation_context( agent=agent @@ -891,9 +889,13 @@ async def mock_receive(): async for _ in flow.run_live(invocation_context): pass - # We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts - # Total calls = 2 + 5 = 7 - assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2 + # Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed. + # Connection 2: succeeds (resets to 1), receive raises ConnectionClosed. + # Connection 3: fails (attempt becomes 2) + # Connection 4: fails (attempt becomes 3) + # Connection 5: succeeds (resets to 1), receive raises ConnectionClosed. + # Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates. + assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5 @pytest.mark.asyncio @@ -987,7 +989,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_clears_resumption_handle_on_transfer(): """Test that run_live clears session resumption handles when transferring to another agent.""" - from google.adk.agents.live_request_queue import LiveRequestQueue agent = Agent(name='test_agent') invocation_context = await testing_utils.create_invocation_context( @@ -1069,3 +1070,345 @@ async def mock_run_live_sub_agent(child_ctx, *args, **kwargs): assert ( invocation_context.run_config.session_resumption.handle == 'test_handle' ) + + +@pytest.mark.asyncio +async def test_postprocess_live_yields_grounding_metadata_only(): + """Test that _postprocess_live yields LlmResponse with only grounding_metadata.""" + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + flow = BaseLlmFlowForTesting() + + llm_request = LlmRequest() + grounding_metadata = types.GroundingMetadata( + web_search_queries=['test query'], + ) + llm_response = LlmResponse(grounding_metadata=grounding_metadata) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + + events = [] + async for event in flow._postprocess_live( + invocation_context, llm_request, llm_response, model_response_event + ): + events.append(event) + + assert len(events) == 1 + assert events[0].grounding_metadata == grounding_metadata + + +@pytest.mark.asyncio +async def test_postprocess_async_yields_grounding_metadata_only(): + """Test that _postprocess_async yields LlmResponse with only grounding_metadata.""" + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + flow = BaseLlmFlowForTesting() + + llm_request = LlmRequest() + grounding_metadata = types.GroundingMetadata( + web_search_queries=['test query'], + ) + llm_response = LlmResponse(grounding_metadata=grounding_metadata) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + + events = [] + async for event in flow._postprocess_async( + invocation_context, llm_request, llm_response, model_response_event + ): + events.append(event) + + assert len(events) == 1 + assert events[0].grounding_metadata == grounding_metadata + + +@pytest.mark.asyncio +async def test_run_live_reconnect_does_not_set_transparent(): + """Test that run_live reconnect does not set transparent=True.""" + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + invocation_context.run_config = RunConfig() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + async def mock_preprocess(ctx, req): + req.live_connect_config.session_resumption = ( + ctx.run_config.session_resumption + ) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + mock_connection_2 = mock.AsyncMock() + + class StopTestError(Exception): + pass + + async def mock_receive_2(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch.object( + Gemini, '_api_backend', new_callable=mock.PropertyMock + ) as mock_backend: + mock_backend.return_value = GoogleLLMVariant.GEMINI_API + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 2 + second_call_req = mock_connect.call_args_list[1][0][0] + session_resump = ( + second_call_req.live_connect_config.session_resumption + ) + assert session_resump.transparent is None + + +@pytest.mark.asyncio +async def test_run_live_reconnect_sets_transparent_for_vertex(): + """Test that run_live reconnect sets transparent=True for vertex backend.""" + + real_model = Gemini( + model='projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp' + ) + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + invocation_context.run_config = RunConfig() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + async def mock_preprocess(ctx, req): + req.live_connect_config.session_resumption = ( + ctx.run_config.session_resumption + ) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + mock_connection_2 = mock.AsyncMock() + + class StopTestError(Exception): + pass + + async def mock_receive_2(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 2 + second_call_req = mock_connect.call_args_list[1][0][0] + session_resump = second_call_req.live_connect_config.session_resumption + assert session_resump.transparent + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'api_backend', + [ + GoogleLLMVariant.GEMINI_API, + GoogleLLMVariant.VERTEX_AI, + ], +) +async def test_run_live_history_config_set_for_all_backends(api_backend): + """Test that run_live sets history_config for all backends.""" + + real_model = Gemini(model='gemini-3.1-flash-live-preview') + mock_connection = mock.AsyncMock() + + class StopTestError(Exception): + pass + + async def mock_receive(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + async def mock_preprocess(ctx, req): + req.model = 'gemini-3.1-flash-live-preview' + req.contents = [ + types.Content(parts=[types.Part.from_text(text='history')]) + ] + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + with mock.patch.object( + Gemini, '_api_backend', new_callable=mock.PropertyMock + ) as mock_backend: + mock_backend.return_value = api_backend + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 1 + called_req = mock_connect.call_args[0][0] + assert called_req.live_connect_config is not None + assert called_req.live_connect_config.history_config is not None + assert ( + called_req.live_connect_config.history_config.initial_history_in_client_content + is True + ) + + +@pytest.mark.asyncio +async def test_run_live_respects_explicit_initial_history_in_client_content_false(): + """Test that run_live respects explicit initial_history_in_client_content=False in RunConfig.""" + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + run_config = RunConfig( + history_config=types.HistoryConfig( + initial_history_in_client_content=False + ) + ) + invocation_context.run_config = run_config + + flow = BaseLlmFlowForTesting() + + async def mock_preprocess(ctx, req): + req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] + from google.adk.flows.llm_flows.basic import _build_basic_request + + _build_basic_request(ctx, req) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + class StopTestError(Exception): + pass + + async def mock_receive(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 1 + call_req = mock_connect.call_args[0][0] + assert call_req.live_connect_config.history_config is not None + assert ( + call_req.live_connect_config.history_config.initial_history_in_client_content + is False + ) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 133a455738..62548dac30 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -285,11 +285,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_interrupt( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on interrupt signal.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -345,7 +351,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 5 assert responses[4].interrupted is True @@ -365,11 +371,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_generation_complete( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on generation_complete signal.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -425,7 +437,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 4 @@ -444,11 +456,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_turn_complete( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on interrupt or complete signals.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -504,7 +522,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 5 assert responses[4].turn_complete is True @@ -867,6 +885,7 @@ async def test_receive_grounding_metadata_standalone( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) mock_message.usage_metadata = None @@ -911,6 +930,7 @@ async def test_receive_grounding_metadata_with_content( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) mock_message.usage_metadata = None @@ -981,6 +1001,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_metadata_msg = mock.create_autospec( types.LiveServerMessage, instance=True @@ -1001,6 +1022,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio( mock_turn_complete_content.interrupted = False mock_turn_complete_content.input_transcription = None mock_turn_complete_content.output_transcription = None + mock_turn_complete_content.generation_complete = False mock_turn_complete_msg = mock.create_autospec( types.LiveServerMessage, instance=True @@ -1240,3 +1262,483 @@ async def mock_receive_generator(): content_response = next((r for r in responses if r.content), None) assert content_response is not None assert content_response.content == mock_content + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_pending( + gemini_connection, mock_gemini_session +): + """Test that grounding metadata in partial chunks is pending and yielded on full text.""" + grounding_metadata = types.GroundingMetadata( + web_search_queries=['stock price of google'], + ) + + def make_msg(text=None, g_meta=None, tc=False): + msg = mock.Mock( + usage_metadata=None, + tool_call=None, + session_resumption_update=None, + go_away=None, + ) + msg.server_content = mock.Mock( + interrupted=False, + input_transcription=None, + output_transcription=None, + generation_complete=False, + turn_complete=tc, + grounding_metadata=g_meta, + model_turn=types.Content( + role='model', parts=[types.Part.from_text(text=text)] + ) + if text + else None, + ) + return msg + + msg1 = make_msg(text='hello', g_meta=grounding_metadata) + msg2 = make_msg(text=' world') + msg3 = make_msg(tc=True) + + async def gen(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=gen()) + + responses = [resp async for resp in gemini_connection.receive()] + + # Expected responses: + # 1. Msg 1 partial (hello) with grounding_metadata + # 2. Msg 2 partial ( world) without grounding_metadata + # 3. Full text response (hello world) with PENDING grounding_metadata + # 4. Turn complete response without grounding_metadata (already cleared) + assert len(responses) == 4 + + assert responses[0].content.parts[0].text == 'hello' + assert responses[0].partial is True + assert responses[0].grounding_metadata == grounding_metadata + + assert responses[1].content.parts[0].text == ' world' + assert responses[1].partial is True + assert responses[1].grounding_metadata is None + + assert responses[2].content.parts[0].text == 'hello world' + assert responses[2].partial is False + assert responses[2].grounding_metadata == grounding_metadata + + assert responses[3].turn_complete is True + assert responses[3].grounding_metadata is None + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse.""" + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = None + mock_server_content.grounding_metadata = None + mock_server_content.turn_complete = True + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].turn_complete is True + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason_standalone_grounding( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse for standalone grounding metadata.""" + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = None + mock_server_content.grounding_metadata = mock.create_autospec( + types.GroundingMetadata, instance=True + ) + mock_server_content.turn_complete = False + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].grounding_metadata is not None + assert responses[0].turn_complete is None + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason_with_content( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse when model turn has content parts.""" + mock_content = types.Content( + role='model', + parts=[types.Part.from_text(text='hello')], + ) + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = mock_content + mock_server_content.grounding_metadata = None + mock_server_content.turn_complete = False + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].content == mock_content + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_multiplexed_parts( + gemini_connection, mock_gemini_session +): + """Test receive with multiplexed inline data and text content.""" + mock_content = types.Content( + role='model', + parts=[ + types.Part( + inline_data=types.Blob(data=b'audio_data', mime_type='audio/pcm') + ), + types.Part.from_text(text='transcription text'), + ], + ) + mock_server_content = mock.Mock() + mock_server_content.model_turn = mock_content + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.turn_complete = False + mock_server_content.grounding_metadata = None + + mock_message = mock.AsyncMock() + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + assert responses + content_response = next((r for r in responses if r.content), None) + assert content_response is not None + assert content_response.content == mock_content + assert content_response.partial is True + + +@pytest.mark.asyncio +async def test_send_history_gemini_31_turn_complete(mock_gemini_session): + """Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True.""" + from google.adk.models.google_llm import GoogleLLMVariant + + conn = GeminiLlmConnection( + mock_gemini_session, + api_backend=GoogleLLMVariant.GEMINI_API, + model_version='gemini-3.1-flash-live-preview', + ) + mock_gemini_session.send_client_content = mock.AsyncMock() + + mock_contents = [ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]), + types.Content(role='model', parts=[types.Part.from_text(text='hello')]), + ] + await conn.send_history(mock_contents) + + mock_gemini_session.send_client_content.assert_called_once_with( + turns=mock_contents, + turn_complete=True, + ) + + +@pytest.mark.asyncio +async def test_send_history_collapse_vertex_ai(mock_gemini_session): + """Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend.""" + from google.adk.models.google_llm import GoogleLLMVariant + + conn = GeminiLlmConnection( + mock_gemini_session, + api_backend=GoogleLLMVariant.VERTEX_AI, + model_version='gemini-3.1-flash-live-preview', + ) + mock_gemini_session.send_client_content = mock.AsyncMock() + + mock_contents = [ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]), + types.Content(role='model', parts=[types.Part.from_text(text='hello')]), + ] + await conn.send_history(mock_contents) + + assert mock_gemini_session.send_client_content.call_count == 1 + called_turns = mock_gemini_session.send_client_content.call_args.kwargs[ + 'turns' + ] + assert len(called_turns) == 1 + assert called_turns[0].role == 'user' + assert 'Previous conversation history:' in called_turns[0].parts[0].text + assert '[user]: hi' in called_turns[0].parts[0].text + assert '[model]: hello' in called_turns[0].parts[0].text + assert ( + mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] + is True + ) + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_default_gemini_3_1( + mock_gemini_session, +): + """Verify grounding_metadata defaults to empty GroundingMetadata for Gemini 3.1.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-3.1-flash-live-preview', + ) + + def make_msg(text=None, tc=False, tool_call=None): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = tool_call + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = None + msg.server_content.output_transcription = None + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = ( + types.Content(role='model', parts=[types.Part.from_text(text=text)]) + if text + else None + ) + return msg + + # 1. Content event + msg1 = make_msg(text='hello') + # 2. Tool call event (yields immediately for Gemini 3.1) + function_call = types.FunctionCall(name='foo', args={}) + tool_call = mock.create_autospec(types.LiveServerToolCall, instance=True) + tool_call.function_calls = [function_call] + msg2 = make_msg(tool_call=tool_call) + # 3. Turn complete event + msg3 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + responses = [resp async for resp in conn.receive()] + # Expected: + # responses[0] -> partial content response for msg1 (has grounding_metadata) + # responses[1] -> full text response for msg1 (has grounding_metadata) + # responses[2] -> tool call response for msg2 (has grounding_metadata) + # responses[3] -> turn_complete response for msg3 (has grounding_metadata) + assert len(responses) == 4 + assert responses[0].content.parts[0].text == 'hello' + assert isinstance(responses[0].grounding_metadata, types.GroundingMetadata) + assert responses[0].grounding_metadata.web_search_queries is None + assert responses[0].partial is True + assert responses[1].content.parts[0].text == 'hello' + assert isinstance(responses[1].grounding_metadata, types.GroundingMetadata) + assert responses[1].partial is False + assert responses[2].content.parts[0].function_call.name == 'foo' + assert isinstance(responses[2].grounding_metadata, types.GroundingMetadata) + assert responses[3].turn_complete is True + assert isinstance(responses[3].grounding_metadata, types.GroundingMetadata) + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_default_non_gemini_3_1( + mock_gemini_session, +): + """Verify grounding_metadata stays None for non-Gemini 3.1 models.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-2.5-flash-live', + ) + + def make_msg(text=None, tc=False): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = None + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = None + msg.server_content.output_transcription = None + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = ( + types.Content(role='model', parts=[types.Part.from_text(text=text)]) + if text + else None + ) + return msg + + msg1 = make_msg(text='hello') + msg2 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + responses = [resp async for resp in conn.receive()] + assert len(responses) == 3 + assert responses[0].content.parts[0].text == 'hello' + assert responses[0].grounding_metadata is None + assert responses[0].partial is True + assert responses[1].content.parts[0].text == 'hello' + assert responses[1].grounding_metadata is None + assert responses[1].partial is False + assert responses[2].turn_complete is True + assert responses[2].grounding_metadata is None + + +@pytest.mark.asyncio +async def test_receive_input_transcription_gemini_3_1( + mock_gemini_session, +): + """Verify input_transcription yields finished=True immediately for Gemini 3.1.""" + conn = GeminiLlmConnection( + mock_gemini_session, + model_version='gemini-3.1-flash-live-preview', + ) + + def make_msg( + input_text=None, output_text=None, output_finished=False, tc=False + ): + msg = mock.create_autospec(types.LiveServerMessage, instance=True) + msg.usage_metadata = None + msg.tool_call = None + msg.session_resumption_update = None + msg.go_away = None + msg.server_content = mock.Mock() + msg.server_content.interrupted = False + msg.server_content.input_transcription = ( + types.Transcription(text=input_text, finished=False) + if input_text + else None + ) + msg.server_content.output_transcription = ( + types.Transcription(text=output_text, finished=output_finished) + if output_text + else None + ) + msg.server_content.generation_complete = False + msg.server_content.turn_complete = tc + msg.server_content.grounding_metadata = None + msg.server_content.model_turn = None + return msg + + msg1 = make_msg(input_text='Hello') + msg2 = make_msg(output_text='Hi there!', output_finished=True) + msg3 = make_msg(tc=True) + + async def mock_receive_generator(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in conn.receive()] + + assert len(responses) == 4 + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is True + assert responses[0].partial is False + + assert responses[1].output_transcription.text == 'Hi there!' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + + assert responses[2].output_transcription.text == 'Hi there!' + assert responses[2].output_transcription.finished is True + assert responses[2].partial is False + + assert responses[3].turn_complete is True diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index a94b2eb885..f7e16014ff 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -588,6 +588,9 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( self, ): """Verify that sessions from different loops are cleaned up without calling aclose().""" + from google.adk.features import FeatureName + from google.adk.features._feature_registry import temporary_feature_override + manager = MCPSessionManager(self.mock_stdio_connection_params) # 1. Simulate a session created in a "different" loop @@ -617,8 +620,11 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( mock_wait_for.return_value = new_session mock_session_context_class.return_value = AsyncMock() - # 3. Call create_session - session = await manager.create_session() + # 3. Call create_session with flag off to hit wait_for branch + with temporary_feature_override( + FeatureName._MCP_GRACEFUL_ERROR_HANDLING, False + ): + session = await manager.create_session() # 4. Verify results assert session == new_session @@ -969,8 +975,8 @@ class TestMCPGracefulErrorHandlingFlagContract: loudly so we don't silently break GE's rollout. """ - def test_default_state_is_off_so_cl_is_a_noop(self): - """The CL must be a no-op until GE explicitly enables it.""" + def test_default_state_is_on(self): + """The fix must be enabled by default.""" import os from google.adk.features import FeatureName @@ -981,34 +987,34 @@ def test_default_state_is_off_so_cl_is_a_noop(self): saved = {k: os.environ.pop(k) for k in (enable, disable) if k in os.environ} try: assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True ) finally: os.environ.update(saved) - def test_env_var_enable_flips_flag_on_at_runtime(self): - """The env var GE will set must turn the fix on without a rebuild.""" + def test_env_var_disable_flips_flag_off_at_runtime(self): + """The env var must turn the fix off without a rebuild.""" import os from google.adk.features import FeatureName from google.adk.features import is_feature_enabled - enable = "ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING" - saved = os.environ.pop(enable, None) + disable = "ADK_DISABLE_MCP_GRACEFUL_ERROR_HANDLING" + saved = os.environ.pop(disable, None) try: - os.environ[enable] = "1" + os.environ[disable] = "1" assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False ) # And once it's removed, we revert. Confirms the value is read # live from os.environ on every call (no caching, no binary push). - del os.environ[enable] + del os.environ[disable] assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True ) finally: if saved is not None: - os.environ[enable] = saved + os.environ[disable] = saved def test_env_var_disable_acts_as_kill_switch(self): """The disable env var lets consumers turn off without a rebuild.""" diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index bb2654c3db..f962fb6f5c 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -16,6 +16,7 @@ from google.adk.utils.model_name_utils import extract_model_name from google.adk.utils.model_name_utils import is_gemini_1_model +from google.adk.utils.model_name_utils import is_gemini_3_1_flash_live from google.adk.utils.model_name_utils import is_gemini_eap_or_2_or_above from google.adk.utils.model_name_utils import is_gemini_model from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled @@ -338,3 +339,28 @@ def test_default_is_disabled(self, monkeypatch): def test_true_enables_check_bypass(self, monkeypatch): monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') assert is_gemini_model_id_check_disabled() is True + + +class TestIsGemini31FlashLive: + """Test the is_gemini_3_1_flash_live function.""" + + def test_is_gemini_3_1_flash_live_simple_name(self): + """Test with simple model name format.""" + assert is_gemini_3_1_flash_live('gemini-3.1-flash-live') is True + assert is_gemini_3_1_flash_live('gemini-3.1-flash-live-preview') is True + assert is_gemini_3_1_flash_live('gemini-3.1-pro-live') is False + assert is_gemini_3_1_flash_live('gemini-2.5-flash-live') is False + + def test_is_gemini_3_1_flash_live_path_based_name(self): + """Test with path-based format (Vertex AI etc.).""" + vertex_path = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash-live' + assert is_gemini_3_1_flash_live(vertex_path) is True + vertex_path_preview = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash-live-preview' + assert is_gemini_3_1_flash_live(vertex_path_preview) is True + non_live_path = 'projects/123/locations/us-central1/publishers/google/models/gemini-3.1-flash' + assert is_gemini_3_1_flash_live(non_live_path) is False + + def test_is_gemini_3_1_flash_live_edge_cases(self): + """Test edge cases.""" + assert is_gemini_3_1_flash_live(None) is False + assert is_gemini_3_1_flash_live('') is False