diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart index e6c53963bd5f..b5d1b01fb9cd 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart @@ -14,9 +14,9 @@ import 'dart:async'; import 'dart:developer' as developer; import 'package:flutter/foundation.dart'; - import 'package:flutter/material.dart'; import 'package:firebase_ai/firebase_ai.dart'; +import 'package:waveform_flutter/waveform_flutter.dart'; import '../utils/audio_input.dart'; import '../utils/audio_output.dart'; @@ -25,466 +25,373 @@ import '../widgets/message_widget.dart'; import '../widgets/audio_visualizer.dart'; import '../widgets/camera_previews.dart'; -class BidiPage extends StatefulWidget { - const BidiPage({ - super.key, - required this.title, +// ============================================================================ +// MEDIA MANAGER +// Isolates Audio and Video hardware stream setup, start, stop, and cleanup. +// ============================================================================ +class BidiMediaManager { + final AudioOutput _audioOutput = AudioOutput(); + final AudioInput _audioInput = AudioInput(); + final VideoInput _videoInput = VideoInput(); + + StreamSubscription? _audioSubscription; + StreamSubscription? _videoSubscription; + + bool videoIsInitialized = false; + + // Expose hardware state/streams to the Controller and UI + Stream? get amplitudeStream => _audioInput.amplitudeStream; + dynamic get cameraController => _videoInput.cameraController; + String? get selectedCameraId => _videoInput.selectedCameraId; + bool get controllerInitialized => _videoInput.controllerInitialized; + + void setMacOSController(dynamic controller) { + _videoInput.setMacOSController(controller); + } + + Future init() async { + try { + await _audioOutput.init(); + } catch (e) { + developer.log('Audio Output init error: $e'); + } + + try { + await _audioInput.init(); + } catch (e) { + developer.log('Audio Input init error: $e'); + } + + try { + await _videoInput.init(); + videoIsInitialized = true; + } catch (e) { + developer.log('Error during video initialization: $e'); + } + } + + Future startAudio(void Function(Uint8List) onData) async { + await stopAudio(); + try { + var inputStream = await _audioInput.startRecordingStream(); + await _audioOutput.playStream(); + if (inputStream != null) { + _audioSubscription = inputStream.listen( + onData, + onError: (e) { + developer.log('Audio Stream Error: $e'); + stopAudio(); + }, + cancelOnError: true, + ); + } + } catch (e) { + developer.log('BidiMediaManager.startAudio(): $e'); + rethrow; + } + } + + Future stopAudio() async { + await _audioSubscription?.cancel(); + _audioSubscription = null; + await _audioInput.stopRecording(); + } + + Future startVideo(void Function(Uint8List, String) onData) async { + if (!videoIsInitialized) return; + + if (!_videoInput.controllerInitialized || + _videoInput.cameraController == null) { + await _videoInput.initializeCameraController(); + } + + if (!kIsWeb && defaultTargetPlatform == TargetPlatform.macOS) { + int attempts = 0; + while (_videoInput.cameraController == null) { + if (attempts > 50) break; // 5 second timeout safety + await Future.delayed(const Duration(milliseconds: 100)); + attempts++; + } + } + + // Wait for Mac Camera to Settle (Prevent audio hijack) + await Future.delayed(const Duration(milliseconds: 1000)); + + _videoSubscription = _videoInput.startStreamingImages().listen( + (data) { + String mimeType = 'image/jpeg'; + if (!kIsWeb && defaultTargetPlatform == TargetPlatform.macOS) { + if (data.length > 3 && data[0] == 0x89 && data[1] == 0x50) { + mimeType = 'image/png'; + } + } + onData(data, mimeType); + }, + onError: (e) => developer.log('Video Stream Error: $e'), + ); + } + + Future stopVideo() async { + await _videoSubscription?.cancel(); + _videoSubscription = null; + await _videoInput.stopStreamingImages(); + } + + void playAudioChunk(Uint8List bytes) { + _audioOutput.addDataToAudioStream(bytes); + } +} + +// ============================================================================ +// BIDI SESSION CONTROLLER +// Isolates business logic, session start/stop, reconnection, and tool execution. +// ============================================================================ +class BidiSessionController extends ChangeNotifier { + BidiSessionController({ required this.model, required this.useVertexBackend, - }); + this.onShowError, + this.onScrollDown, + }) { + _initLiveModel(); + } - final String title; final GenerativeModel model; final bool useVertexBackend; + final void Function(String)? onShowError; + final VoidCallback? onScrollDown; - @override - State createState() => _BidiPageState(); -} + late LiveGenerativeModel _liveModel; + LiveSession? _session; + final BidiMediaManager mediaManager = BidiMediaManager(); -class LightControl { - final int? brightness; - final String? colorTemperature; + bool isLoading = false; + bool isSessionActive = false; + bool isMicOn = false; + bool isCameraOn = false; - LightControl({this.brightness, this.colorTemperature}); -} + // Intention state for robust stream reconnection + bool _intendedMicOn = false; + bool _intendedCameraOn = false; -class _BidiPageState extends State { - final ScrollController _scrollController = ScrollController(); - final TextEditingController _textController = TextEditingController(); - final FocusNode _textFieldFocus = FocusNode(); - final List _messages = []; - bool _loading = false; - bool _sessionOpening = false; - bool _recording = false; - late LiveGenerativeModel _liveModel; - late LiveSession _session; - final AudioOutput _audioOutput = AudioOutput(); - final AudioInput _audioInput = AudioInput(); - final VideoInput _videoInput = VideoInput(); - StreamSubscription? _audioSubscription; + final List messages = []; + String? _activeSessionHandle; + int? _lastProcessedIndex; int? _inputTranscriptionMessageIndex; int? _outputTranscriptionMessageIndex; - bool _isCameraOn = false; - bool _videoIsInitialized = false; - - @override - void initState() { - super.initState(); + void _initLiveModel() { final config = LiveGenerationConfig( speechConfig: SpeechConfig(voiceName: 'Fenrir'), - responseModalities: [ - ResponseModalities.audio, - ], + responseModalities: [ResponseModalities.audio], inputAudioTranscription: AudioTranscriptionConfig(), outputAudioTranscription: AudioTranscriptionConfig(), ); - _liveModel = widget.useVertexBackend + final tools = [ + Tool.functionDeclarations([_lightControlTool]) + ]; + + _liveModel = useVertexBackend ? FirebaseAI.vertexAI().liveGenerativeModel( model: 'gemini-live-2.5-flash-preview-native-audio-09-2025', liveGenerationConfig: config, - tools: [ - Tool.functionDeclarations([lightControlTool]), - ], + tools: tools, ) : FirebaseAI.googleAI().liveGenerativeModel( model: 'gemini-2.5-flash-native-audio-preview-09-2025', liveGenerationConfig: config, - tools: [ - Tool.functionDeclarations([lightControlTool]), - ], + tools: tools, ); } - Future _initAudio() async { - try { - await _audioOutput.init(); - } catch (e) { - developer.log('Audio Output init error: $e'); - } + Future initialize() async { + isLoading = true; + notifyListeners(); + await mediaManager.init(); + isLoading = false; + notifyListeners(); + } - try { - await _audioInput.init(); - } catch (e) { - developer.log('Audio Input init error: $e'); + Future toggleSession() async { + if (isSessionActive) { + await _stopSession(explicit: true); + } else { + await _startSession(explicit: true); } } - Future _initVideo() async { + Future _startSession({required bool explicit}) async { + isLoading = true; + notifyListeners(); + try { - await _videoInput.init(); - setState(() { - _videoIsInitialized = true; - }); - } catch (e) { - developer.log('Error during video initialization: $e'); + _session = await _liveModel.connect( + sessionResumption: + SessionResumptionConfig(handle: _activeSessionHandle), + ); + } on Exception catch (e) { + developer.log('Error setting up session: $e, starting a new one.'); + _session = await _liveModel.connect(); } - } - void _scrollDown() { - if (!_scrollController.hasClients) return; + isSessionActive = true; + unawaited(_processMessagesContinuously()); - _scrollController.jumpTo( - _scrollController.position.maxScrollExtent, - ); + if (explicit) { + // Reconnect previously active hardware seamlessly into the new session + if (_intendedMicOn) await _startMicStream(); + if (_intendedCameraOn) await _startCameraStream(); + } + + isLoading = false; + notifyListeners(); } - @override - void dispose() { - if (_sessionOpening) { - _sessionOpening = false; - _session.close(); + Future _stopSession({required bool explicit}) async { + isLoading = true; + notifyListeners(); + + if (explicit) { + await mediaManager.stopAudio(); + await mediaManager.stopVideo(); + isMicOn = false; + isCameraOn = false; + // We purposefully DO NOT reset _intendedMicOn/CameraOn so we know what + // the user had active when they reconnect! } - super.dispose(); - } - @override - Widget build(BuildContext context) { - return Padding( - padding: const EdgeInsets.all(8), - child: Column( - mainAxisAlignment: MainAxisAlignment.center, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - if (_isCameraOn) - Container( - height: 200, - color: Colors.black, - alignment: Alignment.center, - child: (!kIsWeb && defaultTargetPlatform == TargetPlatform.macOS) - ? FullCameraPreview( - controller: _videoInput.cameraController, - deviceId: _videoInput.selectedCameraId, - onInitialized: (controller) { - // This is where the controller actually gets born on macOS - _videoInput.setMacOSController(controller); - }, - ) - : (_videoInput.cameraController != null && - _videoInput.controllerInitialized) - ? FullCameraPreview( - controller: _videoInput.cameraController, - deviceId: _videoInput.selectedCameraId, - onInitialized: (controller) { - // Web/Mobile callback (often unused if controller passed in) - }, - ) - : const Center(child: CircularProgressIndicator()), - ), - Expanded( - child: ListView.builder( - controller: _scrollController, - itemBuilder: (context, idx) { - return MessageWidget( - text: _messages[idx].text, - image: _messages[idx].imageBytes != null - ? Image.memory( - _messages[idx].imageBytes!, - cacheWidth: 400, - cacheHeight: 400, - ) - : null, - isFromUser: _messages[idx].fromUser ?? false, - isThought: _messages[idx].isThought, - ); - }, - itemCount: _messages.length, - ), - ), - Padding( - padding: const EdgeInsets.symmetric( - vertical: 25, - horizontal: 15, - ), - child: Row( - children: [ - Expanded( - child: TextField( - focusNode: _textFieldFocus, - controller: _textController, - onSubmitted: _sendTextPrompt, - ), - ), - const SizedBox.square( - dimension: 15, - ), - AudioVisualizer( - audioStreamIsActive: _recording, - amplitudeStream: _audioInput.amplitudeStream, - ), - const SizedBox.square( - dimension: 15, - ), - IconButton( - tooltip: 'Start Streaming', - onPressed: !_loading - ? () async { - await _setupSession(); - } - : null, - icon: Icon( - Icons.network_wifi, - color: _sessionOpening - ? Theme.of(context).colorScheme.secondary - : Theme.of(context).colorScheme.primary, - ), - ), - IconButton( - tooltip: 'Send Stream Message', - onPressed: !_loading - ? () async { - if (_recording) { - await _stopRecording(); - } else { - await _startRecording(); - } - } - : null, - icon: Icon( - _recording ? Icons.stop : Icons.mic, - color: _loading - ? Theme.of(context).colorScheme.secondary - : Theme.of(context).colorScheme.primary, - ), - ), - IconButton( - tooltip: 'Toggle Camera', - onPressed: _isCameraOn ? _stopVideoStream : _startVideoStream, - icon: Icon( - _isCameraOn ? Icons.videocam_off : Icons.videocam, - color: _loading - ? Theme.of(context).colorScheme.secondary - : Theme.of(context).colorScheme.primary, - ), - ), - if (!_loading) - IconButton( - onPressed: () async { - await _sendTextPrompt(_textController.text); - }, - icon: Icon( - Icons.send, - color: Theme.of(context).colorScheme.primary, - ), - ) - else - const CircularProgressIndicator(), - ], - ), - ), - ], - ), - ); - } + await _session?.close(); + _session = null; + isSessionActive = false; - final lightControlTool = FunctionDeclaration( - 'setLightValues', - 'Set the brightness and color temperature of a room light.', - parameters: { - 'brightness': Schema.integer( - description: 'Light level from 0 to 100. ' - 'Zero is off and 100 is full brightness.', - ), - 'colorTemperature': Schema.string( - description: 'Color temperature of the light fixture, ' - 'which can be `daylight`, `cool` or `warm`.', - ), - }, - ); + isLoading = false; + notifyListeners(); + } - Future> _setLightValues({ - int? brightness, - String? colorTemperature, - }) async { - final apiResponse = { - 'colorTemprature': 'warm', - 'brightness': brightness, - }; - return apiResponse; + Future _sessionResume() async { + if (isSessionActive) { + // Explicit false means we keep the hardware pipelines active + // so they immediately route data to the new session object + await _stopSession(explicit: false); + await _startSession(explicit: false); + } } - Future _setupSession() async { - setState(() { - _loading = true; - }); - await _initAudio(); + void _onAudioData(Uint8List data) { + if (isSessionActive && _session != null) { + _session!.sendAudioRealtime(InlineDataPart('audio/pcm', data)); + } + } - try { - if (!_videoIsInitialized) { - await _initVideo(); - } else { - await _videoInput.initializeCameraController(); - } - } catch (e) { - developer.log('Video Hardware init error: $e'); + void _onVideoData(Uint8List data, String mimeType) { + if (isSessionActive && _session != null) { + _session!.sendVideoRealtime(InlineDataPart(mimeType, data)); } + } - if (!_sessionOpening) { - _session = await _liveModel.connect(); - _sessionOpening = true; - unawaited( - _processMessagesContinuously(), - ); + Future toggleMic() async { + _intendedMicOn = !_intendedMicOn; + if (_intendedMicOn) { + await _startMicStream(); } else { - await _session.close(); - _sessionOpening = false; + await mediaManager.stopAudio(); + isMicOn = false; + notifyListeners(); } - - setState(() { - _loading = false; - }); } - Future _startRecording() async { - await _audioSubscription?.cancel(); - _audioSubscription = null; - setState(() { - _recording = true; - }); + Future _startMicStream() async { + if (!isSessionActive) { + isMicOn = true; + notifyListeners(); + return; + } try { - var inputStream = await _audioInput.startRecordingStream(); - await _audioOutput.playStream(); - if (inputStream != null) { - _audioSubscription = inputStream.listen( - (data) { - _session.sendAudioRealtime(InlineDataPart('audio/pcm', data)); - }, - onError: (e) { - developer.log('Audio Stream Error: $e'); - _stopRecording(); - }, - cancelOnError: true, - ); - } + await mediaManager.startAudio(_onAudioData); + isMicOn = true; + notifyListeners(); } catch (e) { - developer.log('bidi_page._startRecording(): $e'); - _showError('bidi_page._startRecording(): $e'); - setState(() => _recording = false); + onShowError?.call(e.toString()); + isMicOn = false; + notifyListeners(); } } - Future _stopRecording() async { - await _audioSubscription?.cancel(); - _audioSubscription = null; + Future _startCameraStream() async { + if (!isSessionActive) { + isCameraOn = true; + notifyListeners(); + return; + } try { - await _audioInput.stopRecording(); + await mediaManager.startVideo(_onVideoData); + isCameraOn = true; + notifyListeners(); } catch (e) { - _showError(e.toString()); + developer.log('Error starting video stream: $e'); + onShowError?.call(e.toString()); + isCameraOn = false; + notifyListeners(); } - - setState(() { - _recording = false; - }); } - Future _startVideoStream() async { - // 1. Re-entry Guard: Prevent multiple clicks while switching - if (_loading || !_videoIsInitialized) return; + Future toggleCamera() async { + if (isLoading) return; // Prevent multiple clicks + _intendedCameraOn = !_intendedCameraOn; - // 2. Capture the current recording state - bool wasRecording = _recording; - - setState(() { - _loading = true; // Lock the UI during the switch - }); + isLoading = true; + notifyListeners(); try { - if (wasRecording) { - await _stopRecording(); - } - - // 4. Wait for ripple/UI (Prevent freeze) - await Future.delayed(const Duration(milliseconds: 250)); - - // 5. Initialize Camera if needed - if (!_videoInput.controllerInitialized || - _videoInput.cameraController == null) { - await _videoInput.initializeCameraController(); - } + if (!_intendedCameraOn) { + await mediaManager.stopVideo(); + isCameraOn = false; + } else { + // Stop audio momentarily to prevent hijacking (Mac quirk workaround) + bool wasMicOn = isMicOn; + if (wasMicOn) await mediaManager.stopAudio(); - // 6. Mount Camera UI - setState(() { - _isCameraOn = true; - }); - - if (!kIsWeb && defaultTargetPlatform == TargetPlatform.macOS) { - // ✅ Because we set _cameraController to null in stopStreamingImages, - // this loop will now CORRECTLY wait for the new View to initialize. - int attempts = 0; - while (_videoInput.cameraController == null) { - if (attempts > 50) break; // 5 second timeout safety - await Future.delayed(const Duration(milliseconds: 100)); - attempts++; - } - } + await Future.delayed(const Duration(milliseconds: 250)); - // 7. Wait for Mac Camera to Settle (Prevent audio hijack) - await Future.delayed(const Duration(milliseconds: 1000)); + await _startCameraStream(); - // 8. CLEAN RESTART: Use the helper method! - // Only restart if we were recording before. - if (wasRecording) { - developer.log('Resuming audio session...'); - await _startRecording(); + // Restart Audio + if (wasMicOn) await _startMicStream(); } - - // 9. Start Video Stream - _videoInput.startStreamingImages().listen( - (data) { - String mimeType = 'image/jpeg'; - if (!kIsWeb && defaultTargetPlatform == TargetPlatform.macOS) { - if (data.length > 3 && data[0] == 0x89 && data[1] == 0x50) { - mimeType = 'image/png'; - } - } - _session.sendVideoRealtime(InlineDataPart(mimeType, data)); - }, - onError: (e) => developer.log('Video Stream Error: $e'), - ); } catch (e) { developer.log('Error switching to video: $e'); - _showError(e.toString()); + onShowError?.call(e.toString()); + isCameraOn = false; } finally { - // 10. Always unlock the UI - setState(() { - _loading = false; - }); + isLoading = false; + notifyListeners(); } } - Future _stopVideoStream() async { - await _videoInput.stopStreamingImages(); - setState(() { - _isCameraOn = false; - }); - } + Future sendTextPrompt(String textPrompt) async { + if (!isSessionActive || _session == null) return; + isLoading = true; + notifyListeners(); - Future _sendTextPrompt(String textPrompt) async { - setState(() { - _loading = true; - }); try { - //final prompt = Content.text(textPrompt); - // await _session.send(input: prompt, turnComplete: true); - await _session.sendTextRealtime(textPrompt); + await _session!.sendTextRealtime(textPrompt); } catch (e) { - _showError(e.toString()); + onShowError?.call(e.toString()); } - setState(() { - _loading = false; - }); + isLoading = false; + notifyListeners(); } Future _processMessagesContinuously() async { + if (_session == null) return; try { - await for (final message in _session.receive()) { - if (!mounted) break; + await for (final message in _session!.receive()) { await _handleLiveServerMessage(message); } } catch (e) { - _showError(e.toString()); + onShowError?.call(e.toString()); } } @@ -496,40 +403,6 @@ class _BidiPageState extends State { await _handleLiveServerContent(message); } - int? _handleTranscription( - Transcription? transcription, - int? messageIndex, - String prefix, - bool fromUser, - ) { - int? currentIndex = messageIndex; - if (transcription?.text != null) { - if (currentIndex != null) { - _messages[currentIndex] = _messages[currentIndex].copyWith( - text: '${_messages[currentIndex].text}${transcription!.text!}', - ); - } else { - _messages.add( - MessageData( - text: '$prefix${transcription!.text!}', - fromUser: fromUser, - ), - ); - currentIndex = _messages.length - 1; - } - if (transcription.finished ?? false) { - currentIndex = null; - setState(_scrollDown); - } else { - // Use a scheduled frame instead of an immediate setState - WidgetsBinding.instance.addPostFrameCallback((_) { - if (mounted) setState(() {}); - }); - } - } - return currentIndex; - } - _inputTranscriptionMessageIndex = _handleTranscription( message.inputTranscription, _inputTranscriptionMessageIndex, @@ -549,8 +422,48 @@ class _BidiPageState extends State { } else if (message is LiveServerToolCall && message.functionCalls != null) { await _handleLiveServerToolCall(message); } else if (message is GoingAwayNotice) { - developer.log('Session is going away in ${message.timeLeft} seconds'); + developer.log('GoAway message received, time left: ${message.timeLeft}'); + if (_activeSessionHandle != null) { + developer.log('====================================================='); + developer.log('Resume Session with handle: $_activeSessionHandle'); + await _sessionResume(); + } + } else if (message is SessionResumptionUpdate && + message.resumable != null && + message.resumable!) { + _activeSessionHandle = message.newHandle; + _lastProcessedIndex = message.lastConsumedClientMessageIndex; + developer.log( + 'SessionResumptionUpdate: handle ${message.newHandle}, index $_lastProcessedIndex'); + } + } + + int? _handleTranscription(Transcription? transcription, int? messageIndex, + String prefix, bool fromUser) { + int? currentIndex = messageIndex; + if (transcription?.text != null) { + if (currentIndex != null) { + messages[currentIndex] = messages[currentIndex].copyWith( + text: '${messages[currentIndex].text}${transcription!.text!}', + ); + } else { + messages.add( + MessageData( + text: '$prefix${transcription!.text!}', + fromUser: fromUser, + ), + ); + currentIndex = messages.length - 1; + } + + if (transcription.finished ?? false) { + currentIndex = null; + onScrollDown?.call(); + } else { + notifyListeners(); // Trigger UI rebuild for streaming text + } } + return currentIndex; } Future _handleLiveServerContent(LiveServerContent response) async { @@ -558,9 +471,19 @@ class _BidiPageState extends State { if (partList != null) { for (final part in partList) { if (part is TextPart) { - await _handleTextPart(part); + messages.add( + MessageData( + text: part.text, + fromUser: false, + isThought: part.isThought ?? false, + ), + ); + onScrollDown?.call(); + notifyListeners(); } else if (part is InlineDataPart) { - await _handleInlineDataPart(part); + if (part.mimeType.startsWith('audio')) { + mediaManager.playAudioChunk(part.bytes); + } } else { developer.log('receive part with type ${part.runtimeType}'); } @@ -568,31 +491,6 @@ class _BidiPageState extends State { } } - Future _handleTextPart(TextPart part) async { - if (!_loading) { - setState(() { - _loading = true; - }); - } - _messages.add( - MessageData( - text: part.text, - fromUser: false, - isThought: part.isThought ?? false, - ), - ); - setState(() { - _loading = false; - _scrollDown(); - }); - } - - Future _handleInlineDataPart(InlineDataPart part) async { - if (part.mimeType.startsWith('audio')) { - _audioOutput.addDataToAudioStream(part.bytes); - } - } - Future _handleLiveServerToolCall(LiveServerToolCall response) async { final functionCalls = response.functionCalls!.toList(); if (functionCalls.isNotEmpty) { @@ -600,11 +498,15 @@ class _BidiPageState extends State { if (functionCall.name == 'setLightValues') { var color = functionCall.args['colorTemperature']! as String; var brightness = functionCall.args['brightness']! as int; - final functionResult = await _setLightValues( - brightness: brightness, - colorTemperature: color, - ); - await _session.sendToolResponse([ + + // Mock Tool Execution + final functionResult = { + 'colorTemprature': + color, // original had a typo, keeping to preserve functionality intent + 'brightness': brightness, + }; + + await _session?.sendToolResponse([ FunctionResponse( functionCall.name, functionResult, @@ -612,27 +514,99 @@ class _BidiPageState extends State { ), ]); } else { - throw UnimplementedError( - 'Function not declared to the model: ${functionCall.name}', - ); + throw UnimplementedError('Function not declared: ${functionCall.name}'); } } } + @override + void dispose() { + _stopSession(explicit: true); + super.dispose(); + } + + static final _lightControlTool = FunctionDeclaration( + 'setLightValues', + 'Set the brightness and color temperature of a room light.', + parameters: { + 'brightness': Schema.integer( + description: + 'Light level from 0 to 100. Zero is off and 100 is full brightness.', + ), + 'colorTemperature': Schema.string( + description: + 'Color temperature of the light fixture, which can be `daylight`, `cool` or `warm`.', + ), + }, + ); +} + +// ============================================================================ +// UI WIDGET +// Isolates presentation, keeping state out of the visual hierarchy. +// ============================================================================ +class BidiPage extends StatefulWidget { + const BidiPage({ + super.key, + required this.title, + required this.model, + required this.useVertexBackend, + }); + + final String title; + final GenerativeModel model; + final bool useVertexBackend; + + @override + State createState() => _BidiPageState(); +} + +class _BidiPageState extends State { + late final BidiSessionController _controller; + final ScrollController _scrollController = ScrollController(); + final TextEditingController _textController = TextEditingController(); + final FocusNode _textFieldFocus = FocusNode(); + + @override + void initState() { + super.initState(); + _controller = BidiSessionController( + model: widget.model, + useVertexBackend: widget.useVertexBackend, + onShowError: _showError, + onScrollDown: _scrollDown, + ); + _controller.initialize(); + } + + @override + void dispose() { + _controller.dispose(); + _scrollController.dispose(); + _textController.dispose(); + _textFieldFocus.dispose(); + super.dispose(); + } + + void _scrollDown() { + if (!_scrollController.hasClients) return; + WidgetsBinding.instance.addPostFrameCallback((_) { + if (mounted) { + _scrollController.jumpTo(_scrollController.position.maxScrollExtent); + } + }); + } + void _showError(String message) { showDialog( context: context, builder: (context) { return AlertDialog( title: const Text('Something went wrong'), - content: SingleChildScrollView( - child: SelectableText(message), - ), + content: SingleChildScrollView(child: SelectableText(message)), actions: [ TextButton( - onPressed: () { - Navigator.of(context).pop(); - }, + onPressed: () => Navigator.of(context).pop(), child: const Text('OK'), ), ], @@ -640,4 +614,145 @@ class _BidiPageState extends State { }, ); } + + @override + Widget build(BuildContext context) { + return AnimatedBuilder( + animation: _controller, + builder: (context, child) { + return Padding( + padding: const EdgeInsets.all(8), + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (_controller.isCameraOn) + Container( + height: 200, + color: Colors.black, + alignment: Alignment.center, + child: (!kIsWeb && + defaultTargetPlatform == TargetPlatform.macOS) + ? FullCameraPreview( + controller: _controller.mediaManager.cameraController, + deviceId: _controller.mediaManager.selectedCameraId, + onInitialized: (controller) { + _controller.mediaManager + .setMacOSController(controller); + }, + ) + : (_controller.mediaManager.cameraController != null && + _controller.mediaManager.controllerInitialized) + ? FullCameraPreview( + controller: + _controller.mediaManager.cameraController, + deviceId: + _controller.mediaManager.selectedCameraId, + onInitialized: (controller) {}, + ) + : const Center(child: CircularProgressIndicator()), + ), + Expanded( + child: ListView.builder( + controller: _scrollController, + itemCount: _controller.messages.length, + itemBuilder: (context, idx) { + final message = _controller.messages[idx]; + return MessageWidget( + text: message.text, + image: message.imageBytes != null + ? Image.memory( + message.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ) + : null, + isFromUser: message.fromUser ?? false, + isThought: message.isThought, + ); + }, + ), + ), + Padding( + padding: + const EdgeInsets.symmetric(vertical: 25, horizontal: 15), + child: Row( + children: [ + Expanded( + child: TextField( + focusNode: _textFieldFocus, + controller: _textController, + onSubmitted: (text) { + _controller.sendTextPrompt(text); + _textController.clear(); + }, + ), + ), + const SizedBox.square(dimension: 15), + AudioVisualizer( + audioStreamIsActive: _controller.isMicOn, + amplitudeStream: _controller.mediaManager.amplitudeStream, + ), + const SizedBox.square(dimension: 15), + IconButton( + tooltip: 'Start Streaming', + onPressed: !_controller.isLoading + ? () => _controller.toggleSession() + : null, + icon: Icon( + Icons.network_wifi, + color: _controller.isSessionActive + ? Theme.of(context).colorScheme.secondary + : Theme.of(context).colorScheme.primary, + ), + ), + IconButton( + tooltip: 'Send Stream Message', + onPressed: !_controller.isLoading + ? () => _controller.toggleMic() + : null, + icon: Icon( + _controller.isMicOn ? Icons.stop : Icons.mic, + color: _controller.isLoading + ? Theme.of(context).colorScheme.secondary + : Theme.of(context).colorScheme.primary, + ), + ), + IconButton( + tooltip: 'Toggle Camera', + onPressed: !_controller.isLoading + ? () => _controller.toggleCamera() + : null, + icon: Icon( + _controller.isCameraOn + ? Icons.videocam_off + : Icons.videocam, + color: _controller.isLoading + ? Theme.of(context).colorScheme.secondary + : Theme.of(context).colorScheme.primary, + ), + ), + if (!_controller.isLoading) + IconButton( + tooltip: 'Send Text', + onPressed: () { + _controller.sendTextPrompt(_textController.text); + _textController.clear(); + }, + icon: Icon( + Icons.send, + color: Theme.of(context).colorScheme.primary, + ), + ) + else + const CircularProgressIndicator(), + ], + ), + ), + ], + ), + ); + }, + ); + } } diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 681989831e53..d2586202923b 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -99,15 +99,19 @@ export 'src/imagen/imagen_reference.dart' ImagenControlReference; export 'src/live_api.dart' show - LiveGenerationConfig, - SpeechConfig, AudioTranscriptionConfig, + ContextWindowCompressionConfig, + GoingAwayNotice, + LiveGenerationConfig, LiveServerMessage, LiveServerContent, LiveServerToolCall, LiveServerToolCallCancellation, LiveServerResponse, - GoingAwayNotice, + SessionResumptionConfig, + SessionResumptionUpdate, + SlidingWindow, + SpeechConfig, Transcription; export 'src/live_session.dart' show LiveSession; export 'src/schema.dart' show Schema, SchemaType; diff --git a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart index 01ac7eb834b3..ea95fdb09918 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart @@ -18,11 +18,8 @@ import 'dart:convert'; import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; -import 'package:flutter/foundation.dart'; import 'package:http/http.dart' as http; import 'package:meta/meta.dart'; -import 'package:web_socket_channel/io.dart'; -import 'package:web_socket_channel/web_socket_channel.dart'; import 'api.dart'; import 'client.dart'; diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_api.dart b/packages/firebase_ai/firebase_ai/lib/src/live_api.dart index 9db6bd845476..ba70050c7d3b 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_api.dart @@ -77,21 +77,90 @@ class AudioTranscriptionConfig { Map toJson() => {}; } +/// Configures the sliding window context compression mechanism. +/// +/// The context window will be truncated by keeping only a suffix of it. +class SlidingWindow { + /// Creates a [SlidingWindow] instance. + /// + /// [targetTokens] (optional): The target number of tokens to keep in the + /// context window. + SlidingWindow({this.targetTokens}); + + /// The session reduction target, i.e., how many tokens we should keep. + final int? targetTokens; + // ignore: public_member_api_docs + Map toJson() => + {if (targetTokens case final targetTokens?) 'targetTokens': targetTokens}; +} + +/// Enables context window compression to manage the model's context window. +/// +/// This mechanism prevents the context from exceeding a given length. +class ContextWindowCompressionConfig { + /// Creates a [ContextWindowCompressionConfig] instance. + /// + /// [triggerTokens] (optional): The number of tokens that triggers the + /// compression mechanism. + /// [slidingWindow] (optional): The sliding window compression mechanism to + /// use. + ContextWindowCompressionConfig({this.triggerTokens, this.slidingWindow}); + + /// The number of tokens (before running a turn) that triggers the context + /// window compression. + final int? triggerTokens; + + /// The sliding window compression mechanism. + final SlidingWindow? slidingWindow; + // ignore: public_member_api_docs + Map toJson() => { + if (triggerTokens case final triggerTokens?) + 'triggerTokens': triggerTokens, + if (slidingWindow case final slidingWindow?) + 'slidingWindow': slidingWindow.toJson() + }; +} + +/// Configuration for the session resumption mechanism. +/// +/// When included in the session setup, the server will send +/// [SessionResumptionUpdate] messages. +class SessionResumptionConfig { + /// Creates a [SessionResumptionConfig] instance. + /// + /// [handle] (optional): The session resumption handle of the previous session + /// to restore. + /// [transparent] (optional): If set, requests the server to send updates with + /// the message index of the last client message included in the session + /// state. + SessionResumptionConfig({this.handle}); + + /// The session resumption handle of the previous session to restore. + /// + /// If not present, a new session will be started. + final String? handle; + + // ignore: public_member_api_docs + Map toJson() => { + if (handle case final handle?) 'handle': handle, + }; +} + /// Configures live generation settings. final class LiveGenerationConfig extends BaseGenerationConfig { // ignore: public_member_api_docs - LiveGenerationConfig({ - this.speechConfig, - this.inputAudioTranscription, - this.outputAudioTranscription, - super.responseModalities, - super.maxOutputTokens, - super.temperature, - super.topP, - super.topK, - super.presencePenalty, - super.frequencyPenalty, - }); + LiveGenerationConfig( + {this.speechConfig, + this.inputAudioTranscription, + this.outputAudioTranscription, + this.contextWindowCompression, + super.responseModalities, + super.maxOutputTokens, + super.temperature, + super.topP, + super.topK, + super.presencePenalty, + super.frequencyPenalty}); /// The speech configuration. final SpeechConfig? speechConfig; @@ -103,6 +172,9 @@ final class LiveGenerationConfig extends BaseGenerationConfig { /// the output audio. final AudioTranscriptionConfig? outputAudioTranscription; + /// The context window compression configuration. + final ContextWindowCompressionConfig? contextWindowCompression; + @override Map toJson() => { ...super.toJson(), @@ -222,6 +294,34 @@ class GoingAwayNotice implements LiveServerMessage { final String? timeLeft; } +/// An update of the session resumption state. +/// +/// This message is only sent if [SessionResumptionConfig] was set in the +/// session setup. +class SessionResumptionUpdate implements LiveServerMessage { + /// Creates a [SessionResumptionUpdate] instance. + /// + /// [newHandle] (optional): The new handle that represents the state that can + /// be resumed. + /// [resumable] (optional): Indicates if the session can be resumed at this + /// point. + /// [lastConsumedClientMessageIndex] (optional): The index of the last client + /// message that is included in the state represented by this update. + SessionResumptionUpdate( + {this.newHandle, this.resumable, this.lastConsumedClientMessageIndex}); + + /// The new handle that represents the state that can be resumed. Empty if + /// `resumable` is false. + final String? newHandle; + + /// Indicates if the session can be resumed at this point. + final bool? resumable; + + /// The index of the last client message that is included in the state + /// represented by this update. + final int? lastConsumedClientMessageIndex; +} + /// A single response chunk received during a live content generation. /// /// It can contain generated content, function calls to be executed, or @@ -449,8 +549,17 @@ LiveServerMessage _parseServerMessage(Object jsonObject) { } else if (json.containsKey('setupComplete')) { return LiveServerSetupComplete(); } else if (json.containsKey('goAway')) { - final goAwayJson = json['goAway'] as Map; + final goAwayJson = json['goAway'] as Map; return GoingAwayNotice(timeLeft: goAwayJson['timeLeft'] as String?); + } else if (json.containsKey('sessionResumptionUpdate')) { + final sessionResumptionUpdateJson = + json['sessionResumptionUpdate'] as Map; + return SessionResumptionUpdate( + newHandle: sessionResumptionUpdateJson['newHandle'] as String?, + resumable: sessionResumptionUpdateJson['resumable'] as bool?, + lastConsumedClientMessageIndex: + sessionResumptionUpdateJson['lastConsumedClientMessageIndex'] as int?, + ); } else { throw unhandledFormat('LiveServerMessage', json); } diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_model.dart b/packages/firebase_ai/firebase_ai/lib/src/live_model.dart index 3414b1376af1..8d6bfd478965 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_model.dart @@ -93,30 +93,12 @@ final class LiveGenerativeModel extends BaseModel { /// /// Returns a [Future] that resolves to an [LiveSession] object upon successful /// connection. - Future connect() async { + Future connect( + {SessionResumptionConfig? sessionResumption}) async { final uri = _useVertexBackend ? _vertexAIUri() : _googleAIUri(); final modelString = _useVertexBackend ? _vertexAIModelString() : _googleAIModelString(); - final setupJson = { - 'setup': { - 'model': modelString, - if (_systemInstruction != null) - 'system_instruction': _systemInstruction.toJson(), - if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(), - if (_liveGenerationConfig != null) ...{ - 'generation_config': _liveGenerationConfig.toJson(), - if (_liveGenerationConfig.inputAudioTranscription != null) - 'input_audio_transcription': - _liveGenerationConfig.inputAudioTranscription!.toJson(), - if (_liveGenerationConfig.outputAudioTranscription != null) - 'output_audio_transcription': - _liveGenerationConfig.outputAudioTranscription!.toJson(), - }, - } - }; - - final request = jsonEncode(setupJson); final headers = await BaseModel.firebaseTokens( _appCheck, _auth, @@ -124,14 +106,15 @@ final class LiveGenerativeModel extends BaseModel { _useLimitedUseAppCheckTokens, )(); - var ws = kIsWeb - ? WebSocketChannel.connect(Uri.parse(uri)) - : IOWebSocketChannel.connect(Uri.parse(uri), headers: headers); - await ws.ready; - - ws.sink.add(request); - - return LiveSession(ws); + return LiveSession.connect( + uri: uri, + headers: headers, + modelString: modelString, + systemInstruction: _systemInstruction, + tools: _tools, + sessionResumption: sessionResumption, + liveGenerationConfig: _liveGenerationConfig, + ); } } diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_session.dart b/packages/firebase_ai/firebase_ai/lib/src/live_session.dart index f136a644d03d..9e21ea530bb2 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_session.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_session.dart @@ -16,17 +16,128 @@ import 'dart:async'; import 'dart:convert'; import 'dart:developer'; +import 'package:flutter/foundation.dart'; +import 'package:web_socket_channel/io.dart'; import 'package:web_socket_channel/web_socket_channel.dart'; import 'content.dart'; import 'error.dart'; import 'live_api.dart'; +import 'tool.dart'; /// Manages asynchronous communication with Gemini model over a WebSocket /// connection. class LiveSession { // ignore: public_member_api_docs - LiveSession(this._ws) { + LiveSession._( + this._ws, { + required String uri, + required Map headers, + required String modelString, + Content? systemInstruction, + List? tools, + LiveGenerationConfig? liveGenerationConfig, + }) : _uri = uri, + _headers = headers, + _modelString = modelString, + _systemInstruction = systemInstruction, + _tools = tools, + _liveGenerationConfig = liveGenerationConfig, + _messageController = StreamController.broadcast() { + _listenToWebSocket(); + } + + /// Establishes a connection to a live generation service. + /// + /// This function handles the WebSocket connection setup and returns an [LiveSession] + /// object that can be used to communicate with the service. + /// + /// Returns a [Future] that resolves to an [LiveSession] object upon successful + /// connection. + static Future connect({ + required String uri, + required Map headers, + required String modelString, + Content? systemInstruction, + List? tools, + SessionResumptionConfig? sessionResumption, + LiveGenerationConfig? liveGenerationConfig, + }) async { + final ws = await _performWebSocketSetup( + uri: uri, + headers: headers, + modelString: modelString, + systemInstruction: systemInstruction, + tools: tools, + sessionResumption: sessionResumption, + liveGenerationConfig: liveGenerationConfig, + ); + return LiveSession._( + ws, + uri: uri, + headers: headers, + modelString: modelString, + systemInstruction: systemInstruction, + tools: tools, + liveGenerationConfig: liveGenerationConfig, + ); + } + + // Persisted values for session resumption. + final String _uri; + final Map _headers; + final String _modelString; + final Content? _systemInstruction; + final List? _tools; + final LiveGenerationConfig? _liveGenerationConfig; + + WebSocketChannel _ws; + StreamController _messageController; + late StreamSubscription _wsSubscription; + + static Future _performWebSocketSetup({ + required String uri, + required Map headers, + required String modelString, + Content? systemInstruction, + List? tools, + SessionResumptionConfig? sessionResumption, + LiveGenerationConfig? liveGenerationConfig, + }) async { + final setupJson = { + 'setup': { + 'model': modelString, + if (systemInstruction != null) + 'system_instruction': systemInstruction.toJson(), + if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), + if (sessionResumption != null) + 'session_resumption': sessionResumption.toJson(), + if (liveGenerationConfig != null) ...{ + 'generation_config': liveGenerationConfig.toJson(), + if (liveGenerationConfig.inputAudioTranscription != null) + 'input_audio_transcription': + liveGenerationConfig.inputAudioTranscription!.toJson(), + if (liveGenerationConfig.outputAudioTranscription != null) + 'output_audio_transcription': + liveGenerationConfig.outputAudioTranscription!.toJson(), + if (liveGenerationConfig.contextWindowCompression + case final contextWindowCompression?) + 'contextWindowCompression': contextWindowCompression.toJson() + }, + } + }; + + final request = jsonEncode(setupJson); + final ws = kIsWeb + ? WebSocketChannel.connect(Uri.parse(uri)) + : IOWebSocketChannel.connect(Uri.parse(uri), headers: headers); + await ws.ready; + + ws.sink.add(request); + return ws; + } + + void _listenToWebSocket() { _wsSubscription = _ws.stream.listen( (message) { try { @@ -44,9 +155,32 @@ class LiveSession { onDone: _messageController.close, ); } - final WebSocketChannel _ws; - final _messageController = StreamController.broadcast(); - late StreamSubscription _wsSubscription; + + /// Resumes an existing live session with the server. + /// + /// This closes the current WebSocket connection and establishes a new one using + /// the same configuration (URI, headers, model, system instruction, tools, etc.) + /// as the original session. + /// + /// [sessionResumption] (optional): The configuration for session resumption, + /// such as the handle to the previous session state to restore. + Future resumeSession( + {SessionResumptionConfig? sessionResumption}) async { + await _wsSubscription.cancel(); + await _ws.sink.close(); + + _ws = await _performWebSocketSetup( + uri: _uri, + headers: _headers, + modelString: _modelString, + systemInstruction: _systemInstruction, + tools: _tools, + sessionResumption: sessionResumption, + liveGenerationConfig: _liveGenerationConfig, + ); + + _listenToWebSocket(); + } /// Sends content to the server. /// @@ -107,6 +241,7 @@ class LiveSession { /// [text]: The text data to send. Future sendTextRealtime(String text) async { _checkWsStatus(); + log('sendTextRealtime: $text'); var clientMessage = LiveClientRealtimeInput.text(text); var clientJson = jsonEncode(clientMessage.toJson()); _ws.sink.add(clientJson); @@ -165,10 +300,8 @@ class LiveSession { _checkWsStatus(); await for (final result in _messageController.stream) { + log('live_session.received result, ${result.message.runtimeType}'); yield result; - if (result case LiveServerContent(turnComplete: true)) { - break; // Exit the loop when the turn is complete - } } }