diff --git a/apps/llm/package.json b/apps/llm/package.json
index 8243a5eefe..acdc2379e7 100644
--- a/apps/llm/package.json
+++ b/apps/llm/package.json
@@ -28,7 +28,7 @@
"metro-config": "^0.83.0",
"react": "19.2.5",
"react-native": "0.83.4",
- "react-native-audio-api": "0.12.0",
+ "react-native-audio-api": "0.12.2",
"react-native-device-info": "^15.0.2",
"react-native-executorch": "workspace:*",
"react-native-executorch-expo-resource-fetcher": "workspace:*",
diff --git a/apps/speech/App.tsx b/apps/speech/App.tsx
index d336319805..3532fbce1a 100644
--- a/apps/speech/App.tsx
+++ b/apps/speech/App.tsx
@@ -2,6 +2,7 @@ import React, { useState } from 'react';
import { View, Text, StyleSheet, TouchableOpacity } from 'react-native';
import { TextToSpeechScreen } from './screens/TextToSpeechScreen';
import { SpeechToTextScreen } from './screens/SpeechToTextScreen';
+import { VoiceActivityDetectionScreen } from './screens/VoiceActivityDetectionScreen';
import ColorPalette from './colors';
import ExecutorchLogo from './assets/executorch.svg';
import { Quiz } from './screens/Quiz';
@@ -15,7 +16,12 @@ initExecutorch({
export default function App() {
const [currentScreen, setCurrentScreen] = useState<
- 'menu' | 'speech-to-text' | 'text-to-speech' | 'quiz' | 'text-to-speech-llm'
+ | 'menu'
+ | 'speech-to-text'
+ | 'text-to-speech'
+ | 'quiz'
+ | 'text-to-speech-llm'
+ | 'vad'
>('menu');
const goToMenu = () => setCurrentScreen('menu');
@@ -28,6 +34,10 @@ export default function App() {
return ;
}
+ if (currentScreen === 'vad') {
+ return ;
+ }
+
if (currentScreen === 'quiz') {
return ;
}
@@ -47,6 +57,12 @@ export default function App() {
>
Speech to Text
+ setCurrentScreen('vad')}
+ >
+ Voice Activity Detection
+
setCurrentScreen('text-to-speech')}
diff --git a/apps/speech/screens/SpeechToTextScreen.tsx b/apps/speech/screens/SpeechToTextScreen.tsx
index 5c6e2228fc..5f88f3764a 100644
--- a/apps/speech/screens/SpeechToTextScreen.tsx
+++ b/apps/speech/screens/SpeechToTextScreen.tsx
@@ -9,6 +9,7 @@ import {
KeyboardAvoidingView,
Platform,
Switch,
+ Keyboard,
} from 'react-native';
import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context';
import {
@@ -19,6 +20,7 @@ import {
} from 'react-native-executorch';
import { ModelPicker, ModelOption } from '../components/ModelPicker';
const speechToText = models.speech_to_text;
+const vad = models.vad;
type STTModelSources = SpeechToTextProps['model'];
@@ -51,6 +53,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
const model = useSpeechToText({
model: selectedModel,
+ vad: vad.fsmn_vad(),
});
const [transcription, setTranscription] =
@@ -65,6 +68,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
} | null>(null);
const [enableTimestamps, setEnableTimestamps] = useState(false);
+ const [useVAD, setUseVAD] = useState(true);
const [error, setError] = useState(null);
const [audioURL, setAudioURL] = useState('');
const [hasMicPermission, setHasMicPermission] = useState(false);
@@ -104,11 +108,15 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
}
const handleTranscribeFromURL = async () => {
- if (!audioURL.trim()) {
- console.warn('Please provide a valid audio file URL');
+ if (!audioURL.trim() || model.isGenerating) {
+ if (!audioURL.trim()) {
+ console.warn('Please provide a valid audio file URL');
+ }
return;
}
+ Keyboard.dismiss();
+
// Reset previous states
setTranscription(null);
setLiveResult(null);
@@ -131,8 +139,10 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
};
const handleStartTranscribeFromMicrophone = async () => {
- if (!hasMicPermission) {
- setError('Microphone permission denied. Please enable it in Settings.');
+ if (!hasMicPermission || model.isGenerating || liveTranscribing) {
+ if (!hasMicPermission) {
+ setError('Microphone permission denied. Please enable it in Settings.');
+ }
return;
}
@@ -177,7 +187,9 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
try {
const streamIter = model.stream({
verbose: enableTimestamps,
- timeout: 100,
+ timeout: 200,
+ useVAD: useVAD,
+ vadDetectionMargin: 1200,
});
for await (const { committed, nonCommitted } of streamIter) {
@@ -352,22 +364,64 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
Stop Live Transcription
) : (
-
-
-
- {isSimulator
- ? 'Recording is not available on Simulator'
- : 'Start Live Transcription'}
-
-
+
+
+
+
+ {isSimulator ? 'No Mic' : 'Start Live'}
+
+
+
+ setUseVAD(!useVAD)}
+ activeOpacity={0.7}
+ accessibilityRole="switch"
+ accessibilityState={{ checked: useVAD }}
+ accessibilityLabel={`Voice Activity Detection ${useVAD ? 'on' : 'off'}`}
+ style={[
+ styles.vadButton,
+ useVAD ? styles.vadActive : styles.vadInactive,
+ recordingButtonDisabled && styles.disabled,
+ ]}
+ >
+
+
+
+ VAD
+
+
+ {useVAD ? 'ON' : 'OFF'}
+
+
+
+
)}
@@ -492,6 +546,54 @@ const styles = StyleSheet.create({
backgroundBlue: {
backgroundColor: '#0f186e',
},
+ buttonRow: {
+ flexDirection: 'row',
+ gap: 8,
+ marginTop: 12,
+ },
+ flex1: {
+ flex: 1,
+ marginTop: 0,
+ },
+ vadButton: {
+ flexDirection: 'row',
+ alignItems: 'center',
+ justifyContent: 'center',
+ paddingHorizontal: 14,
+ borderRadius: 12,
+ gap: 10,
+ },
+ vadActive: {
+ backgroundColor: '#0f186e',
+ },
+ vadInactive: {
+ backgroundColor: '#f1f5f9',
+ },
+ vadTextContainer: {
+ alignItems: 'flex-start',
+ },
+ vadButtonLabel: {
+ fontWeight: '800',
+ fontSize: 13,
+ letterSpacing: 0.5,
+ },
+ vadButtonLabelActive: {
+ color: 'white',
+ },
+ vadButtonLabelInactive: {
+ color: '#64748b',
+ },
+ vadButtonState: {
+ fontWeight: '700',
+ fontSize: 10,
+ letterSpacing: 1,
+ },
+ vadButtonStateActive: {
+ color: '#bbf7d0',
+ },
+ vadButtonStateInactive: {
+ color: '#94a3b8',
+ },
disabled: {
opacity: 0.5,
},
diff --git a/apps/speech/screens/TextToSpeechScreen.tsx b/apps/speech/screens/TextToSpeechScreen.tsx
index 198c40faf8..919076dc35 100644
--- a/apps/speech/screens/TextToSpeechScreen.tsx
+++ b/apps/speech/screens/TextToSpeechScreen.tsx
@@ -7,6 +7,7 @@ import {
TextInput,
KeyboardAvoidingView,
Platform,
+ Keyboard,
} from 'react-native';
import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context';
import {
@@ -124,6 +125,7 @@ export const TextToSpeechScreen = ({ onBack }: { onBack: () => void }) => {
return;
}
+ Keyboard.dismiss();
setIsPlaying(true);
try {
diff --git a/apps/speech/screens/VoiceActivityDetectionScreen.tsx b/apps/speech/screens/VoiceActivityDetectionScreen.tsx
new file mode 100644
index 0000000000..724ea52500
--- /dev/null
+++ b/apps/speech/screens/VoiceActivityDetectionScreen.tsx
@@ -0,0 +1,329 @@
+import React, { useEffect, useRef, useState } from 'react';
+import {
+ Text,
+ View,
+ StyleSheet,
+ TouchableOpacity,
+ ScrollView,
+ Platform,
+} from 'react-native';
+import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context';
+import {
+ models,
+ useVAD
+} from 'react-native-executorch';
+import FontAwesome from '@expo/vector-icons/FontAwesome';
+import { AudioManager, AudioRecorder } from 'react-native-audio-api';
+import SWMIcon from '../assets/swm_icon.svg';
+import DeviceInfo from 'react-native-device-info';
+import ErrorBanner from '../components/ErrorBanner';
+
+const isSimulator = DeviceInfo.isEmulatorSync();
+
+export const VoiceActivityDetectionScreen = ({
+ onBack,
+}: {
+ onBack: () => void;
+}) => {
+ const model = useVAD({
+ model: models.vad.fsmn_vad(),
+ });
+
+ const [isSpeaking, setIsSpeaking] = useState(false);
+ const [error, setError] = useState(null);
+ const [hasMicPermission, setHasMicPermission] = useState(false);
+ const [isStreaming, setIsStreaming] = useState(false);
+
+ const recorder = useRef(new AudioRecorder());
+ const logScrollRef = useRef(null);
+ const [logs, setLogs] = useState([]);
+
+ const addLog = (msg: string) => {
+ setLogs((prev) => [...prev, `${new Date().toLocaleTimeString()}: ${msg}`]);
+ };
+
+ useEffect(() => {
+ AudioManager.setAudioSessionOptions({
+ iosCategory: 'playAndRecord',
+ iosMode: 'spokenAudio',
+ iosOptions: ['allowBluetoothHFP', 'defaultToSpeaker'],
+ });
+ const checkPerms = async () => {
+ const status = await AudioManager.requestRecordingPermissions();
+ setHasMicPermission(status === 'Granted');
+ };
+ checkPerms();
+ }, []);
+
+ const handleStartStreaming = async () => {
+ if (isStreaming || model.isGenerating || !model.isReady) {
+ return;
+ }
+
+ setIsStreaming(true);
+ if (!hasMicPermission) {
+ setError('Microphone permission denied. Please enable it in Settings.');
+ setIsStreaming(false);
+ return;
+ }
+
+ setLogs([]);
+ addLog('Starting VAD stream...');
+
+ const sampleRate = 16000;
+
+ recorder.current.onAudioReady(
+ {
+ sampleRate,
+ bufferLength: 0.1 * sampleRate,
+ channelCount: 1,
+ },
+ ({ buffer }) => {
+ model.streamInsert(buffer.getChannelData(0));
+ }
+ );
+
+ try {
+ const success = await AudioManager.setAudioSessionActivity(true);
+ if (!success) {
+ setError('Cannot start audio session correctly');
+ }
+ const result = recorder.current.start();
+ if (result.status === 'error') {
+ setError(`Recording problems: ${result.status}`);
+ }
+
+ await model.stream({
+ onSpeechBegin: () => {
+ setIsSpeaking(true);
+ addLog('Speech detected (Begin)');
+ },
+ onSpeechEnd: () => {
+ setIsSpeaking(false);
+ addLog('Silence detected (End)');
+ },
+ options: {
+ timeout: 100,
+ detectionMargin: 300,
+ },
+ });
+ } catch (e) {
+ setError(e instanceof Error ? e.message : String(e));
+ setIsStreaming(false);
+ }
+ };
+
+ const handleStopStreaming = () => {
+ recorder.current.stop();
+ model.streamStop();
+ setIsStreaming(false);
+ setIsSpeaking(false);
+ addLog('VAD stream stopped');
+ };
+
+ const getModelStatus = () => {
+ if (isStreaming || model.isGenerating) return 'Processing...';
+ if (model.isReady) return 'Ready';
+ return `Loading model: ${(100 * model.downloadProgress).toFixed(2)}%`;
+ };
+
+ useEffect(() => {
+ if (model.error) setError(String(model.error));
+ }, [model.error]);
+
+ const readyToStream = model.isReady;
+ const recordingButtonDisabled =
+ isSimulator || !readyToStream || model.isGenerating;
+
+ return (
+
+
+
+
+
+
+
+ React Native ExecuTorch
+ Voice Activity Detection
+
+
+
+ Status: {getModelStatus()}
+
+
+ setError(null)} />
+
+
+
+
+ {isSpeaking ? 'SPEAKING' : 'SILENT'}
+
+
+
+
+ VAD Events
+
+ logScrollRef.current?.scrollToEnd({ animated: true })
+ }
+ >
+ {logs.length > 0 ? (
+ logs.map((log, i) => (
+
+ {log}
+
+ ))
+ ) : (
+
+ No events logged yet...
+
+ )}
+
+
+
+
+ {isStreaming ? (
+
+
+ Stop VAD Stream
+
+ ) : (
+
+
+
+ {isSimulator
+ ? 'Recording not available on Simulator'
+ : 'Start VAD Stream'}
+
+
+ )}
+
+
+
+ );
+};
+
+const styles = StyleSheet.create({
+ container: {
+ flex: 1,
+ alignItems: 'center',
+ backgroundColor: 'white',
+ paddingHorizontal: 16,
+ },
+ header: {
+ alignItems: 'center',
+ position: 'relative',
+ width: '100%',
+ },
+ backButton: {
+ position: 'absolute',
+ left: 0,
+ top: 10,
+ padding: 10,
+ zIndex: 1,
+ },
+ headerText: {
+ fontSize: 22,
+ fontWeight: 'bold',
+ color: '#0f186e',
+ },
+ statusContainer: {
+ marginTop: 12,
+ alignItems: 'center',
+ },
+ visualizerContainer: {
+ flex: 1,
+ justifyContent: 'center',
+ alignItems: 'center',
+ },
+ visualizerText: {
+ marginTop: 20,
+ fontSize: 24,
+ fontWeight: '800',
+ letterSpacing: 2,
+ },
+ speakingText: {
+ color: '#22c55e',
+ },
+ silentText: {
+ color: '#ef4444',
+ },
+ logContainer: {
+ height: 150,
+ width: '100%',
+ marginVertical: 12,
+ },
+ logLabel: {
+ marginLeft: 12,
+ marginBottom: 4,
+ color: '#0f186e',
+ fontWeight: '600',
+ },
+ logScrollContainer: {
+ borderRadius: 12,
+ borderWidth: 1,
+ borderColor: '#0f186e',
+ padding: 12,
+ backgroundColor: '#f8fafc',
+ },
+ logText: {
+ fontSize: 12,
+ fontFamily: Platform.OS === 'ios' ? 'Menlo' : 'monospace',
+ color: '#334155',
+ marginBottom: 2,
+ },
+ placeholderText: {
+ color: '#aaa',
+ fontStyle: 'italic',
+ },
+ inputContainer: {
+ marginBottom: 30,
+ width: '100%',
+ },
+ liveButton: {
+ flexDirection: 'row',
+ justifyContent: 'center',
+ alignItems: 'center',
+ padding: 16,
+ borderRadius: 12,
+ gap: 8,
+ },
+ backgroundRed: {
+ backgroundColor: '#ef4444',
+ },
+ backgroundBlue: {
+ backgroundColor: '#0f186e',
+ },
+ buttonText: {
+ color: 'white',
+ fontWeight: '600',
+ fontSize: 16,
+ },
+ disabled: {
+ opacity: 0.5,
+ },
+});
diff --git a/docs/docs/03-hooks/01-natural-language-processing/useSpeechToText.md b/docs/docs/03-hooks/01-natural-language-processing/useSpeechToText.md
index 80141db45f..02d0008dda 100644
--- a/docs/docs/03-hooks/01-natural-language-processing/useSpeechToText.md
+++ b/docs/docs/03-hooks/01-natural-language-processing/useSpeechToText.md
@@ -49,7 +49,7 @@ import { AudioContext } from 'react-native-audio-api';
import * as FileSystem from 'expo-file-system';
const model = useSpeechToText({
- model: models.speech_to_text.whisper_tiny_en(),
+ model: models.speech_to_text.whisper_tiny_en(), // Use whisper_tiny_en for English or whisper_tiny for multilingual support
});
// 1. Get audio file
@@ -89,7 +89,13 @@ The `stream()` function accepts several optional parameters:
- `language`: The language code (e.g., `'es'`, `'fr'`). Required for multilingual models.
- `verbose`: If `true`, includes word-level timestamps and segment metadata in the result objects.
+- `useVAD`: Enable the Voice Activity Detection submodule (if configured in `useSpeechToText` props) to optimize performance by filtering silence. Defaults to `false`.
- `timeout`: (Advanced) The interval (in milliseconds) between processing consecutive audio chunks in streaming mode. Lower values provide more frequent updates and lower latency, while higher values reduce CPU consumption. Defaults to `100`.
+- `vadDetectionMargin`: (Advanced) The duration of silence (in milliseconds) required after speech is detected before "committing" a segment. Defaults to `500`. Only active when VAD module is used.
+
+### Voice Activity Detection (VAD)
+
+Integrating a VAD submodule is highly recommended for streaming. It improves performance by automatically removing silence, which reduces CPU usage, saves battery, and prevents the model from "hallucinating" text during silent periods.
### Example
@@ -102,6 +108,7 @@ import { AudioManager, AudioRecorder } from 'react-native-audio-api';
export default function LiveTranscriber() {
const model = useSpeechToText({
model: models.speech_to_text.whisper_tiny_en(),
+ vad: models.vad.fsmn_vad(),
});
const [text, setText] = useState('');
const isRecordingRef = useRef(false);
@@ -111,7 +118,7 @@ export default function LiveTranscriber() {
isRecordingRef.current = true;
setText('');
- // 1. Capture microphone input
+ // 2. Capture microphone input
recorder.onAudioReady(
{ sampleRate: 16000, bufferLength: 1600, channelCount: 1 },
(chunk) => model.streamInsert(chunk.buffer.getChannelData(0))
@@ -119,10 +126,14 @@ export default function LiveTranscriber() {
await recorder.start();
- // 2. Process the stream
+ // 3. Process the stream with VAD enabled
try {
let finalizedText = '';
- const streamIter = model.stream({ verbose: false });
+ const streamIter = model.stream({
+ verbose: false,
+ useVAD: true, // Enable VAD filter
+ vadDetectionMargin: 500, // Wait for 500ms of silence before committing
+ });
for await (const { committed, nonCommitted } of streamIter) {
if (!isRecordingRef.current) break;
@@ -163,6 +174,9 @@ To transcribe languages other than English, use a multilingual model (e.g., `mod
```typescript
// Transcribe in Spanish
+const model = useSpeechToText({
+ model: models.speech_to_text.whisper_tiny(),
+});
const result = await model.transcribe(spanishAudio, { language: 'es' });
```
diff --git a/docs/docs/03-hooks/01-natural-language-processing/useVAD.md b/docs/docs/03-hooks/01-natural-language-processing/useVAD.md
index 01cf8cd4e3..f05d53c10f 100644
--- a/docs/docs/03-hooks/01-natural-language-processing/useVAD.md
+++ b/docs/docs/03-hooks/01-natural-language-processing/useVAD.md
@@ -13,145 +13,79 @@ It is recommended to use models provided by us, which are available at our [Hugg
- For detailed API Reference for `useVAD` see: [`useVAD` API Reference](../../06-api-reference/functions/useVAD.md).
- For all VAD models available out-of-the-box in React Native ExecuTorch see: [VAD Models](../../06-api-reference/index.md#models---voice-activity-detection).
-## High Level Overview
+## Static Audio (Batch) processing
-You can obtain waveform from audio in any way most suitable to you, however in the snippet below we utilize [`react-native-audio-api`](https://docs.swmansion.com/react-native-audio-api/) library to process a `.mp3` file.
+This mode is best suited for processing pre-recorded audio files or existing buffers. You provide a full waveform to the `forward` method, which returns an array of detected speech segments.
```typescript
-import { models, useVAD } from 'react-native-executorch';
-import { AudioContext } from 'react-native-audio-api';
-import * as FileSystem from 'expo-file-system';
+import { useVAD, models } from 'react-native-executorch';
-const model = useVAD({
- model: models.vad.fsmn_vad(),
-});
+const model = useVAD({ model: models.vad.fsmn_vad() });
-const { uri } = await FileSystem.downloadAsync(
- 'https://some-audio-url.com/file.mp3',
- FileSystem.cacheDirectory + 'audio_file'
-);
-
-const audioContext = new AudioContext({ sampleRate: 16000 });
-const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
-const audioBuffer = decodedAudioData.getChannelData(0);
+// ... obtain audioBuffer (Float32Array) at 16kHz ...
try {
- // NOTE: to obtain segments in seconds, you need to divide
- // start / end of the segment by the sampling rate (16k)
-
const speechSegments = await model.forward(audioBuffer);
- console.log(speechSegments);
+ console.log('Speech detected at:', speechSegments);
} catch (error) {
- console.error('Error during running VAD model', error);
+ console.error('VAD Error:', error);
}
```
-### Arguments
-
-`useVAD` takes [`VADProps`](../../06-api-reference/interfaces/VADProps.md) that consists of:
+:::note
+Timestamps in `Segment[]` correspond to the indices of the input array. Divide them by your sampling rate (usually 16000) to get results in seconds.
+:::
-- `model` containing [`modelSource`](../../06-api-reference/interfaces/VADProps.md#modelsource).
-- An optional flag [`preventLoad`](../../06-api-reference/interfaces/VADProps.md#preventload) which prevents auto-loading of the model.
+## Live Streaming (Real-time detection)
-You need more details? Check the following resources:
+Live streaming allows you to process audio in real-time as it arrives from a microphone or network stream. It uses an internal state to track speech transitions across chunks.
-- For detailed information about `useVAD` arguments check this section: [`useVAD` arguments](../../06-api-reference/functions/useVAD.md#parameters).
-- For all VAD models available out-of-the-box in React Native ExecuTorch see: [VAD Models](../../06-api-reference/index.md#models---voice-activity-detection).
-- For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
+### How it works
-### Returns
+1. **Start the session**: Call `model.stream()` with callbacks for speech events. This returns a promise that stays active until the stream is stopped.
+2. **Feed audio**: Periodically push audio chunks using `model.streamInsert()`.
+3. **Handle events**: Use `onSpeechBegin` and `onSpeechEnd` callbacks to trigger UI updates or toggle recording for other tasks (like STT).
+4. **End the session**: Call `model.streamStop()` to clean up.
-`useVAD` returns an object called `VADType` containing bunch of functions to interact with VAD models. To get more details please read: [`VADType` API Reference](../../06-api-reference/interfaces/VADType.md).
+### Configuration Options
-## Running the model
+You can fine-tune the streaming behavior via the `options` object:
-Before running the model's [`forward`](../../06-api-reference/interfaces/VADType.md#forward) method, make sure to extract the audio waveform you want to process. You'll need to handle this step yourself, ensuring the audio is sampled at 16 kHz. Once you have the waveform, pass it as an argument to the [`forward`](../../06-api-reference/interfaces/VADType.md#forward) method. The method returns a promise that resolves to the array of detected speech [`Segment[]`](../../06-api-reference/interfaces/Segment.md).
-
-:::note
-Timestamps in returned speech segments, correspond to indices of input array (waveform).
-:::
-
-## Example
+- **`timeout`** (default: `100`ms): Specifies the interval between consecutive VAD inferences. A lower value makes the detection more responsive but increases CPU usage.
+- **`detectionMargin`** (default: `100`ms): Specifies the maximum allowed gap between the last detected speech segment and the current time to still consider the speech as "ongoing." This value determines how much silence is tolerated before `onSpeechEnd` is triggered.
```tsx
-import React from 'react';
-import { Button, Text, SafeAreaView } from 'react-native';
-import { models, useVAD } from 'react-native-executorch';
-import { AudioContext } from 'react-native-audio-api';
-import * as FileSystem from 'expo-file-system';
-
-export default function App() {
- const model = useVAD({
- model: models.vad.fsmn_vad(),
+import { useVAD, models } from 'react-native-executorch';
+
+const model = useVAD({ model: models.vad.fsmn_vad() });
+
+const startLiveVAD = async () => {
+ // Start the continuous streaming listener
+ model.stream({
+ onSpeechBegin: () => console.log('User started speaking'),
+ onSpeechEnd: () => console.log('User stopped speaking'),
+ options: {
+ timeout: 100, // Checks every 100ms
+ detectionMargin: 500, // 500ms of silence before ending speech
+ },
});
- const audioURL = 'https://some-audio-url.com/file.mp3';
-
- const handleAudio = async () => {
- if (!model) {
- console.error('VAD model is not loaded yet.');
- return;
- }
-
- console.log('Processing URL:', audioURL);
-
- try {
- const { uri } = await FileSystem.downloadAsync(
- audioURL,
- FileSystem.cacheDirectory + 'vad_example.tmp'
- );
-
- const audioContext = new AudioContext({ sampleRate: 16000 });
- const originalDecodedBuffer =
- await audioContext.decodeAudioDataSource(uri);
- const originalChannelData = originalDecodedBuffer.getChannelData(0);
-
- const segments = await model.forward(originalChannelData);
- if (segments.length === 0) {
- console.log('No speech segments were found.');
- return;
- }
- console.log(`Found ${segments.length} speech segments.`);
-
- const totalLength = segments.reduce(
- (sum, seg) => sum + (seg.end - seg.start),
- 0
- );
- const newAudioBuffer = audioContext.createBuffer(
- 1, // Mono
- totalLength,
- originalDecodedBuffer.sampleRate
- );
- const newChannelData = newAudioBuffer.getChannelData(0);
-
- let offset = 0;
- for (const segment of segments) {
- const slice = originalChannelData.subarray(segment.start, segment.end);
- newChannelData.set(slice, offset);
- offset += slice.length;
- }
-
- // Play the processed audio
- const source = audioContext.createBufferSource();
- source.buffer = newAudioBuffer;
- source.connect(audioContext.destination);
- source.start();
- } catch (error) {
- console.error('Error processing audio data:', error);
- }
- };
-
- return (
-
-
- Press the button to process and play speech from a sample file.
-
-
-
- );
-}
+ // Example: Hook into your audio recorder's data event
+ audioRecorder.on('data', (chunk: Float32Array) => {
+ model.streamInsert(chunk);
+ });
+};
+
+const stopLiveVAD = () => {
+ model.streamStop();
+};
```
+### Arguments & Returns
+
+- **Arguments**: `useVAD` takes a [`VADProps`](../../06-api-reference/interfaces/VADProps.md) object containing the `model` and an optional `preventLoad` flag.
+- **Returns**: A [`VADType`](../../06-api-reference/interfaces/VADType.md) object providing `forward`, `stream`, `streamInsert`, and `streamStop` methods, along with `isReady` and `error` states.
+
## Supported models
- [fsmn-vad](https://huggingface.co/collections/software-mansion/voice-activity-detection)
diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md
index 3b38e56984..2e2597397d 100644
--- a/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md
+++ b/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md
@@ -14,11 +14,12 @@ The `SpeechToTextModule` class provides a direct interface to the library's spee
You can transcribe audio in two ways: **one-shot** (for files/short clips) and **streaming** (for live microphone input).
```typescript
-import { models, SpeechToTextModule } from 'react-native-executorch';
+import { SpeechToTextModule, models } from 'react-native-executorch';
-// Initialize the model
+// Initialize the model with VAD submodule
const model = await SpeechToTextModule.fromModelName(
models.speech_to_text.whisper_tiny_en(),
+ models.vad.fsmn_vad(), // Optional VAD submodule
(progress) => {
console.log(`Loading: ${progress * 100}%`);
}
@@ -30,7 +31,7 @@ console.log(result.text);
// 2. Live streaming (yields partial/stable results)
model.streamInsert(audioChunk);
-const stream = model.stream();
+const stream = model.stream({ useVAD: true }); // Enable VAD logic
for await (const { committed, nonCommitted } of stream) {
// Update UI live with stable and partial text
}
@@ -61,42 +62,20 @@ The `transcribe()` function accepts an optional configuration object:
- `language`: The language code (e.g., `'es'`, `'fr'`). Required for multilingual models.
- `verbose`: If `true`, the method returns a detailed `TranscriptionResult` object following the OpenAI Whisper `verbose_json` format (including segments and word-level timestamps).
-### Multilingual transcription
-
-If you aim to obtain a transcription in languages other than English, use a multilingual Whisper model. To get the output in your desired language, pass the [`DecodingOptions`](../../06-api-reference/interfaces/DecodingOptions.md) object with the [`language`](../../06-api-reference/interfaces/DecodingOptions.md#language) field set to the target language code.
-
-```typescript
-const transcription = await model.transcribe(spanishAudio, { language: 'es' });
-```
-
-### Timestamps & Detailed Results
-
-Set `verbose: true` in the options to obtain word-level timestamps and other parameters. The result mimics the _verbose_json_ format from OpenAI Whisper API.
-
-```typescript
-const result = await model.transcribe(audioBuffer, { verbose: true });
-// Example result:
-// {
-// text: "Example text...",
-// segments: [
-// { start: 0, end: 5.4, text: "Example text", words: [...] }
-// ],
-// language: "en"
-// }
-```
-
## Live Streaming Transcription
The **Streaming API** is optimized for live microphone input or real-time audio feeds. It handles audio inputs of arbitrary length by automatically managing context windows to bypass the standard 30-second limit.
-:::iStreaming Options
+### Streaming Options
+
The `stream()` function accepts several optional parameters:
- `language`: The language code (e.g., `'es'`, `'fr'`). Required for multilingual models.
- `verbose`: If `true`, includes word-level timestamps and segment metadata in the result objects.
- `timeout`: (Advanced) The interval (in milliseconds) between processing consecutive audio chunks. Lower values provide more frequent updates, while higher values reduce CPU consumption. Defaults to `100`.
+- `useVAD`: Enable the Voice Activity Detection submodule (if configured during load) to improve performance by skipping silence.
-### nfo
+:::info
- **`committed`**: Finalized transcription that is stable and will not change. Useful for building a persistent transcript record.
- **`nonCommitted`**: Partial transcription that is still being processed and may update as more context arrives. Useful for live UI updates.
@@ -107,11 +86,12 @@ The `stream()` function accepts several optional parameters:
In this example, we use [`react-native-audio-api`](https://docs.swmansion.com/react-native-audio-api/) to feed live audio into the model.
```tsx
-import { models, SpeechToTextModule } from 'react-native-executorch';
+import { SpeechToTextModule, models } from 'react-native-executorch';
import { AudioManager, AudioRecorder } from 'react-native-audio-api';
const model = await SpeechToTextModule.fromModelName(
- models.speech_to_text.whisper_tiny_en()
+ models.speech_to_text.whisper_tiny_en(),
+ models.vad.fsmn_vad()
);
// 1. Configure audio session & permissions
@@ -138,7 +118,7 @@ await recorder.start();
// 3. Process the Stream
try {
let stableTranscript = '';
- const streamIter = model.stream({ verbose: false });
+ const streamIter = model.stream({ useVAD: true });
for await (const { committed, nonCommitted } of streamIter) {
if (committed.text) stableTranscript += committed.text;
@@ -155,6 +135,40 @@ model.streamStop();
recorder.stop();
```
+## Advanced Features
+
+### VAD Integration (Recommended for Live)
+
+Integrating **Voice Activity Detection (VAD)** as a submodule improves streaming performance by automatically removing silence. This reduces CPU usage, saves battery, and prevents hallucinations during silent periods.
+
+To use it, provide the VAD model configuration when loading the module and enable `useVAD` in the stream options:
+
+```typescript
+const model = await SpeechToTextModule.fromModelName(
+ models.speech_to_text.whisper_tiny_en(),
+ models.vad.fsmn_vad()
+);
+
+// Enable VAD logic in the stream context
+const stream = model.stream({ useVAD: true });
+```
+
+### Multilingual Transcription
+
+If you aim to obtain a transcription in languages other than English, use a multilingual Whisper model. To get the output in your desired language, pass the [`DecodingOptions`](../../06-api-reference/interfaces/DecodingOptions.md) object with the [`language`](../../06-api-reference/interfaces/DecodingOptions.md#language) field set to the target language code.
+
+```typescript
+const transcription = await model.transcribe(spanishAudio, { language: 'es' });
+```
+
+### Timestamps & Detailed Results
+
+Set `verbose: true` in the options to obtain word-level timestamps and other parameters. The result mimics the _verbose_json_ format from OpenAI Whisper API.
+
+```typescript
+const result = await model.transcribe(audioBuffer, { verbose: true });
+```
+
## Supported models
| Model | Language |
diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md
index e9dd56710e..c9113b561b 100644
--- a/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md
+++ b/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md
@@ -22,7 +22,7 @@ await model.forward(waveform);
### Methods
-All methods of `VADModule` are explained in details here: [`VADModule` API Reference](../../06-api-reference/classes/VADModule.md)
+All methods of `VADModule` are explained in detail here: [`VADModule` API Reference](../../06-api-reference/classes/VADModule.md)
## Loading the model
@@ -36,11 +36,23 @@ To create a ready-to-use instance, call the static [`fromModelName`](../../06-ap
The factory returns a promise that resolves to a loaded `VADModule` instance.
-For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
+For more information on loading resources, take a look at the [loading models](../../01-fundamentals/02-loading-models.md) page.
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/VADModule.md#forward) method on the module object. Before running the model's [`forward`](../../06-api-reference/classes/VADModule.md#forward) method, make sure to extract the audio waveform you want to process. You'll need to handle this step yourself, ensuring the audio is sampled at 16 kHz. Once you have the waveform, pass it as an argument to the forward method. The method returns a promise that resolves to the array of detected speech segments.
+### File Processing
+
+To process a full audio buffer at once, use the [`forward`](../../06-api-reference/classes/VADModule.md#forward) method. Before calling [`forward`](../../06-api-reference/classes/VADModule.md#forward), ensure you have the audio waveform sampled at 16 kHz. Pass the waveform as an argument; the method returns a promise that resolves to an array of detected speech segments.
+
+### Live Streaming
+
+For real-time applications, `VADModule` supports a streaming mode that identifies speech events as audio arrives.
+
+1. **Initialize the stream**: Call [`stream`](../../06-api-reference/classes/VADModule.md#stream) with `onSpeechBegin` and `onSpeechEnd` callbacks.
+2. **Insert audio**: Use [`streamInsert`](../../06-api-reference/classes/VADModule.md#streaminsert) to push new audio chunks into the internal buffer.
+3. **Stop the stream**: Use [`streamStop`](../../06-api-reference/classes/VADModule.md#streamstop) to finish detection and release resources.
+
+Refer to the [`useVAD`](../../03-hooks/01-natural-language-processing/useVAD.md#live-streaming-real-time-detection) hook documentation for a detailed example of the streaming architecture.
## Managing memory
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
index c50410a4f7..f94ef918ac 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
@@ -242,6 +242,12 @@ inline std::span getValue>(const jsi::Value &val,
return getTypedArrayAsSpan(val, runtime);
}
+template <>
+inline std::span
+getValue>(const jsi::Value &val, jsi::Runtime &runtime) {
+ return getTypedArrayAsSpan(val, runtime);
+}
+
template <>
inline std::span getValue>(const jsi::Value &val,
jsi::Runtime &runtime) {
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
index 4b1367ce3a..cb8313598f 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h
@@ -24,6 +24,7 @@
#include
#include
#include
+#include
#include
namespace rnexecutorch {
@@ -101,6 +102,21 @@ template class ModelHostObject : public JsiHostObject {
"streamStop"));
}
+ if constexpr (meta::SameAs) {
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject,
+ promiseHostFunction<&Model::stream>,
+ "stream"));
+
+ addFunctions(JSI_EXPORT_FUNCTION(
+ ModelHostObject, synchronousHostFunction<&Model::streamInsert>,
+ "streamInsert"));
+
+ addFunctions(JSI_EXPORT_FUNCTION(
+ ModelHostObject, synchronousHostFunction<&Model::streamStop>,
+ "streamStop"));
+ }
+
if constexpr (meta::SameAs) {
addFunctions(JSI_EXPORT_FUNCTION(
ModelHostObject, promiseHostFunction<&Model::getVocabSize>,
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp
index 3acd076779..f3ec9f755f 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp
@@ -12,14 +12,20 @@ namespace rnexecutorch::models::speech_to_text {
SpeechToText::SpeechToText(const std::string &modelName,
const std::string &modelSource,
const std::string &tokenizerSource,
+ const std::string &vadSource,
std::shared_ptr callInvoker)
: callInvoker_(std::move(callInvoker)) {
// Switch between the ASR implementations based on model name
if (modelName == "whisper") {
+ if (!vadSource.empty()) {
+ vad_ = std::make_unique(vadSource, callInvoker_);
+ }
+
transcriber_ = std::make_unique(modelSource, tokenizerSource,
callInvoker_);
streamer_ = std::make_unique(
- static_cast(transcriber_.get()));
+ static_cast(transcriber_.get()),
+ static_cast(vad_.get()));
} else {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
@@ -116,12 +122,19 @@ TranscriptionResult wordsToResult(const std::vector &words,
void SpeechToText::stream(std::shared_ptr callback,
std::string languageOption, bool verbose,
- uint32_t timeout) {
+ uint32_t timeout, bool useVAD,
+ uint32_t vadDetectionMargin) {
if (isStreaming_) {
throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress,
"Streaming is already in progress!");
}
+ // VAD must be initialized for this option to work.
+ if (useVAD && !vad_) {
+ throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
+ "Attempting to use VAD but it's not initialized!");
+ }
+
auto nativeCallback =
[this, callback](const TranscriptionResult &committed,
const TranscriptionResult &nonCommitted, bool isDone) {
@@ -139,18 +152,24 @@ void SpeechToText::stream(std::shared_ptr callback,
};
isStreaming_ = true;
- DecodingOptions options(languageOption, verbose);
+ StreamingOptions options(languageOption, verbose, useVAD, vadDetectionMargin);
while (isStreaming_) {
if (readyToProcess_ && streamer_->isReady()) {
ProcessResult res = streamer_->process(options);
- TranscriptionResult committedRes =
- wordsToResult(res.committed, languageOption, verbose);
- TranscriptionResult nonCommittedRes =
- wordsToResult(res.nonCommitted, languageOption, verbose);
+ // This is not only about optimization, but also about ensuring
+ // correctness when VAD is used. Otherwise we run into the vanishing text
+ // issue.
+ if (!res.committed.empty() || !res.nonCommitted.empty()) {
+ TranscriptionResult committedRes =
+ wordsToResult(res.committed, languageOption, verbose);
+ TranscriptionResult nonCommittedRes =
+ wordsToResult(res.nonCommitted, languageOption, verbose);
+
+ nativeCallback(committedRes, nonCommittedRes, false);
+ }
- nativeCallback(committedRes, nonCommittedRes, false);
readyToProcess_ = false;
}
@@ -166,7 +185,11 @@ void SpeechToText::stream(std::shared_ptr callback,
[this] { return !isStreaming_.load(); });
}
- std::vector finalWords = streamer_->finish(options);
+ // Do not use VAD on final processing to flush and commit everything properly.
+ StreamingOptions finishOptions = options;
+ finishOptions.useVAD = false;
+
+ std::vector finalWords = streamer_->finish(finishOptions);
TranscriptionResult finalRes =
wordsToResult(finalWords, languageOption, verbose);
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h
index adcfd8ae99..9c084dcf6e 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h
@@ -10,15 +10,18 @@
#include "common/schema/ASR.h"
#include "common/schema/OnlineASR.h"
#include "common/types/TranscriptionResult.h"
+#include
namespace rnexecutorch {
namespace models::speech_to_text {
+using voice_activity_detection::VoiceActivityDetection;
+
class SpeechToText {
public:
SpeechToText(const std::string &modelName, const std::string &modelSource,
- const std::string &tokenizerSource,
+ const std::string &tokenizerSource, const std::string &vadSource,
std::shared_ptr callInvoker);
// Required because of std::atomic usage
@@ -40,7 +43,8 @@ class SpeechToText {
// Stream
void stream(std::shared_ptr callback,
- std::string languageOption, bool verbose, uint32_t timeout);
+ std::string languageOption, bool verbose,
+ uint32_t timeout, bool useVAD, uint32_t vadDetectionMargin);
void streamStop();
void streamInsert(std::span waveform);
@@ -61,12 +65,15 @@ class SpeechToText {
// waiting for the next throttling interval to expire.
std::mutex streamCvMutex_;
std::condition_variable streamCv_;
+
+ // VAD submodule
+ std::unique_ptr vad_ = nullptr;
};
} // namespace models::speech_to_text
REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string,
- std::string, std::string,
+ std::string, std::string, std::string,
std::shared_ptr);
} // namespace rnexecutorch
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h
index db259c82fc..61ca8b6871 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h
@@ -4,7 +4,7 @@
#include
#include
-#include "../types/DecodingOptions.h"
+#include "../types/Options.h"
#include "../types/Segment.h"
#include
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h
index efe6cc2819..b2ba2bb62b 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h
@@ -3,7 +3,7 @@
#include
#include
-#include "../types/DecodingOptions.h"
+#include "../types/Options.h"
#include "../types/ProcessResult.h"
#include "../types/Word.h"
@@ -34,9 +34,9 @@ class OnlineASR {
virtual bool isReady() const = 0;
- virtual ProcessResult process(const DecodingOptions &options) = 0;
+ virtual ProcessResult process(const StreamingOptions &options) = 0;
- virtual std::vector finish(const DecodingOptions &options) = 0;
+ virtual std::vector finish(const StreamingOptions &options) = 0;
virtual void reset() = 0;
};
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h
deleted file mode 100644
index de08308260..0000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h
+++ /dev/null
@@ -1,17 +0,0 @@
-#pragma once
-
-#include
-#include
-
-namespace rnexecutorch::models::speech_to_text {
-
-struct DecodingOptions {
- explicit DecodingOptions(const std::string &language, bool verbose = false)
- : language(language.empty() ? std::nullopt : std::optional(language)),
- verbose(verbose) {}
-
- std::optional language;
- bool verbose;
-};
-
-} // namespace rnexecutorch::models::speech_to_text
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h
new file mode 100644
index 0000000000..760106d3f9
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include
+#include
+
+namespace rnexecutorch::models::speech_to_text {
+
+struct DecodingOptions {
+ DecodingOptions(const std::string &language, bool verbose = false)
+ : language(language.empty() ? std::nullopt : std::optional(language)),
+ verbose(verbose) {}
+
+ std::optional language;
+ bool verbose;
+};
+
+struct StreamingOptions : public DecodingOptions {
+ StreamingOptions(const std::string &language, bool verbose = false,
+ bool useVAD = false, uint32_t vadDetectionMargin = 500)
+ : DecodingOptions(language, verbose), useVAD(useVAD),
+ vadDetectionMargin(vadDetectionMargin) {}
+
+ bool useVAD;
+ uint32_t vadDetectionMargin;
+};
+
+} // namespace rnexecutorch::models::speech_to_text
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp
index e663c5bfab..963567d8b5 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp
@@ -10,7 +10,8 @@
namespace rnexecutorch::models::speech_to_text::whisper::stream {
-OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) {
+OnlineASR::OnlineASR(const ASR *asr, const VoiceActivityDetection *vad)
+ : asr_(asr), vad_(vad) {
audioBuffer_.reserve((constants::kChunkSize + 1) * constants::kSamplingRate);
}
@@ -42,9 +43,11 @@ void OnlineASR::insertAudioChunk(std::span audio) {
}
}
-ProcessResult OnlineASR::process(const DecodingOptions &options) {
+ProcessResult OnlineASR::process(const StreamingOptions &options) {
constexpr size_t kStreamSafeBufferMaxSamples = static_cast(
params::kStreamSafeBufferDuration * constants::kSamplingRate);
+ constexpr size_t kSafetyMarginSamples = static_cast(
+ params::kStreamSafetyThreshold * constants::kSamplingRate);
std::vector audioCopy;
@@ -55,12 +58,57 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) {
audioCopy = audioBuffer_;
}
- // Obtain a transcription for current audio buffer state.
- // It's very unlikely that buffer will exceed whisper's maximum capacity, but
- // for absolute safety we can additionally clip the buffer.
- std::span input(
- audioCopy.begin(),
- audioCopy.begin() + std::min(constants::kMaxSamples, audioCopy.size()));
+ std::span input;
+
+ // Allowing VAD changes logic significantly - we no longer commit and clean
+ // at max samples reached moments, but rather at the end of speech moments.
+ if (options.useVAD && vad_) {
+ auto speechSegments = vad_->generate(audioCopy, options.vadDetectionMargin *
+ params::kVadGapFactor);
+
+ if (speechSegments.empty()) {
+ // Extra cleanup to speed-up future processing by removing silence.
+ if (audioCopy.size() > params::kVadDeadSamplesRemovalSamples) {
+ std::scoped_lock lock(streamingMutex);
+ size_t cut = std::min(params::kVadDeadSamplesRemovalSamples -
+ kSafetyMarginSamples,
+ audioBuffer_.size());
+ audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + cut);
+ }
+
+ return {};
+ }
+
+ const auto &lastSegment = speechSegments.back();
+ size_t marginSamples =
+ options.vadDetectionMargin * constants::kSamplesPerMilisecond;
+
+ if (audioCopy.size() - lastSegment.end <= marginSamples) {
+ // Speech is ongoing. Keep last 1s context and trim around current
+ // segment.
+ size_t startWithMargin =
+ std::max(lastSegment.start, constants::kSamplingRate) -
+ constants::kSamplingRate;
+ input = std::span(audioCopy.begin() + startWithMargin,
+ audioCopy.begin() + lastSegment.end);
+ } else {
+ // Speech ended beyond margin. Commit existing transcript and clear
+ // buffer.
+ std::scoped_lock lock(streamingMutex);
+ std::vector committed = std::move(memory_.transcript);
+ memory_.transcript.clear();
+ memory_.eos.clear();
+
+ audioBuffer_.erase(audioBuffer_.begin(),
+ audioBuffer_.begin() +
+ std::min(lastSegment.end, audioBuffer_.size()));
+ return {.committed = std::move(committed), .nonCommitted = {}};
+ }
+ } else {
+ input = std::span(audioCopy.begin(),
+ audioCopy.begin() +
+ std::min(constants::kMaxSamples, audioCopy.size()));
+ }
std::vector transcriptions = asr_->transcribe(input, options);
@@ -135,7 +183,7 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) {
return {.committed = std::move(committed), .nonCommitted = std::move(words)};
}
-std::vector OnlineASR::finish(const DecodingOptions &options) {
+std::vector OnlineASR::finish(const StreamingOptions &options) {
ProcessResult result = process(options);
// Last-tick committed delta + whatever never made it past the commit
@@ -179,8 +227,7 @@ std::vector OnlineASR::commitAndClean(std::vector &transcript) {
// recorded any speech. In this case we can safely cut the maximum amount of
// audio data.
if (memory_.eos.empty()) {
- size_t cut =
- bufferSize - params::kStreamSafetyThreshold * constants::kSamplingRate;
+ size_t cut = bufferSize - kSafetyMarginSamples;
audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + cut);
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h
index 7547d16bd5..f13fa9ccff 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h
@@ -1,6 +1,7 @@
#pragma once
#include
+#include
#include
#include
@@ -8,9 +9,12 @@
#include "../common/types/ProcessResult.h"
#include "../common/types/Word.h"
#include "ASR.h"
+#include
namespace rnexecutorch::models::speech_to_text::whisper::stream {
+using voice_activity_detection::VoiceActivityDetection;
+
/**
* Online Automatic Speech Recognition (OnlineASR) for Whisper.
* It manages continuous processing of audio stream by maintaining a local
@@ -18,7 +22,7 @@ namespace rnexecutorch::models::speech_to_text::whisper::stream {
*/
class OnlineASR : public schema::OnlineASR {
public:
- OnlineASR(const ASR *asr);
+ OnlineASR(const ASR *asr, const VoiceActivityDetection *vad);
/**
* Checks if the buffer contains enough audio for the next processing step.
@@ -37,13 +41,13 @@ class OnlineASR : public schema::OnlineASR {
* @param options Decoding options for the transcription.
* @return Transcription result containing committed and volatile tokens.
*/
- ProcessResult process(const DecodingOptions &options) override;
+ ProcessResult process(const StreamingOptions &options) override;
/**
* Finalizes the current stream and returns all words.
* @return Vector of detected words.
*/
- std::vector finish(const DecodingOptions &options) override;
+ std::vector finish(const StreamingOptions &options) override;
/**
* Resets the internal state and clears buffers.
@@ -57,6 +61,9 @@ class OnlineASR : public schema::OnlineASR {
// ASR module connection for transcribing the audio
const ASR *asr_;
+ // VAD module connection for selecting processing (optional)
+ const VoiceActivityDetection *vad_;
+
// Audio buffer (input) - accumulates obtained audio samples.
std::vector audioBuffer_ = {};
mutable std::mutex streamingMutex; // Covers both buffer & memory
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h
index 847a22b1e0..f6c6e491b0 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h
@@ -32,7 +32,6 @@ constexpr inline float kStreamSafetyThreshold = 3.F; // [s]
*/
constexpr inline float kStreamSafeBufferDuration =
kStreamMaxDuration - kStreamSafetyThreshold; // [s]
-
/**
* An estimate of the number of words spoken per second.
* Used for estimating transcription progress and buffer management heuristics.
@@ -60,4 +59,24 @@ constexpr inline float kWordsPerSecondLow = 1.5F;
*/
constexpr inline int32_t kChunkBreakBuffer = 2; // [s]
+/**
+ * Multiplier applied to `vadDetectionMargin` when computing the merge gap
+ * passed to `VoiceActivityDetection::generate`. The merge window is widened
+ * slightly so brief intra-utterance silences (just over the detection margin)
+ * still fold into a single speech segment.
+ */
+constexpr inline float kVadGapFactor = 1.2F;
+
+constexpr inline size_t kVadDeadSamplesRemovalSamples =
+ 9 * constants::kSamplingRate; // 9s of audio samples
+
+// `OnlineASR::process` computes the silence-trim cut as
+// `kVadDeadSamplesRemovalSamples - kSafetyMarginSamples` over `size_t`. If the
+// two ever invert, that subtraction wraps and the subsequent erase reads past
+// the buffer. Catch the regression at compile time.
+static_assert(kVadDeadSamplesRemovalSamples >
+ static_cast(kStreamSafetyThreshold *
+ constants::kSamplingRate),
+ "kVadDeadSamplesRemovalSamples must exceed the safety margin");
+
} // namespace rnexecutorch::models::speech_to_text::whisper::params
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h
index eaf34796a4..fccc719eaa 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h
@@ -6,13 +6,14 @@
namespace rnexecutorch::models::voice_activity_detection::constants {
inline constexpr uint32_t kSampleRate = 16000;
-inline constexpr auto kMstoSecond = 0.001f;
+inline constexpr uint32_t kSamplesPerMs = kSampleRate / 1000;
+inline constexpr auto kMsToSeconds = 0.001f;
inline constexpr uint32_t kWindowSizeMs = 25;
inline constexpr uint32_t kHopLengthMs = 10;
inline constexpr auto kWindowSize =
- static_cast(kMstoSecond * kWindowSizeMs * kSampleRate); // 400
+ static_cast(kMsToSeconds * kWindowSizeMs * kSampleRate); // 400
inline constexpr auto kHopLength =
- static_cast(kMstoSecond * kHopLengthMs * kSampleRate); // 160
+ static_cast(kMsToSeconds * kHopLengthMs * kSampleRate); // 160
inline constexpr auto kPreemphasisCoeff = 0.97f;
inline constexpr auto kLeftPadding = (kWindowSize - 1) / 2;
inline constexpr auto kRightPadding = kWindowSize / 2;
@@ -20,8 +21,11 @@ inline constexpr auto kPaddedWindowSize = std::bit_ceil(kWindowSize); // 512
inline constexpr size_t kModelInputMin = 100;
inline constexpr size_t kModelInputMax = 1000;
inline constexpr auto kSpeechThreshold = 0.6f;
-inline constexpr size_t kMinSpeechDuration = 25; // 250 ms
-inline constexpr size_t kMinSilenceDuration = 10; // 100 ms
-inline constexpr size_t kSpeechPad = 3; // 30 ms
+inline constexpr size_t kMinSpeechDuration = 25; // 250 ms
+inline constexpr size_t kMinSilenceDuration = 10; // 100 ms
+inline constexpr size_t kSpeechPad = 3; // 30 ms
+inline constexpr size_t kStreamBufferMaxSize = 10 * kSampleRate; // 10s
+inline constexpr size_t kStreamBufferMinReserve =
+ 1 * kSampleRate; // 1s of audio
-} // namespace rnexecutorch::models::voice_activity_detection::constants
\ No newline at end of file
+} // namespace rnexecutorch::models::voice_activity_detection::constants
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp
index 2ecd0e80cc..ff26372896 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp
@@ -1,4 +1,7 @@
#include "Utils.h"
+#include "Types.h"
+
+#include
namespace rnexecutorch::models::voice_activity_detection::utils {
size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
@@ -12,4 +15,28 @@ size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
return startIdx + size;
}
+std::vector
+mergeSegments(const std::vector &segments, size_t maxMergeGap) {
+ if (segments.empty()) {
+ return segments;
+ }
+
+ std::vector mergedSegments;
+ mergedSegments.push_back(segments[0]);
+
+ for (size_t i = 1; i < segments.size(); ++i) {
+ auto &lastMerged = mergedSegments.back();
+ const auto ¤t = segments[i];
+
+ if (current.start < lastMerged.end ||
+ current.start - lastMerged.end <= maxMergeGap) {
+ lastMerged.end = std::max(lastMerged.end, current.end);
+ } else {
+ mergedSegments.push_back(current);
+ }
+ }
+
+ return mergedSegments;
+}
+
} // namespace rnexecutorch::models::voice_activity_detection::utils
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h
index b016700887..3ec8212448 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h
@@ -1,5 +1,6 @@
#pragma once
+#include "Types.h"
#include
#include
#include
@@ -10,4 +11,16 @@ size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
std::vector &resultVector,
size_t startIdx);
-} // namespace rnexecutorch::models::voice_activity_detection::utils
\ No newline at end of file
+/**
+ * Merges adjacent speech segments which are separated by a gap smaller than or
+ * equal to the specified maximum merge gap.
+ *
+ * @param segments A collection of speech segments to be merged.
+ * @param maxMergeGap The maximum allowed distance between two segments (in
+ * samples) to qualify them for a merge.
+ * @return A new collection containing the merged speech segments.
+ */
+std::vector
+mergeSegments(const std::vector &segments, size_t maxMergeGap);
+
+} // namespace rnexecutorch::models::voice_activity_detection::utils
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp
index 12baef7fb1..ff1d228f2f 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp
@@ -8,60 +8,31 @@
#include
#include
#include
+#include
#include
namespace rnexecutorch::models::voice_activity_detection {
-using namespace constants;
+
namespace ranges = std::ranges;
+using namespace constants;
using executorch::aten::Tensor;
using executorch::extension::TensorPtr;
VoiceActivityDetection::VoiceActivityDetection(
const std::string &modelSource,
std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker) {}
-
-std::vector>
-VoiceActivityDetection::preprocess(std::span waveform) const {
- auto kHammingWindowArray = dsp::hannWindow(kWindowSize);
-
- const size_t numFrames = (waveform.size() - kWindowSize) / kHopLength;
-
- std::vector> frameBuffer(
- numFrames, std::array{});
-
- constexpr size_t totalPadding = kPaddedWindowSize - kWindowSize;
- constexpr size_t leftPadding = totalPadding / 2;
- for (size_t i = 0; i < numFrames; i++) {
-
- auto windowView = waveform.subspan(i * kHopLength, kWindowSize);
- ranges::copy(windowView, frameBuffer[i].begin() + leftPadding);
- auto frameView =
- std::span{frameBuffer[i].data() + leftPadding, kWindowSize};
- const float sum = std::reduce(frameView.begin(), frameView.end(), 0.0f);
- const float mean = sum / kWindowSize;
- ranges::transform(frameView, frameView.begin(),
- [mean](float value) { return value - mean; });
-
- // apply pre-emphasis filter
- for (auto j = frameView.size() - 1; j > 0; --j) {
- frameView[j] -= kPreemphasisCoeff * frameView[j - 1];
- }
- // apply hamming window to reduce spectral leakage
- ranges::transform(frameView, kHammingWindowArray, frameView.begin(),
- std::multiplies{});
- }
- return frameBuffer;
-}
-
-void VoiceActivityDetection::unload() noexcept {
- std::scoped_lock lock(inference_mutex_);
- BaseModel::unload();
+ : BaseModel(modelSource, callInvoker), callInvoker_(callInvoker) {
+ // Important - preallocate memory for the buffer to avoid any reallocations.
+ audioBuffer_.reserve(2 * constants::kStreamBufferMaxSize);
}
std::vector
-VoiceActivityDetection::generate(std::span waveform) const {
- std::scoped_lock lock(inference_mutex_);
+VoiceActivityDetection::generate(std::span waveform,
+ uint32_t mergeGap) const {
+ // Guard against small buffers to prevent underflow in preprocess
+ if (waveform.size() < kWindowSize) {
+ return {};
+ }
auto windowedInput = preprocess(waveform);
auto [chunksNumber, remainder] = std::div(
@@ -100,12 +71,131 @@ VoiceActivityDetection::generate(std::span waveform) const {
auto tensor = forwardResult->at(0).toTensor();
startIdx = utils::getNonSpeechClassProbabilites(tensor, tensor.size(2),
remainder, scores, startIdx);
- return postprocess(scores, kSpeechThreshold);
+ return postprocess(scores, kSpeechThreshold, mergeGap);
+}
+
+void VoiceActivityDetection::stream(std::shared_ptr callback,
+ uint32_t timeout,
+ uint32_t detectionMargin) {
+ bool expected = false;
+ if (!isStreaming_.compare_exchange_strong(expected, true)) {
+ throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress,
+ "Streaming is already in progress!");
+ }
+
+ // RAII guard
+ struct StreamingGuard {
+ VoiceActivityDetection *self;
+ ~StreamingGuard() {
+ self->isStreaming_ = false;
+ std::scoped_lock lock(self->audioBufferMutex_);
+ self->audioBuffer_.clear();
+ }
+ } guard{this};
+
+ // Build a native callback which simply sends a boolean signal,
+ // where true corresponds to detected ongoing speech, and false corresponds to
+ // silence.
+ auto nativeCallback = [this, callback](bool speaking) {
+ callInvoker_->invokeAsync([callback, speaking](jsi::Runtime &rt) {
+ callback->call(rt, jsi::Value(speaking));
+ });
+ };
+
+ while (isStreaming_) {
+ // Make sure that audio buffer does not exceed it's max size
+ // BEFORE infering the model, such that potentially save 1 unnecessary
+ // iteration.
+ {
+ std::scoped_lock lock(audioBufferMutex_);
+ if (audioBuffer_.size() > constants::kStreamBufferMaxSize) {
+ // Keep the most recent audio samples in.
+ size_t cut = audioBuffer_.size() - constants::kStreamBufferMinReserve;
+ audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + cut);
+ }
+ }
+
+ std::vector snapshot;
+ {
+ // Copy under lock so a concurrent streamInsert that grows audioBuffer_
+ // past its reserved capacity (causing a reallocation) can't invalidate
+ // the data we hand to generate().
+ std::scoped_lock lock(audioBufferMutex_);
+ snapshot.assign(audioBuffer_.begin(), audioBuffer_.end());
+ }
+
+ std::vector detection = generate(snapshot);
+
+ bool speaking = false;
+
+ // Get the last detected speech segment (if any exists) and
+ // check if it is recent enough to consider the speech as ongoing.
+ if (!detection.empty()) {
+ auto lastSegment = detection.back();
+ auto speechEnd = lastSegment.end;
+
+ std::scoped_lock lock(audioBufferMutex_);
+ uint32_t diffMs =
+ (audioBuffer_.size() - speechEnd) / constants::kSamplesPerMs; // [ms]
+
+ speaking = diffMs <= detectionMargin;
+ }
+
+ // Send result to the JS side.
+ nativeCallback(speaking);
+
+ std::unique_lock lock(streamCvMutex_);
+ streamCv_.wait_for(lock, std::chrono::milliseconds(timeout),
+ [this] { return !isStreaming_.load(); });
+ }
+}
+
+void VoiceActivityDetection::streamStop() {
+ isStreaming_ = false;
+ streamCv_.notify_all();
+}
+
+void VoiceActivityDetection::streamInsert(std::span audio) {
+ std::scoped_lock lock(audioBufferMutex_);
+ audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end());
+}
+
+std::vector>
+VoiceActivityDetection::preprocess(std::span waveform) const {
+ auto kHammingWindowArray = dsp::hannWindow(kWindowSize);
+
+ const size_t numFrames = (waveform.size() - kWindowSize) / kHopLength;
+
+ std::vector> frameBuffer(
+ numFrames, std::array{});
+
+ constexpr size_t totalPadding = kPaddedWindowSize - kWindowSize;
+ constexpr size_t leftPadding = totalPadding / 2;
+ for (size_t i = 0; i < numFrames; i++) {
+
+ auto windowView = waveform.subspan(i * kHopLength, kWindowSize);
+ ranges::copy(windowView, frameBuffer[i].begin() + leftPadding);
+ auto frameView =
+ std::span{frameBuffer[i].data() + leftPadding, kWindowSize};
+ const float sum = std::reduce(frameView.begin(), frameView.end(), 0.0f);
+ const float mean = sum / kWindowSize;
+ ranges::transform(frameView, frameView.begin(),
+ [mean](float value) { return value - mean; });
+
+ // apply pre-emphasis filter
+ for (auto j = frameView.size() - 1; j > 0; --j) {
+ frameView[j] -= kPreemphasisCoeff * frameView[j - 1];
+ }
+ // apply hamming window to reduce spectral leakage
+ ranges::transform(frameView, kHammingWindowArray, frameView.begin(),
+ std::multiplies{});
+ }
+ return frameBuffer;
}
std::vector
VoiceActivityDetection::postprocess(const std::vector &scores,
- float threshold) const {
+ float threshold, uint32_t mergeGap) const {
bool triggered = false;
std::vector speechSegments{};
ssize_t startSegment = -1;
@@ -154,7 +244,9 @@ VoiceActivityDetection::postprocess(const std::vector &scores,
end = std::min(end + kSpeechPad, scores.size()) * kHopLength;
}
- return speechSegments;
+ // Merge tightly placed segments according to the max allowed gap parameter.
+ size_t maxMergeGap = mergeGap * constants::kSamplesPerMs;
+ return utils::mergeSegments(speechSegments, maxMergeGap);
}
} // namespace rnexecutorch::models::voice_activity_detection
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h
index c756bb6d3c..d9d72adbba 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h
@@ -1,5 +1,6 @@
#pragma once
+#include
#include
#include
#include
@@ -15,25 +16,59 @@
namespace rnexecutorch {
namespace models::voice_activity_detection {
+
using executorch::extension::TensorPtr;
using executorch::runtime::EValue;
+
class VoiceActivityDetection : public BaseModel {
public:
VoiceActivityDetection(const std::string &modelSource,
std::shared_ptr callInvoker);
+
[[nodiscard("Registered non-void function")]] std::vector
- generate(std::span waveform) const;
+ generate(std::span waveform, uint32_t mergeGap = 0) const;
- void unload() noexcept;
+ /**
+ * Initializes the streaming procedure, which performs
+ * a continuous detection on live expanding audio buffer.
+ *
+ * @param callback JS side callback to pass the results with.
+ * @param timeout Specifies (in miliseconds) how much does streamer wait
+ * between model inferences.
+ * @param detectionMargin Specifies (in miliseconds) how far the last detected
+ * speech segment can be to still be considered as ongoing speech.
+ */
+ void stream(std::shared_ptr callback, uint32_t timeout,
+ uint32_t detectionMargin);
-private:
- mutable std::mutex inference_mutex_;
+ /**
+ * When called, stops the streaming procedure.
+ */
+ void streamStop();
+
+ /**
+ * Adds a new audio chunk to the streaming buffer.
+ */
+ void streamInsert(std::span audio);
+private:
std::vector>
preprocess(std::span waveform) const;
std::vector postprocess(const std::vector &scores,
- float threshold) const;
+ float threshold,
+ uint32_t mergeGap) const;
+
+ std::shared_ptr callInvoker_;
+
+ // Streaming buffer
+ std::vector audioBuffer_ = {};
+ mutable std::mutex audioBufferMutex_;
+ // Streaming state
+ std::atomic isStreaming_ = false;
+ std::condition_variable streamCv_;
+ std::mutex streamCvMutex_;
};
+
} // namespace models::voice_activity_detection
REGISTER_CONSTRUCTOR(models::voice_activity_detection::VoiceActivityDetection,
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
index 1fcad420cc..1f34b3a18e 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
@@ -263,6 +263,8 @@ add_rn_test(SpeechToTextTests integration/SpeechToTextTest.cpp
${RNEXECUTORCH_DIR}/models/speech_to_text/SpeechToText.cpp
${RNEXECUTORCH_DIR}/models/speech_to_text/whisper/ASR.cpp
${RNEXECUTORCH_DIR}/models/speech_to_text/whisper/OnlineASR.cpp
+ ${RNEXECUTORCH_DIR}/models/voice_activity_detection/VoiceActivityDetection.cpp
+ ${RNEXECUTORCH_DIR}/models/voice_activity_detection/Utils.cpp
${RNEXECUTORCH_DIR}/data_processing/gzip.cpp
${TOKENIZER_SOURCES}
${DSP_SOURCES}
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp
index ed9a1c5c0e..cc326d8bc0 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp
@@ -1,8 +1,16 @@
#include "BaseModelTests.h"
#include "utils/TestUtils.h"
+#include
#include
+#include
+#include
#include
#include
+#include
+
+namespace rnexecutorch {
+std::shared_ptr createMockCallInvoker();
+}
using namespace rnexecutorch;
using namespace rnexecutorch::models::speech_to_text;
@@ -11,6 +19,7 @@ using namespace model_tests;
constexpr auto kValidModelPath = "whisper_tiny_en_xnnpack.pte";
constexpr auto kValidTokenizerPath = "whisper_tokenizer.json";
+constexpr auto kValidVadModelPath = "fsmn-vad_xnnpack.pte";
// ============================================================================
// Common tests via typed test suite
@@ -20,12 +29,13 @@ template <> struct ModelTraits {
using ModelType = SpeechToText;
static ModelType createValid() {
- return ModelType("whisper", kValidModelPath, kValidTokenizerPath, nullptr);
+ return ModelType("whisper", kValidModelPath, kValidTokenizerPath,
+ /*vadSource=*/"", nullptr);
}
static ModelType createInvalid() {
return ModelType("whisper", "nonexistent.pte", kValidTokenizerPath,
- nullptr);
+ /*vadSource=*/"", nullptr);
}
static void callGenerate(ModelType &model) {
@@ -44,24 +54,25 @@ INSTANTIATE_TYPED_TEST_SUITE_P(SpeechToText, CommonModelTest,
// ============================================================================
TEST(S2TCtorTests, InvalidModelNameThrows) {
EXPECT_THROW(SpeechToText("invalid_model", kValidModelPath,
- kValidTokenizerPath, nullptr),
+ kValidTokenizerPath, /*vadSource=*/"", nullptr),
RnExecutorchError);
}
TEST(S2TCtorTests, InvalidModelPathThrows) {
- EXPECT_THROW(
- SpeechToText("whisper", "nonexistent.pte", kValidTokenizerPath, nullptr),
- RnExecutorchError);
+ EXPECT_THROW(SpeechToText("whisper", "nonexistent.pte", kValidTokenizerPath,
+ /*vadSource=*/"", nullptr),
+ RnExecutorchError);
}
TEST(S2TCtorTests, InvalidTokenizerPathThrows) {
- EXPECT_THROW(
- SpeechToText("whisper", kValidModelPath, "nonexistent.json", nullptr),
- rnexecutorch::RnExecutorchError);
+ EXPECT_THROW(SpeechToText("whisper", kValidModelPath, "nonexistent.json",
+ /*vadSource=*/"", nullptr),
+ rnexecutorch::RnExecutorchError);
}
TEST(S2TEncodeTests, EncodeReturnsNonNull) {
- SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, nullptr);
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ /*vadSource=*/"", nullptr);
auto audio = loadAudioFromFile("test_audio_float.raw");
ASSERT_FALSE(audio.empty());
auto result = model.encode(audio);
@@ -71,7 +82,8 @@ TEST(S2TEncodeTests, EncodeReturnsNonNull) {
TEST(S2TTranscribeTests, TranscribeReturnsValidChars) {
GTEST_SKIP() << "TODO: known failure on this branch; needs investigation.";
- SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, nullptr);
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ /*vadSource=*/"", nullptr);
auto audio = loadAudioFromFile("test_audio_float.raw");
ASSERT_FALSE(audio.empty());
auto result = model.transcribe(audio, "en", true);
@@ -87,16 +99,73 @@ TEST(S2TTranscribeTests, TranscribeReturnsValidChars) {
}
TEST(S2TTranscribeTests, EmptyResultOnSilence) {
- SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, nullptr);
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ /*vadSource=*/"", nullptr);
auto audio = generateSilence(16000 * 5);
auto result = model.transcribe(audio, "en", false);
EXPECT_TRUE(result.text.empty());
}
TEST(S2TTranscribeTests, InvalidLanguageThrows) {
- SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, nullptr);
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ /*vadSource=*/"", nullptr);
auto audio = loadAudioFromFile("test_audio_float.raw");
ASSERT_FALSE(audio.empty());
EXPECT_THROW((void)model.transcribe(audio, "invalid_language_code", false),
RnExecutorchError);
}
+
+// ============================================================================
+// VAD integration tests (vadSource provided => internal VAD module loaded)
+// ============================================================================
+TEST(S2TVadCtorTests, ValidVadSourceConstructs) {
+ EXPECT_NO_THROW(SpeechToText("whisper", kValidModelPath, kValidTokenizerPath,
+ kValidVadModelPath, createMockCallInvoker()));
+}
+
+TEST(S2TVadCtorTests, InvalidVadSourceThrows) {
+ EXPECT_THROW(SpeechToText("whisper", kValidModelPath, kValidTokenizerPath,
+ "nonexistent_vad.pte", createMockCallInvoker()),
+ RnExecutorchError);
+}
+
+TEST(S2TVadTranscribeTests, TranscribeStillWorksWithVadLoaded) {
+ // The vadSource only affects streaming. The one-shot transcribe() path
+ // must remain unchanged when a VAD is attached.
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ kValidVadModelPath, createMockCallInvoker());
+ auto silence = generateSilence(16000 * 5);
+ auto result = model.transcribe(silence, "en", false);
+ EXPECT_TRUE(result.text.empty());
+}
+
+TEST(S2TVadStreamTests, UseVadWithoutVadInitializedThrows) {
+ // Guard added by the PR: stream() with useVAD=true on a model that was
+ // built with an empty vadSource should fail loudly.
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ /*vadSource=*/"", createMockCallInvoker());
+ EXPECT_THROW(model.stream(std::shared_ptr(), /*language=*/"en",
+ /*verbose=*/false, /*timeout=*/100,
+ /*useVAD=*/true, /*vadDetectionMargin=*/500),
+ RnExecutorchError);
+}
+
+TEST(S2TVadStreamTests, StreamWithVadOnSilenceCompletesCleanly) {
+ // Drives the OnlineASR::process VAD branch with audio that contains no
+ // speech segments. Exercises the "speechSegments.empty()" cleanup path
+ // added by the PR.
+ SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath,
+ kValidVadModelPath, createMockCallInvoker());
+ std::thread streamer([&model] {
+ model.stream(std::shared_ptr(), /*language=*/"en",
+ /*verbose=*/false, /*timeout=*/100, /*useVAD=*/true,
+ /*vadDetectionMargin=*/500);
+ });
+
+ auto silence = generateSilence(16000 * 2); // 2s of silence
+ model.streamInsert(silence);
+ std::this_thread::sleep_for(std::chrono::milliseconds(400));
+
+ model.streamStop();
+ streamer.join();
+}
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp
index 5c19b74f67..e60b375f1d 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp
@@ -1,14 +1,25 @@
#include "BaseModelTests.h"
#include "utils/TestUtils.h"
+#include
#include
+#include
+#include
#include
+#include
#include
+#include
+
+namespace rnexecutorch {
+std::shared_ptr createMockCallInvoker();
+}
using namespace rnexecutorch;
using namespace rnexecutorch::models::voice_activity_detection;
using namespace test_utils;
using namespace model_tests;
+namespace vad_utils = rnexecutorch::models::voice_activity_detection::utils;
+
constexpr auto kValidVadModelPath = "fsmn-vad_xnnpack.pte";
// ============================================================================
@@ -97,3 +108,146 @@ TEST(VADInheritedTests, GetMethodMetaWorks) {
auto result = model.getMethodMeta("forward");
EXPECT_TRUE(result.ok());
}
+
+// ============================================================================
+// utils::mergeSegments unit tests (no model load needed)
+// ============================================================================
+TEST(VADMergeSegmentsTests, EmptyInputReturnsEmpty) {
+ std::vector empty;
+ EXPECT_TRUE(vad_utils::mergeSegments(empty, /*maxMergeGap=*/100).empty());
+}
+
+TEST(VADMergeSegmentsTests, SingleSegmentReturnedUnchanged) {
+ std::vector input = {{100, 500}};
+ auto merged = vad_utils::mergeSegments(input, /*maxMergeGap=*/100);
+ ASSERT_EQ(merged.size(), 1u);
+ EXPECT_EQ(merged[0].start, 100u);
+ EXPECT_EQ(merged[0].end, 500u);
+}
+
+TEST(VADMergeSegmentsTests, DistantSegmentsStaySeparate) {
+ // Gap between segments (1000 - 500 = 500) is larger than maxMergeGap (100).
+ std::vector input = {{0, 500}, {1000, 1500}};
+ auto merged = vad_utils::mergeSegments(input, /*maxMergeGap=*/100);
+ ASSERT_EQ(merged.size(), 2u);
+ EXPECT_EQ(merged[0].end, 500u);
+ EXPECT_EQ(merged[1].start, 1000u);
+}
+
+TEST(VADMergeSegmentsTests, CloseSegmentsAreMerged) {
+ // Gap (510 - 500 = 10) is within maxMergeGap (100).
+ std::vector input = {{0, 500}, {510, 1000}};
+ auto merged = vad_utils::mergeSegments(input, /*maxMergeGap=*/100);
+ ASSERT_EQ(merged.size(), 1u);
+ EXPECT_EQ(merged[0].start, 0u);
+ EXPECT_EQ(merged[0].end, 1000u);
+}
+
+TEST(VADMergeSegmentsTests, AdjacentSegmentsAreMerged) {
+ // Gap of 0 is <= maxMergeGap (0).
+ std::vector input = {{0, 500}, {500, 1000}};
+ auto merged = vad_utils::mergeSegments(input, /*maxMergeGap=*/0);
+ ASSERT_EQ(merged.size(), 1u);
+ EXPECT_EQ(merged[0].start, 0u);
+ EXPECT_EQ(merged[0].end, 1000u);
+}
+
+TEST(VADMergeSegmentsTests, OverlappingShorterInnerDoesNotShrink) {
+ // Second segment overlaps inside the first. Merged result must not shrink.
+ std::vector input = {{0, 1000}, {200, 500}};
+ auto merged = vad_utils::mergeSegments(input, /*maxMergeGap=*/0);
+ ASSERT_EQ(merged.size(), 1u);
+ EXPECT_EQ(merged[0].start, 0u);
+ EXPECT_EQ(merged[0].end, 1000u);
+}
+
+TEST(VADMergeSegmentsTests, MixedSequenceMergesOnlyAdjacentClose) {
+ // First two are close enough to merge; third is far.
+ std::vector input = {{0, 500}, {520, 800}, {2000, 2500}};
+ auto merged = vad_utils::mergeSegments(input, /*maxMergeGap=*/100);
+ ASSERT_EQ(merged.size(), 2u);
+ EXPECT_EQ(merged[0].start, 0u);
+ EXPECT_EQ(merged[0].end, 800u);
+ EXPECT_EQ(merged[1].start, 2000u);
+ EXPECT_EQ(merged[1].end, 2500u);
+}
+
+// ============================================================================
+// Streaming lifecycle tests
+//
+// VAD::stream() blocks on its own loop until streamStop() flips the flag, so
+// every test below drives it from a background thread and joins after stop.
+// The native callback round-trips through the CallInvoker — using the no-op
+// MockCallInvoker means callbacks are dropped but the streaming loop itself
+// still runs normally.
+// ============================================================================
+TEST(VADStreamTests, StreamStopExitsLoop) {
+ VoiceActivityDetection model(kValidVadModelPath, createMockCallInvoker());
+ std::thread streamer([&model] {
+ model.stream(std::shared_ptr(), /*timeout=*/50,
+ /*detectionMargin=*/100);
+ });
+
+ // Let the loop spin a few iterations before stopping.
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+ const auto stopRequestedAt = std::chrono::steady_clock::now();
+ model.streamStop();
+ streamer.join();
+ const auto elapsedToStop = std::chrono::steady_clock::now() - stopRequestedAt;
+
+ // streamStop() should wake the wait_for() and return quickly (well below
+ // the timeout window — give a generous safety margin).
+ EXPECT_LT(elapsedToStop, std::chrono::milliseconds(500));
+}
+
+TEST(VADStreamTests, StreamInsertWhileStreamingDoesNotCrash) {
+ VoiceActivityDetection model(kValidVadModelPath, createMockCallInvoker());
+ std::thread streamer([&model] {
+ model.stream(std::shared_ptr(), /*timeout=*/50,
+ /*detectionMargin=*/100);
+ });
+
+ auto silence = generateSilence(16000); // 1s
+ auto audio = loadAudioFromFile("test_audio_float.raw");
+ ASSERT_FALSE(audio.empty());
+
+ model.streamInsert(silence);
+ std::this_thread::sleep_for(std::chrono::milliseconds(150));
+ model.streamInsert(audio);
+ std::this_thread::sleep_for(std::chrono::milliseconds(300));
+
+ model.streamStop();
+ streamer.join();
+}
+
+TEST(VADStreamTests, ConcurrentStreamCallThrows) {
+ VoiceActivityDetection model(kValidVadModelPath, createMockCallInvoker());
+ std::thread streamer([&model] {
+ model.stream(std::shared_ptr(), /*timeout=*/50,
+ /*detectionMargin=*/100);
+ });
+
+ // Give the first call time to flip isStreaming_ to true.
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ EXPECT_THROW(model.stream(std::shared_ptr(), /*timeout=*/50,
+ /*detectionMargin=*/100),
+ RnExecutorchError);
+
+ model.streamStop();
+ streamer.join();
+}
+
+TEST(VADStreamTests, StreamCanBeRestartedAfterStop) {
+ VoiceActivityDetection model(kValidVadModelPath, createMockCallInvoker());
+
+ for (int iter = 0; iter < 2; ++iter) {
+ std::thread streamer([&model] {
+ model.stream(std::shared_ptr(), /*timeout=*/50,
+ /*detectionMargin=*/100);
+ });
+ std::this_thread::sleep_for(std::chrono::milliseconds(150));
+ model.streamStop();
+ streamer.join();
+ }
+}
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh
index 16e093fdd8..9fbbaade13 100755
--- a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh
@@ -205,7 +205,7 @@ models_for_test() {
StyleTransferTests) echo "style_transfer_candy_xnnpack_fp32.pte test_image.jpg" ;;
VADTests) echo "fsmn-vad_xnnpack.pte" ;;
TokenizerModuleTests) echo "tokenizer.json" ;;
- SpeechToTextTests) echo "whisper_tiny_en_xnnpack.pte whisper_tokenizer.json" ;;
+ SpeechToTextTests) echo "whisper_tiny_en_xnnpack.pte whisper_tokenizer.json fsmn-vad_xnnpack.pte" ;;
TextToSpeechTests) echo "kokoro_duration_predictor.pte kokoro_synthesizer.pte kokoro_af_heart.bin kokoro_us_lexicon.json kokoro_en_tagger.json kokoro_us_phonemizer.pte" ;;
LLMTests) echo "smolLm2_135M_8da4w.pte smollm_tokenizer.json lfm2_5_vl_quantized_xnnpack_v2.pte lfm2_vl_tokenizer.json lfm2_vl_tokenizer_config.json test_image.jpg" ;;
TextToImageTests) echo "t2i_tokenizer.json t2i_encoder.pte t2i_unet.pte t2i_decoder.pte" ;;
diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts
index 229bba73e3..0fd0480d3e 100644
--- a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts
+++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts
@@ -18,6 +18,7 @@ import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
*/
export const useSpeechToText = ({
model,
+ vad,
preventLoad = false,
}: SpeechToTextProps): SpeechToTextType => {
const [error, setError] = useState(null);
@@ -42,6 +43,7 @@ export const useSpeechToText = ({
modelSource: model.modelSource,
tokenizerSource: model.tokenizerSource,
},
+ vad,
(p) => {
if (active) setDownloadProgress(p);
}
@@ -73,6 +75,7 @@ export const useSpeechToText = ({
model.isMultilingual,
model.modelSource,
model.tokenizerSource,
+ vad,
preventLoad,
]);
diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts
index 82a140288a..15d1d72a26 100644
--- a/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts
+++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts
@@ -1,5 +1,5 @@
import { VADModule } from '../../modules/natural_language_processing/VADModule';
-import { VADType, VADProps } from '../../types/vad';
+import { VADType, VADProps, VADStreamingInput } from '../../types/vad';
import { useModuleFactory } from '../useModuleFactory';
/**
@@ -21,5 +21,29 @@ export const useVAD = ({ model, preventLoad = false }: VADProps): VADType => {
const forward = (waveform: Float32Array) =>
runForward((inst) => inst.forward(waveform));
- return { error, isReady, isGenerating, downloadProgress, forward };
+ const stream = (input: VADStreamingInput) =>
+ runForward((inst) => inst.stream(input));
+
+ const streamInsert = (waveform: Float32Array) =>
+ runForward((inst) => {
+ inst.streamInsert(waveform);
+ return Promise.resolve();
+ });
+
+ const streamStop = () =>
+ runForward((inst) => {
+ inst.streamStop();
+ return Promise.resolve();
+ });
+
+ return {
+ error,
+ isReady,
+ isGenerating,
+ downloadProgress,
+ forward,
+ stream,
+ streamInsert,
+ streamStop,
+ };
};
diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts
index 556d35de02..1f190d41f5 100644
--- a/packages/react-native-executorch/src/index.ts
+++ b/packages/react-native-executorch/src/index.ts
@@ -95,7 +95,8 @@ declare global {
var loadSpeechToText: (
modelName: string,
modelSource: string,
- tokenizerSource: string
+ tokenizerSource: string,
+ vadSource: string
) => Promise;
var loadTextToSpeechKokoro: (
lang: string,
diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts
index 3890c9ae50..f7ffe52f1c 100644
--- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts
+++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts
@@ -5,6 +5,7 @@ import {
StreamingOptions,
TranscriptionResult,
} from '../../types/stt';
+import { VADConfig } from '../../types/vad';
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { ResourceSource } from '../../types/common';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
@@ -30,6 +31,7 @@ export class SpeechToTextModule {
/**
* Creates a Speech to Text instance for a built-in model.
* @param namedSources - Configuration object containing model name, sources, and multilingual flag.
+ * @param vad - Optional configuration for Voice Activity Detection (VAD).
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SpeechToTextModule` instance.
* @example
@@ -40,11 +42,13 @@ export class SpeechToTextModule {
*/
static async fromModelName(
namedSources: SpeechToTextModelConfig,
+ vad?: VADConfig,
onDownloadProgress: (progress: number) => void = () => {}
): Promise {
try {
const nativeModule = await SpeechToTextModule.loadWhisper(
namedSources,
+ vad,
onDownloadProgress
);
return new SpeechToTextModule(nativeModule, namedSources);
@@ -64,6 +68,7 @@ export class SpeechToTextModule {
* @param modelSource - A fetchable resource pointing to the model binary.
* @param tokenizerSource - A fetchable resource pointing to the tokenizer file.
* @param isMultilingual - Whether the model supports multiple languages.
+ * @param vad - Optional VAD configuration.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SpeechToTextModule` instance.
*/
@@ -71,6 +76,7 @@ export class SpeechToTextModule {
modelSource: ResourceSource,
tokenizerSource: ResourceSource,
isMultilingual: boolean,
+ vad?: VADConfig,
onDownloadProgress: (progress: number) => void = () => {}
): Promise {
return SpeechToTextModule.fromModelName(
@@ -80,12 +86,14 @@ export class SpeechToTextModule {
tokenizerSource,
isMultilingual,
},
+ vad,
onDownloadProgress
);
}
private static async loadWhisper(
model: SpeechToTextModelConfig,
+ vad: VADConfig | undefined,
onDownloadProgressCallback: (progress: number) => void
): Promise {
const tokenizerLoadPromise = ResourceFetcher.fetch(
@@ -96,9 +104,14 @@ export class SpeechToTextModule {
onDownloadProgressCallback,
model.modelSource
);
- const [tokenizerSources, modelSources] = await Promise.all([
+ const vadPromise = vad
+ ? ResourceFetcher.fetch(undefined, vad.modelSource)
+ : Promise.resolve(undefined);
+
+ const [tokenizerSources, modelSources, vadSources] = await Promise.all([
tokenizerLoadPromise,
modelPromise,
+ vadPromise,
]);
if (!modelSources?.[0] || !tokenizerSources?.[0]) {
throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted);
@@ -107,7 +120,8 @@ export class SpeechToTextModule {
return await global.loadSpeechToText(
'whisper',
modelSources[0],
- tokenizerSources[0]
+ tokenizerSources[0],
+ vadSources?.[0] || ''
);
}
@@ -184,6 +198,8 @@ export class SpeechToTextModule {
const verbose = !!options.verbose;
const language = options.language || '';
const timeout = options.timeout || 100;
+ const useVAD = options.useVAD ?? false;
+ const vadDetectionMargin = options.vadDetectionMargin ?? 500;
const queue: {
committed: TranscriptionResult;
@@ -219,7 +235,9 @@ export class SpeechToTextModule {
},
language,
verbose,
- timeout
+ timeout,
+ useVAD,
+ vadDetectionMargin
);
finished = true;
diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts
index 1b34de2e1d..eb74a582c0 100644
--- a/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts
+++ b/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts
@@ -1,6 +1,6 @@
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { ResourceSource } from '../../types/common';
-import { Segment, VADModelName } from '../../types/vad';
+import { Segment, VADModelName, VADStreamingInput } from '../../types/vad';
import { BaseModule } from '../BaseModule';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
@@ -11,6 +11,8 @@ import { Logger } from '../../common/Logger';
* @category Typescript API
*/
export class VADModule extends BaseModule {
+ private isStreaming: boolean = false;
+
private constructor(nativeModule: unknown) {
super();
this.nativeModule = nativeModule;
@@ -70,4 +72,85 @@ export class VADModule extends BaseModule {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded);
return await this.nativeModule.generate(waveform);
}
+
+ /**
+ * Starts a streaming VAD session.
+ * @param input - Configuration for streaming, including callbacks for speech begin/end and optional parameters.
+ * @returns A promise that resolves when the streaming session stops.
+ */
+ public async stream({
+ onSpeechBegin,
+ onSpeechEnd,
+ options,
+ }: VADStreamingInput): Promise {
+ if (this.nativeModule == null)
+ throw new RnExecutorchError(
+ RnExecutorchErrorCode.ModuleNotLoaded,
+ 'The model is currently not loaded. Please load the model before calling stream().'
+ );
+
+ const timeout = options?.timeout ?? 100;
+ const detectionMargin = options?.detectionMargin ?? 100;
+
+ let isSpeaking = false;
+ let error: unknown;
+ let finished = false;
+
+ let waiter: (() => void) | null = null;
+ const wake = () => {
+ waiter?.();
+ waiter = null;
+ };
+
+ this.isStreaming = true;
+
+ (async () => {
+ try {
+ await this.nativeModule.stream(
+ async (speaking: boolean) => {
+ if (speaking && !isSpeaking) {
+ isSpeaking = true;
+ await onSpeechBegin?.();
+ } else if (!speaking && isSpeaking) {
+ isSpeaking = false;
+ await onSpeechEnd?.();
+ }
+ },
+ timeout,
+ detectionMargin
+ );
+ } catch (e) {
+ error = e;
+ } finally {
+ this.isStreaming = false;
+ finished = true;
+ wake();
+ }
+ })();
+
+ while (this.isStreaming && !finished) {
+ if (error) throw parseUnknownError(error);
+ await new Promise((r) => (waiter = r));
+ }
+ }
+
+ /**
+ * Inserts a new audio chunk into the streaming VAD session.
+ * @param waveform - The audio chunk to insert.
+ */
+ public streamInsert(waveform: Float32Array): void {
+ if (this.nativeModule == null)
+ throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded);
+ this.nativeModule.streamInsert(waveform);
+ }
+
+ /**
+ * Stops the current streaming VAD session.
+ */
+ public streamStop(): void {
+ if (this.nativeModule == null)
+ throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded);
+ this.nativeModule.streamStop();
+ this.isStreaming = false;
+ }
}
diff --git a/packages/react-native-executorch/src/types/stt.ts b/packages/react-native-executorch/src/types/stt.ts
index f9a2fb56d8..7669f6c044 100644
--- a/packages/react-native-executorch/src/types/stt.ts
+++ b/packages/react-native-executorch/src/types/stt.ts
@@ -1,4 +1,5 @@
import { ResourceSource } from './common';
+import { VADConfig } from './vad';
import { RnExecutorchError } from '../errors/errorUtils';
/**
@@ -22,6 +23,10 @@ export interface SpeechToTextProps {
* Configuration object containing model sources.
*/
model: SpeechToTextModelConfig; // | ...
+ /**
+ * An optional VAD model to be utilized to enhance speech-to-text streaming capabilities.
+ */
+ vad?: VADConfig;
/**
* Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
*/
@@ -209,9 +214,15 @@ export interface DecodingOptions {
* Configuration options for the speech-to-text streaming process.
* @category Types
* @property {number} [timeout] - Specifies (in milliseconds) how much does streamer wait between model inferences.
+ * @property {boolean} [useVAD] - When set to true, utilizes the inner VAD submodule (if initialized) and
+ * transcription process runs only when speech is being detected.
+ * @property {number} [vadDetectionMargin] - Specifies (in milliseconds) how far the last detected speech segment can be to still be considered
+ * as ongoing speech. Works only with useVAD set to true.
*/
export interface StreamingOptions extends DecodingOptions {
timeout?: number;
+ useVAD?: boolean;
+ vadDetectionMargin?: number;
}
/**
diff --git a/packages/react-native-executorch/src/types/vad.ts b/packages/react-native-executorch/src/types/vad.ts
index 5c0868c5d7..a266a775da 100644
--- a/packages/react-native-executorch/src/types/vad.ts
+++ b/packages/react-native-executorch/src/types/vad.ts
@@ -7,6 +7,17 @@ import { RnExecutorchError } from '../errors/errorUtils';
*/
export type VADModelName = 'fsmn-vad';
+/**
+ * Configuration for the VAD model.
+ * @category Types
+ * @property {VADModelName} modelName - Unique name identifying the model.
+ * @property {ResourceSource} modelSource - The source of the VAD model binary.
+ */
+export interface VADConfig {
+ modelName: VADModelName;
+ modelSource: ResourceSource;
+}
+
/**
* Props for the useVAD hook.
* @category Types
@@ -16,7 +27,7 @@ export type VADModelName = 'fsmn-vad';
* @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
*/
export interface VADProps {
- model: { modelName: VADModelName; modelSource: ResourceSource };
+ model: VADConfig;
preventLoad?: boolean;
}
@@ -31,6 +42,30 @@ export interface Segment {
end: number;
}
+/**
+ * Configuration options for the VAD streaming process.
+ * @category Types
+ * @property {number} [timeout] - Specifies (in milliseconds) how much does streamer wait between model inferences.
+ * @property {number} [detectionMargin] - Specifies (in milliseconds) how far the last detected speech segment can be to still be considered as ongoing speech.
+ */
+export interface VADStreamingOptions {
+ timeout?: number;
+ detectionMargin?: number;
+}
+
+/**
+ * Input configuration for the VAD streaming hook.
+ * @category Types
+ * @property {() => void | Promise} [onSpeechBegin] - Callback function triggered when speech is detected.
+ * @property {() => void | Promise} [onSpeechEnd] - Callback function triggered when speech end (silence is detected).
+ * @property {VADStreamingOptions} [options] - Optional configuration for the streaming process.
+ */
+export interface VADStreamingInput {
+ onSpeechBegin?: () => void | Promise;
+ onSpeechEnd?: () => void | Promise;
+ options?: VADStreamingOptions;
+}
+
/**
* React hook state and methods for managing a Voice Activity Detection (VAD) model instance.
* @category Types
@@ -63,4 +98,22 @@ export interface VADType {
* @throws {RnExecutorchError} If the model is not loaded or is currently processing another request.
*/
forward(waveform: Float32Array): Promise;
+
+ /**
+ * Starts a streaming Voice Activity Detection session.
+ * @param input - Configuration for streaming, including callbacks for speech begin/end and optional parameters.
+ * @returns A promise that resolves when the streaming session stops.
+ */
+ stream(input: VADStreamingInput): Promise;
+
+ /**
+ * Inserts an audio chunk into the streaming VAD session.
+ * @param waveform - The audio data to add to the buffer.
+ */
+ streamInsert(waveform: Float32Array): void;
+
+ /**
+ * Stops the current streaming VAD session.
+ */
+ streamStop(): void;
}
diff --git a/packages/react-native-executorch/third-party/common/phonemis b/packages/react-native-executorch/third-party/common/phonemis
index 6b09cf1fca..a72ef9092c 160000
--- a/packages/react-native-executorch/third-party/common/phonemis
+++ b/packages/react-native-executorch/third-party/common/phonemis
@@ -1 +1 @@
-Subproject commit 6b09cf1fcabe9295d99f9a6bcce864748a226273
+Subproject commit a72ef9092c80a86d963b590b4aa0a4e813270bc3
diff --git a/yarn.lock b/yarn.lock
index 21ff6c80d5..51d6f00207 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -12753,7 +12753,7 @@ __metadata:
metro-config: "npm:^0.83.0"
react: "npm:19.2.5"
react-native: "npm:0.83.4"
- react-native-audio-api: "npm:0.12.0"
+ react-native-audio-api: "npm:0.12.2"
react-native-device-info: "npm:^15.0.2"
react-native-executorch: "workspace:*"
react-native-executorch-expo-resource-fetcher: "workspace:*"
@@ -15302,20 +15302,6 @@ __metadata:
languageName: node
linkType: hard
-"react-native-audio-api@npm:0.12.0":
- version: 0.12.0
- resolution: "react-native-audio-api@npm:0.12.0"
- dependencies:
- semver: "npm:^7.7.3"
- peerDependencies:
- react: "*"
- react-native: "*"
- bin:
- setup-rn-audio-api-web: scripts/setup-rn-audio-api-web.js
- checksum: 10/3328a36649c30ea24a6c957f89c8f97969f06f530a72147eb2c59d7332b8ce5b1a2a34dd95f2a6cd2f8b279505161e1117d1329bfd5101167971c0189e5ab5f3
- languageName: node
- linkType: hard
-
"react-native-audio-api@npm:0.12.2":
version: 0.12.2
resolution: "react-native-audio-api@npm:0.12.2"