diff --git a/pyproject.toml b/pyproject.toml index ed9ab048e..5139ff444 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "aiofiles>=24,<25", "appdirs>=1.4.0", "art>=6.5.0", + "av>=14.0.0", "azure-core>=1.38.0", "azure-identity>=1.19.0", "azure-ai-contentsafety>=1.0.0", @@ -135,6 +136,7 @@ speech = [ # all includes all functional dependencies excluding the ones from the "dev" extra all = [ "accelerate>=1.7.0", + "av>=14.0.0", "azure-ai-ml>=1.27.1", "azure-cognitiveservices-speech>=1.44.0", "azureml-mlflow>=1.60.0", diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index b0d0ad2a9..1395e3b96 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -8,6 +8,8 @@ from abc import ABC from typing import Optional +import av + from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score from pyrit.prompt_converter import AzureSpeechAudioToTextConverter @@ -16,6 +18,76 @@ logger = logging.getLogger(__name__) +def _is_compliant_wav(input_path: str, *, sample_rate: int, channels: int) -> bool: + """ + Check if the audio file is already a compliant WAV with the target format. + + Args: + input_path (str): Path to the audio file. + sample_rate (int): Expected sample rate in Hz. + channels (int): Expected number of channels. + + Returns: + bool: True if the file is already compliant, False otherwise. + """ + try: + with av.open(input_path) as container: + if not container.streams.audio: + return False + stream = container.streams.audio[0] + codec_name = stream.codec_context.name + is_pcm_s16 = codec_name == "pcm_s16le" + is_correct_rate = stream.rate == sample_rate + is_correct_channels = stream.channels == channels + return is_pcm_s16 and is_correct_rate and is_correct_channels + except Exception: + return False + + +def _audio_to_wav(input_path: str, *, sample_rate: int, channels: int) -> str: + """ + Convert any audio or video file to a normalised PCM WAV using PyAV. + + If the input is already a compliant WAV (correct sample rate, channels, and codec), + returns the original path without re-encoding. + + Args: + input_path (str): Source audio or video file. + sample_rate (int): Target sample rate in Hz. + channels (int): Target number of channels (1 = mono). + + Returns: + str: Path to the WAV file (original if compliant, otherwise a temporary file). + """ + # Skip conversion if already compliant + if _is_compliant_wav(input_path, sample_rate=sample_rate, channels=channels): + logger.debug(f"Audio file already compliant, skipping conversion: {input_path}") + return input_path + + layout = "mono" if channels == 1 else "stereo" + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + with av.open(input_path) as in_container: + with av.open(output_path, "w", format="wav") as out_container: + out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate, layout=layout) + resampler = av.AudioResampler(format="s16", layout=layout, rate=sample_rate) + + for frame in in_container.decode(audio=0): + for out_frame in resampler.resample(frame): + for packet in out_stream.encode(out_frame): + out_container.mux(packet) + + for out_frame in resampler.resample(None): + for packet in out_stream.encode(out_frame): + out_container.mux(packet) + + for packet in out_stream.encode(None): + out_container.mux(packet) + + return output_path + + class AudioTranscriptHelper(ABC): # noqa: B024 """ Abstract base class for audio scorers that process audio by transcribing and scoring the text. @@ -29,12 +101,12 @@ class AudioTranscriptHelper(ABC): # noqa: B024 _DEFAULT_SAMPLE_RATE = 16000 # 16kHz - Azure Speech optimal rate _DEFAULT_CHANNELS = 1 # Mono - Azure Speech prefers mono _DEFAULT_SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample) - _DEFAULT_EXPORT_PARAMS = ["-acodec", "pcm_s16le"] # 16-bit PCM for best compatibility def __init__( self, *, text_capable_scorer: Scorer, + use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the base audio scorer. @@ -42,12 +114,15 @@ def __init__( Args: text_capable_scorer (Scorer): A scorer capable of processing text that will be used to score the transcribed audio content. + use_entra_auth (bool, Optional): Whether to use Entra ID authentication for Azure Speech. + Defaults to True if None. Raises: ValueError: If text_capable_scorer does not support text data type. """ self._validate_text_scorer(text_capable_scorer) self.text_scorer = text_capable_scorer + self._use_entra_auth = use_entra_auth if use_entra_auth is not None else True @staticmethod def _validate_text_scorer(scorer: Scorer) -> None: @@ -149,7 +224,7 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: logger.info(f"Audio transcription: WAV file size = {file_size} bytes") try: - converter = AzureSpeechAudioToTextConverter() + converter = AzureSpeechAudioToTextConverter(use_entra_auth=self._use_entra_auth) logger.info("Audio transcription: Starting Azure Speech transcription...") result = await converter.convert_async(prompt=wav_path, input_type="audio_path") logger.info(f"Audio transcription: Result = '{result.output_text}'") @@ -171,25 +246,12 @@ def _ensure_wav_format(self, audio_path: str) -> str: Returns: str: Path to WAV file (original if already WAV, or converted temporary file). - - Raises: - ModuleNotFoundError: If pydub is not installed. """ - try: - from pydub import AudioSegment - except ModuleNotFoundError as e: - logger.error("Could not import pydub. Install it via 'pip install pydub'") - raise e - - audio = AudioSegment.from_file(audio_path) - audio = ( - audio.set_frame_rate(self._DEFAULT_SAMPLE_RATE) - .set_channels(self._DEFAULT_CHANNELS) - .set_sample_width(self._DEFAULT_SAMPLE_WIDTH) + return _audio_to_wav( + audio_path, + sample_rate=self._DEFAULT_SAMPLE_RATE, + channels=self._DEFAULT_CHANNELS, ) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: - audio.export(temp_wav.name, format="wav") - return temp_wav.name def _extract_audio_from_video(self, video_path: str) -> Optional[str]: """ @@ -201,9 +263,6 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]: Returns: str: a path to the extracted audio file (WAV format) or returns None if extraction fails. - - Raises: - ModuleNotFoundError: If pydub/ffmpeg is not installed. """ return AudioTranscriptHelper.extract_audio_from_video(video_path) @@ -218,57 +277,16 @@ def extract_audio_from_video(video_path: str) -> Optional[str]: Returns: str: a path to the extracted audio file (WAV format) or returns None if extraction fails. - - Raises: - ModuleNotFoundError: If pydub/ffmpeg is not installed. """ try: - from pydub import AudioSegment - except ModuleNotFoundError as e: - logger.error("Could not import pydub. Install it via 'pip install pydub'") - raise e - - try: - # Extract audio from video using pydub (requires ffmpeg) logger.info(f"Extracting audio from video: {video_path}") - audio = AudioSegment.from_file(video_path) - logger.info( - f"Audio extracted: duration={len(audio)}ms, channels={audio.channels}, " - f"sample_width={audio.sample_width}, frame_rate={audio.frame_rate}" + output_path = _audio_to_wav( + video_path, + sample_rate=AudioTranscriptHelper._DEFAULT_SAMPLE_RATE, + channels=AudioTranscriptHelper._DEFAULT_CHANNELS, ) - - # Optimize for Azure Speech recognition: - # Azure Speech works best with 16kHz mono audio (same as Azure TTS output) - if audio.frame_rate != AudioTranscriptHelper._DEFAULT_SAMPLE_RATE: - logger.info( - f"Resampling audio from {audio.frame_rate}Hz to {AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz" - ) - audio = audio.set_frame_rate(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE) - - # Ensure 16-bit audio - if audio.sample_width != AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH: - logger.info( - f"Converting sample width from {audio.sample_width * 8}-bit" - f" to {AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH * 8}-bit" - ) - audio = audio.set_sample_width(AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH) - - # Convert to mono (Azure Speech prefers mono) - if audio.channels > AudioTranscriptHelper._DEFAULT_CHANNELS: - logger.info(f"Converting from {audio.channels} channels to mono") - audio = audio.set_channels(AudioTranscriptHelper._DEFAULT_CHANNELS) - - # Create temporary WAV file with PCM encoding for best compatibility - with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio: - audio.export( - temp_audio.name, - format="wav", - parameters=AudioTranscriptHelper._DEFAULT_EXPORT_PARAMS, - ) - logger.info( - f"Audio exported to: {temp_audio.name} (duration={len(audio)}ms, rate={audio.frame_rate}Hz, mono)" - ) - return temp_audio.name + logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)") + return output_path except Exception as e: logger.warning(f"Failed to extract audio from video {video_path}: {e}") return None diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index a44988d5a..528784109 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -25,6 +25,7 @@ def __init__( *, text_capable_scorer: FloatScaleScorer, validator: Optional[ScorerPromptValidator] = None, + use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the AudioFloatScaleScorer. @@ -33,12 +34,17 @@ def __init__( text_capable_scorer: A FloatScaleScorer capable of processing text. This scorer will be used to evaluate the transcribed audio content. validator: Validator for the scorer. Defaults to audio_path data type validator. + use_entra_auth: Whether to use Entra ID authentication for Azure Speech. + Defaults to True if None. Raises: ValueError: If text_capable_scorer does not support text data type. """ super().__init__(validator=validator or self._default_validator) - self._audio_helper = AudioTranscriptHelper(text_capable_scorer=text_capable_scorer) + self._audio_helper = AudioTranscriptHelper( + text_capable_scorer=text_capable_scorer, + use_entra_auth=use_entra_auth, + ) def _build_identifier(self) -> ComponentIdentifier: """ diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index 1c7a5de17..0148cc604 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -25,6 +25,7 @@ def __init__( *, text_capable_scorer: TrueFalseScorer, validator: Optional[ScorerPromptValidator] = None, + use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the AudioTrueFalseScorer. @@ -33,12 +34,17 @@ def __init__( text_capable_scorer: A TrueFalseScorer capable of processing text. This scorer will be used to evaluate the transcribed audio content. validator: Validator for the scorer. Defaults to audio_path data type validator. + use_entra_auth: Whether to use Entra ID authentication for Azure Speech. + Defaults to True if None. Raises: ValueError: If text_capable_scorer does not support text data type. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - self._audio_helper = AudioTranscriptHelper(text_capable_scorer=text_capable_scorer) + self._audio_helper = AudioTranscriptHelper( + text_capable_scorer=text_capable_scorer, + use_entra_auth=use_entra_auth, + ) def _build_identifier(self) -> ComponentIdentifier: """ diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index 543323aa3..4fdd7ce48 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -227,3 +227,98 @@ async def test_score_piece_empty_transcript(self, audio_message_piece): # Empty transcript returns empty list assert len(scores) == 0 + + +class TestPyAVAudioConversion: + """Tests for PyAV audio conversion functions""" + + @pytest.fixture + def compliant_wav_file(self): + """Create a compliant 16kHz mono PCM WAV file""" + import av + import numpy as np + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + sample_rate = 16000 + duration = 0.5 + t = np.linspace(0, duration, int(sample_rate * duration), dtype=np.float32) + audio_data = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + with av.open(output_path, "w", format="wav") as container: + stream = container.add_stream("pcm_s16le", rate=sample_rate, layout="mono") + frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format="s16", layout="mono") + frame.rate = sample_rate + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(None): + container.mux(packet) + + yield output_path + if os.path.exists(output_path): + os.remove(output_path) + + @pytest.fixture + def non_compliant_wav_file(self): + """Create a 44100Hz mono WAV (wrong sample rate)""" + import av + import numpy as np + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + sample_rate = 44100 # Wrong sample rate + duration = 0.5 + t = np.linspace(0, duration, int(sample_rate * duration), dtype=np.float32) + audio_data = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + with av.open(output_path, "w", format="wav") as container: + stream = container.add_stream("pcm_s16le", rate=sample_rate, layout="mono") + frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format="s16", layout="mono") + frame.rate = sample_rate + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(None): + container.mux(packet) + + yield output_path + if os.path.exists(output_path): + os.remove(output_path) + + def test_is_compliant_wav_true(self, compliant_wav_file): + """Test that _is_compliant_wav returns True for compliant files""" + from pyrit.score.audio_transcript_scorer import _is_compliant_wav + + assert _is_compliant_wav(compliant_wav_file, sample_rate=16000, channels=1) is True + + def test_is_compliant_wav_false_wrong_rate(self, non_compliant_wav_file): + """Test that _is_compliant_wav returns False for wrong sample rate""" + from pyrit.score.audio_transcript_scorer import _is_compliant_wav + + assert _is_compliant_wav(non_compliant_wav_file, sample_rate=16000, channels=1) is False + + def test_is_compliant_wav_nonexistent_file(self): + """Test that _is_compliant_wav returns False for nonexistent files""" + from pyrit.score.audio_transcript_scorer import _is_compliant_wav + + assert _is_compliant_wav("/nonexistent/file.wav", sample_rate=16000, channels=1) is False + + def test_audio_to_wav_returns_original_for_compliant(self, compliant_wav_file): + """Test that _audio_to_wav returns the original path for compliant files""" + from pyrit.score.audio_transcript_scorer import _audio_to_wav + + result = _audio_to_wav(compliant_wav_file, sample_rate=16000, channels=1) + assert result == compliant_wav_file + + def test_audio_to_wav_converts_non_compliant(self, non_compliant_wav_file): + """Test that _audio_to_wav converts non-compliant files""" + from pyrit.score.audio_transcript_scorer import _audio_to_wav, _is_compliant_wav + + result = _audio_to_wav(non_compliant_wav_file, sample_rate=16000, channels=1) + try: + assert result != non_compliant_wav_file + assert _is_compliant_wav(result, sample_rate=16000, channels=1) is True + finally: + if result != non_compliant_wav_file and os.path.exists(result): + os.remove(result)