Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 67 additions & 61 deletions pyrit/score/audio_transcript_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import logging
import os
import shutil
import subprocess
import tempfile
import uuid
from abc import ABC
Expand All @@ -16,6 +18,16 @@
logger = logging.getLogger(__name__)


def _check_ffmpeg_installed() -> bool:
"""
Check if ffmpeg is installed and available on PATH.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to check this first, but likely not installed. @romanlutz lmk what you think, but imo we should use something like PyAV as an "all" dependency. It's relatively big, but we'll likely need a lot of audio/video editing and that seems like a good approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I suspect consumers may be less excited but having a million classes with a check on whether or not av is installed is silly. I imagine we'll have A LOT of converters with this over time.

Returns:
bool: True if ffmpeg is installed, False otherwise.
"""
return shutil.which("ffmpeg") is not None


class AudioTranscriptHelper(ABC): # noqa: B024
"""
Abstract base class for audio scorers that process audio by transcribing and scoring the text.
Expand All @@ -29,7 +41,6 @@ class AudioTranscriptHelper(ABC): # noqa: B024
_DEFAULT_SAMPLE_RATE = 16000 # 16kHz - Azure Speech optimal rate
_DEFAULT_CHANNELS = 1 # Mono - Azure Speech prefers mono
_DEFAULT_SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample)
_DEFAULT_EXPORT_PARAMS = ["-acodec", "pcm_s16le"] # 16-bit PCM for best compatibility

def __init__(
self,
Expand Down Expand Up @@ -173,23 +184,35 @@ def _ensure_wav_format(self, audio_path: str) -> str:
str: Path to WAV file (original if already WAV, or converted temporary file).

Raises:
ModuleNotFoundError: If pydub is not installed.
RuntimeError: If ffmpeg is not installed.
"""
try:
from pydub import AudioSegment
except ModuleNotFoundError as e:
logger.error("Could not import pydub. Install it via 'pip install pydub'")
raise e

audio = AudioSegment.from_file(audio_path)
audio = (
audio.set_frame_rate(self._DEFAULT_SAMPLE_RATE)
.set_channels(self._DEFAULT_CHANNELS)
.set_sample_width(self._DEFAULT_SAMPLE_WIDTH)
)
if not _check_ffmpeg_installed():
raise RuntimeError(
"ffmpeg is required for audio processing but was not found on PATH. "
"Install it via: apt install ffmpeg / brew install ffmpeg / "
"https://ffmpeg.org/download.html"
)

with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
audio.export(temp_wav.name, format="wav")
return temp_wav.name
output_path = temp_wav.name
subprocess.run(
[
"ffmpeg",
"-i",
audio_path,
"-ar",
str(self._DEFAULT_SAMPLE_RATE),
"-ac",
str(self._DEFAULT_CHANNELS),
"-acodec",
"pcm_s16le", # 16-bit PCM
output_path,
"-y",
],
check=True,
capture_output=True,
)
return output_path

def _extract_audio_from_video(self, video_path: str) -> Optional[str]:
"""
Expand All @@ -203,7 +226,7 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]:
or returns None if extraction fails.

Raises:
ModuleNotFoundError: If pydub/ffmpeg is not installed.
RuntimeError: If ffmpeg is not installed.
"""
return AudioTranscriptHelper.extract_audio_from_video(video_path)

Expand All @@ -220,55 +243,38 @@ def extract_audio_from_video(video_path: str) -> Optional[str]:
or returns None if extraction fails.

Raises:
ModuleNotFoundError: If pydub/ffmpeg is not installed.
RuntimeError: If ffmpeg is not installed.
"""
try:
from pydub import AudioSegment
except ModuleNotFoundError as e:
logger.error("Could not import pydub. Install it via 'pip install pydub'")
raise e
if not _check_ffmpeg_installed():
raise RuntimeError(
"ffmpeg is required for audio processing but was not found on PATH. "
"Install it via: apt install ffmpeg / brew install ffmpeg / "
"https://ffmpeg.org/download.html"
)

try:
# Extract audio from video using pydub (requires ffmpeg)
logger.info(f"Extracting audio from video: {video_path}")
audio = AudioSegment.from_file(video_path)
logger.info(
f"Audio extracted: duration={len(audio)}ms, channels={audio.channels}, "
f"sample_width={audio.sample_width}, frame_rate={audio.frame_rate}"
)

# Optimize for Azure Speech recognition:
# Azure Speech works best with 16kHz mono audio (same as Azure TTS output)
if audio.frame_rate != AudioTranscriptHelper._DEFAULT_SAMPLE_RATE:
logger.info(
f"Resampling audio from {audio.frame_rate}Hz to {AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz"
)
audio = audio.set_frame_rate(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE)

# Ensure 16-bit audio
if audio.sample_width != AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH:
logger.info(
f"Converting sample width from {audio.sample_width * 8}-bit"
f" to {AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH * 8}-bit"
)
audio = audio.set_sample_width(AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH)

# Convert to mono (Azure Speech prefers mono)
if audio.channels > AudioTranscriptHelper._DEFAULT_CHANNELS:
logger.info(f"Converting from {audio.channels} channels to mono")
audio = audio.set_channels(AudioTranscriptHelper._DEFAULT_CHANNELS)

# Create temporary WAV file with PCM encoding for best compatibility
with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio:
audio.export(
temp_audio.name,
format="wav",
parameters=AudioTranscriptHelper._DEFAULT_EXPORT_PARAMS,
)
logger.info(
f"Audio exported to: {temp_audio.name} (duration={len(audio)}ms, rate={audio.frame_rate}Hz, mono)"
)
return temp_audio.name
output_path = temp_audio.name
subprocess.run(
[
"ffmpeg",
"-i",
video_path,
"-ar",
str(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE),
"-ac",
str(AudioTranscriptHelper._DEFAULT_CHANNELS),
"-acodec",
"pcm_s16le", # 16-bit PCM
output_path,
"-y",
],
check=True,
capture_output=True,
)
logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)")
return output_path
except Exception as e:
logger.warning(f"Failed to extract audio from video {video_path}: {e}")
return None