From a8b1d1c14b86c2bf3347575f956db33777f2b4c5 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 07:12:27 +0100 Subject: [PATCH 1/7] feat(transcribe): add speaker diarization support Add speaker diarization as a post-processing step for transcription using pyannote-audio. This identifies and labels different speakers in the transcript, useful for meetings, interviews, or multi-speaker audio. Features: - New `--diarize` flag to enable speaker diarization - `--diarize-format` option for inline (default) or JSON output - `--hf-token` for HuggingFace authentication (required for pyannote models) - `--min-speakers` and `--max-speakers` hints for improved accuracy - Works with any ASR provider (Wyoming, OpenAI, Gemini) - New optional dependency: `pip install agent-cli[diarization]` Output formats: - Inline: `[SPEAKER_00]: Hello, how are you?` - JSON: structured with speaker, timestamps, and text --- agent_cli/agents/transcribe.py | 94 ++++++- agent_cli/config.py | 13 + agent_cli/core/diarization.py | 209 ++++++++++++++++ agent_cli/opts.py | 33 +++ docs/commands/transcribe.md | 69 ++++++ pyproject.toml | 4 + tests/agents/test_transcribe_recovery.py | 25 ++ tests/test_diarization.py | 302 +++++++++++++++++++++++ 8 files changed, 748 insertions(+), 1 deletion(-) create mode 100644 agent_cli/core/diarization.py create mode 100644 tests/test_diarization.py diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index 7e1bee7e1..8dab7a833 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -256,6 +256,7 @@ async def _async_main( # noqa: PLR0912, PLR0915, C901 audio_file_path: Path | None = None, save_recording: bool = True, process_name: str | None = None, + diarization_cfg: config.Diarization | None = None, ) -> None: """Unified async entry point for both live and file-based transcription.""" start_time = time.monotonic() @@ -336,6 +337,63 @@ async def _async_main( # noqa: PLR0912, PLR0915, C901 elapsed = time.monotonic() - start_time + # Apply diarization if enabled + if diarization_cfg and diarization_cfg.diarize and transcript: + # Determine audio file path for diarization + diarize_audio_path = audio_file_path + if not diarize_audio_path and save_recording: + # For live recordings, get the most recently saved file + diarize_audio_path = get_last_recording(1) + + if diarize_audio_path and diarize_audio_path.exists(): + try: + from agent_cli.core.diarization import ( # noqa: PLC0415 + SpeakerDiarizer, + align_transcript_with_speakers, + format_diarized_output, + ) + + if not general_cfg.quiet: + print_with_style("🎙️ Running speaker diarization...", style="blue") + + # hf_token is validated in CLI before calling _async_main + assert diarization_cfg.hf_token is not None + diarizer = SpeakerDiarizer( + hf_token=diarization_cfg.hf_token, + min_speakers=diarization_cfg.min_speakers, + max_speakers=diarization_cfg.max_speakers, + ) + segments = diarizer.diarize(diarize_audio_path) + + if segments: + # Align transcript with speaker segments + segments = align_transcript_with_speakers(transcript, segments) + # Format output + transcript = format_diarized_output( + segments, + output_format=diarization_cfg.diarize_format, + ) + if not general_cfg.quiet: + print_with_style( + f"✅ Identified {len({s.speaker for s in segments})} speaker(s)", + style="green", + ) + else: + LOGGER.warning("Diarization returned no segments") + except ImportError as e: + print_with_style( + f"❌ Diarization failed: {e}", + style="red", + ) + except Exception as e: + LOGGER.exception("Diarization failed") + print_with_style( + f"❌ Diarization error: {e}", + style="red", + ) + else: + LOGGER.warning("No audio file available for diarization") + if llm_enabled and transcript: if not general_cfg.quiet: print_input_panel( @@ -433,7 +491,7 @@ async def _async_main( # noqa: PLR0912, PLR0915, C901 @app.command("transcribe") -def transcribe( # noqa: PLR0912 +def transcribe( # noqa: PLR0912, PLR0911 *, extra_instructions: str | None = typer.Option( None, @@ -478,6 +536,12 @@ def transcribe( # noqa: PLR0912 config_file: str | None = opts.CONFIG_FILE, print_args: bool = opts.PRINT_ARGS, transcription_log: Path | None = opts.TRANSCRIPTION_LOG, + # --- Diarization Options --- + diarize: bool = opts.DIARIZE, + diarize_format: str = opts.DIARIZE_FORMAT, + hf_token: str | None = opts.HF_TOKEN, + min_speakers: int | None = opts.MIN_SPEAKERS, + max_speakers: int | None = opts.MAX_SPEAKERS, ) -> None: """Wyoming ASR Client for streaming microphone audio to a transcription server.""" if print_args: @@ -488,6 +552,32 @@ def transcribe( # noqa: PLR0912 if transcription_log: transcription_log = transcription_log.expanduser() + # Validate diarization options + if diarize: + if not hf_token: + print_with_style( + "❌ --hf-token required for diarization. " + "Set HF_TOKEN env var or pass --hf-token. " + "Accept license at: https://huggingface.co/pyannote/speaker-diarization-3.1", + style="red", + ) + return + if not save_recording and not from_file and last_recording == 0: + print_with_style( + "❌ Diarization requires audio file. Use --save-recording (default) " + "or --from-file/--last-recording.", + style="red", + ) + return + + diarization_cfg = config.Diarization( + diarize=diarize, + diarize_format=diarize_format, + hf_token=hf_token, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + # Handle recovery options if last_recording and from_file: print_with_style("❌ Cannot use both --last-recording and --from-file", style="red") @@ -576,6 +666,7 @@ def transcribe( # noqa: PLR0912 gemini_llm_cfg=gemini_llm_cfg, llm_enabled=llm, transcription_log=transcription_log, + diarization_cfg=diarization_cfg, ), ) return @@ -622,5 +713,6 @@ def transcribe( # noqa: PLR0912 transcription_log=transcription_log, save_recording=save_recording, process_name=process_name, + diarization_cfg=diarization_cfg, ), ) diff --git a/agent_cli/config.py b/agent_cli/config.py index 65c078dfa..d938403cb 100644 --- a/agent_cli/config.py +++ b/agent_cli/config.py @@ -224,6 +224,19 @@ def _expand_user_path(cls, v: str | None) -> Path | None: return None +# --- Panel: Diarization Options --- + + +class Diarization(BaseModel): + """Configuration for speaker diarization.""" + + diarize: bool = False + diarize_format: str = "inline" + hf_token: str | None = None + min_speakers: int | None = None + max_speakers: int | None = None + + def _config_path(config_path_str: str | None = None) -> Path | None: """Return a usable config path, expanding user directories.""" if config_path_str: diff --git a/agent_cli/core/diarization.py b/agent_cli/core/diarization.py new file mode 100644 index 000000000..3a9060ed7 --- /dev/null +++ b/agent_cli/core/diarization.py @@ -0,0 +1,209 @@ +"""Speaker diarization using pyannote-audio.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path # noqa: TC003 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyannote.core import Annotation + + +def _check_pyannote_installed() -> None: + """Check if pyannote-audio is installed, raise ImportError with helpful message if not.""" + try: + import pyannote.audio # noqa: F401, PLC0415 + except ImportError as e: + msg = ( + "pyannote-audio is required for speaker diarization. " + "Install it with: `pip install agent-cli[diarization]` or `uv sync --extra diarization`." + ) + raise ImportError(msg) from e + + +@dataclass +class DiarizedSegment: + """A segment of speech attributed to a specific speaker.""" + + speaker: str + start: float + end: float + text: str = "" + + +class SpeakerDiarizer: + """Wrapper for pyannote speaker diarization pipeline. + + Requires a HuggingFace token with access to pyannote/speaker-diarization-3.1. + Users must accept the license at: https://huggingface.co/pyannote/speaker-diarization-3.1 + """ + + def __init__( + self, + hf_token: str, + min_speakers: int | None = None, + max_speakers: int | None = None, + ) -> None: + """Initialize the diarization pipeline. + + Args: + hf_token: HuggingFace token for accessing pyannote models. + min_speakers: Minimum number of speakers (optional hint). + max_speakers: Maximum number of speakers (optional hint). + + """ + _check_pyannote_installed() + from pyannote.audio import Pipeline # noqa: PLC0415 + + self.pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=hf_token, + ) + self.min_speakers = min_speakers + self.max_speakers = max_speakers + + def diarize(self, audio_path: Path) -> list[DiarizedSegment]: + """Run diarization on audio file, return speaker segments. + + Args: + audio_path: Path to the audio file (WAV format recommended). + + Returns: + List of DiarizedSegment with speaker labels and timestamps. + + """ + # Build kwargs for speaker count hints + kwargs: dict[str, int] = {} + if self.min_speakers is not None: + kwargs["min_speakers"] = self.min_speakers + if self.max_speakers is not None: + kwargs["max_speakers"] = self.max_speakers + + # Run the pipeline + diarization: Annotation = self.pipeline(str(audio_path), **kwargs) + + # Convert to our dataclass format + segments: list[DiarizedSegment] = [] + for turn, _, speaker in diarization.itertracks(yield_label=True): + segments.append( + DiarizedSegment( + speaker=speaker, + start=turn.start, + end=turn.end, + ), + ) + + return segments + + +def align_transcript_with_speakers( + transcript: str, + segments: list[DiarizedSegment], +) -> list[DiarizedSegment]: + """Align transcript text with speaker segments using simple word distribution. + + This is a basic alignment that distributes words proportionally based on + segment duration. For more accurate word-level alignment, consider using + WhisperX or similar tools. + + Args: + transcript: The full transcript text. + segments: List of speaker segments with timestamps. + + Returns: + List of DiarizedSegment with text filled in. + + """ + if not segments: + return segments + + words = transcript.split() + if not words: + return segments + + # Calculate total duration + total_duration = sum(seg.end - seg.start for seg in segments) + if total_duration <= 0: + # Fallback: distribute words evenly + words_per_segment = len(words) // len(segments) + result = [] + word_idx = 0 + for i, seg in enumerate(segments): + # Last segment gets remaining words + if i == len(segments) - 1: + seg_words = words[word_idx:] + else: + seg_words = words[word_idx : word_idx + words_per_segment] + word_idx += words_per_segment + result.append( + DiarizedSegment( + speaker=seg.speaker, + start=seg.start, + end=seg.end, + text=" ".join(seg_words), + ), + ) + return result + + # Distribute words based on segment duration + result = [] + word_idx = 0 + for i, seg in enumerate(segments): + seg_duration = seg.end - seg.start + # Calculate proportion of words for this segment + if i == len(segments) - 1: + # Last segment gets all remaining words + seg_words = words[word_idx:] + else: + proportion = seg_duration / total_duration + word_count = max(1, round(proportion * len(words))) + seg_words = words[word_idx : word_idx + word_count] + word_idx += word_count + # Adjust total_duration for remaining segments + total_duration -= seg_duration + + result.append( + DiarizedSegment( + speaker=seg.speaker, + start=seg.start, + end=seg.end, + text=" ".join(seg_words), + ), + ) + + return result + + +def format_diarized_output( + segments: list[DiarizedSegment], + output_format: str = "inline", +) -> str: + """Format diarized segments for output. + + Args: + segments: List of DiarizedSegment with speaker labels and text. + output_format: "inline" for human-readable, "json" for structured output. + + Returns: + Formatted string representation of the diarized transcript. + + """ + if output_format == "json": + data = { + "segments": [ + { + "speaker": seg.speaker, + "start": round(seg.start, 2), + "end": round(seg.end, 2), + "text": seg.text, + } + for seg in segments + ], + } + return json.dumps(data, indent=2) + + # Inline format: [Speaker X]: text + lines = [f"[{seg.speaker}]: {seg.text}" for seg in segments if seg.text] + return "\n".join(lines) diff --git a/agent_cli/opts.py b/agent_cli/opts.py index af1573f8d..b73b3842d 100644 --- a/agent_cli/opts.py +++ b/agent_cli/opts.py @@ -408,3 +408,36 @@ def _conf_callback(ctx: typer.Context, param: typer.CallbackParam, value: str) - help="Save the audio recording to disk for recovery.", rich_help_panel="Audio Recovery", ) + +# --- Diarization Options --- +DIARIZE: bool = typer.Option( + False, # noqa: FBT003 + "--diarize/--no-diarize", + help="Enable speaker diarization (requires pyannote-audio). Install with: pip install agent-cli[diarization]", + rich_help_panel="Diarization", +) +DIARIZE_FORMAT: str = typer.Option( + "inline", + "--diarize-format", + help="Output format for diarization ('inline' for [Speaker N]: text, 'json' for structured output).", + rich_help_panel="Diarization", +) +HF_TOKEN: str | None = typer.Option( + None, + "--hf-token", + help="HuggingFace token for pyannote models. Required for diarization. Accept license at: https://huggingface.co/pyannote/speaker-diarization-3.1", + envvar="HF_TOKEN", + rich_help_panel="Diarization", +) +MIN_SPEAKERS: int | None = typer.Option( + None, + "--min-speakers", + help="Minimum number of speakers (optional hint for diarization).", + rich_help_panel="Diarization", +) +MAX_SPEAKERS: int | None = typer.Option( + None, + "--max-speakers", + help="Maximum number of speakers (optional hint for diarization).", + rich_help_panel="Diarization", +) diff --git a/docs/commands/transcribe.md b/docs/commands/transcribe.md index e5252f5b7..ba4f3001a 100644 --- a/docs/commands/transcribe.md +++ b/docs/commands/transcribe.md @@ -45,6 +45,15 @@ agent-cli transcribe --from-file voice_memo.m4a --asr-provider gemini # Re-transcribe most recent recording agent-cli transcribe --last-recording 1 + +# Transcribe with speaker diarization (identifies different speakers) +agent-cli transcribe --diarize --hf-token YOUR_HF_TOKEN + +# Diarization with JSON output format +agent-cli transcribe --diarize --diarize-format json --hf-token YOUR_HF_TOKEN + +# Diarize a file with known number of speakers +agent-cli transcribe --from-file meeting.wav --diarize --min-speakers 2 --max-speakers 4 --hf-token YOUR_HF_TOKEN ``` ## Supported Audio Formats @@ -161,6 +170,16 @@ The `--from-file` option supports multiple audio formats: | `--print-args` | `false` | Print the command line arguments, including variables taken from the configuration file. | | `--transcription-log` | - | Path to log transcription results with timestamps, hostname, model, and raw output. | +### Diarization + +| Option | Default | Description | +|--------|---------|-------------| +| `--diarize/--no-diarize` | `false` | Enable speaker diarization (requires pyannote-audio). Install with: pip install agent-cli[diarization] | +| `--diarize-format` | `inline` | Output format for diarization ('inline' for [Speaker N]: text, 'json' for structured output). | +| `--hf-token` | - | HuggingFace token for pyannote models. Required for diarization. Accept license at: https://huggingface.co/pyannote/speaker-diarization-3.1 | +| `--min-speakers` | - | Minimum number of speakers (optional hint for diarization). | +| `--max-speakers` | - | Maximum number of speakers (optional hint for diarization). | + @@ -197,3 +216,53 @@ agent-cli transcribe --transcription-log ~/.config/agent-cli/transcriptions.log - Use `--list-devices` to find your microphone's index - Enable `--llm` for cleaner output with proper punctuation - Use `--last-recording 1` to re-transcribe if you need to adjust settings + +## Speaker Diarization + +Speaker diarization identifies and labels different speakers in the transcript. This is useful for meeting recordings, interviews, or any multi-speaker audio. + +### Requirements + +1. **Install the diarization extra**: + ```bash + pip install agent-cli[diarization] + # or with uv + uv sync --extra diarization + ``` + +2. **HuggingFace token**: The pyannote-audio models are gated. You need to: + - Accept the license at [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) + - Get your token from [HuggingFace settings](https://huggingface.co/settings/tokens) + - Provide it via `--hf-token` or the `HF_TOKEN` environment variable + +### Output Formats + +**Inline format** (default): +``` +[SPEAKER_00]: Hello, how are you today? +[SPEAKER_01]: I'm doing well, thanks for asking! +[SPEAKER_00]: Great to hear. +``` + +**JSON format** (`--diarize-format json`): +```json +{ + "segments": [ + {"speaker": "SPEAKER_00", "start": 0.0, "end": 2.5, "text": "Hello, how are you today?"}, + {"speaker": "SPEAKER_01", "start": 2.7, "end": 4.1, "text": "I'm doing well, thanks for asking!"}, + {"speaker": "SPEAKER_00", "start": 4.3, "end": 5.2, "text": "Great to hear."} + ] +} +``` + +### Speaker Hints + +If you know how many speakers are in the recording, use `--min-speakers` and `--max-speakers` to improve accuracy: + +```bash +# For a two-person interview +agent-cli transcribe --from-file interview.wav --diarize --min-speakers 2 --max-speakers 2 --hf-token YOUR_TOKEN +``` + +> [!NOTE] +> Diarization requires the audio file to be saved. When using live recording with `--diarize`, ensure `--save-recording` is enabled (it's enabled by default). diff --git a/pyproject.toml b/pyproject.toml index 3967e6954..339f97fa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,10 @@ memory = [ vad = [ "silero-vad>=5.1", ] +diarization = [ + "pyannote-audio>=3.3", + "torch>=2.0", +] test = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", diff --git a/tests/agents/test_transcribe_recovery.py b/tests/agents/test_transcribe_recovery.py index f8e880e47..038a38d1f 100644 --- a/tests/agents/test_transcribe_recovery.py +++ b/tests/agents/test_transcribe_recovery.py @@ -470,6 +470,11 @@ def test_transcribe_command_last_recording_option( config_file=None, print_args=False, transcription_log=None, + diarize=False, + diarize_format="inline", + hf_token=None, + min_speakers=None, + max_speakers=None, ) # Verify _async_main_from_file was called @@ -526,6 +531,11 @@ def test_transcribe_command_from_file_option(tmp_path: Path): config_file=None, print_args=False, transcription_log=None, + diarize=False, + diarize_format="inline", + hf_token=None, + min_speakers=None, + max_speakers=None, ) # Verify _async_main_from_file was called with the right file @@ -594,6 +604,11 @@ def test_transcribe_command_last_recording_with_index( config_file=None, print_args=False, transcription_log=None, + diarize=False, + diarize_format="inline", + hf_token=None, + min_speakers=None, + max_speakers=None, ) # Verify _async_main_from_file was called @@ -660,6 +675,11 @@ def test_transcribe_command_last_recording_disabled( config_file=None, print_args=False, transcription_log=None, + diarize=False, + diarize_format="inline", + hf_token=None, + min_speakers=None, + max_speakers=None, ) # Verify _async_main was called for normal recording (not from file) @@ -709,6 +729,11 @@ def test_transcribe_command_conflicting_options() -> None: config_file=None, print_args=False, transcription_log=None, + diarize=False, + diarize_format="inline", + hf_token=None, + min_speakers=None, + max_speakers=None, ) # Verify error message diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 000000000..b628b1e46 --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,302 @@ +"""Tests for the speaker diarization module.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest + +from agent_cli.core.diarization import ( + DiarizedSegment, + align_transcript_with_speakers, + format_diarized_output, +) + +if TYPE_CHECKING: + from pathlib import Path + + +class TestDiarizedSegment: + """Tests for the DiarizedSegment dataclass.""" + + def test_create_segment(self): + """Test creating a diarized segment.""" + segment = DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=2.5, text="Hello") + assert segment.speaker == "SPEAKER_00" + assert segment.start == 0.0 + assert segment.end == 2.5 + assert segment.text == "Hello" + + def test_segment_default_text(self): + """Test that text defaults to empty string.""" + segment = DiarizedSegment(speaker="SPEAKER_01", start=1.0, end=3.0) + assert segment.text == "" + + +class TestAlignTranscriptWithSpeakers: + """Tests for the align_transcript_with_speakers function.""" + + def test_empty_segments(self): + """Test with empty segment list.""" + result = align_transcript_with_speakers("Hello world", []) + assert result == [] + + def test_empty_transcript(self): + """Test with empty transcript.""" + segments = [DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=2.0)] + result = align_transcript_with_speakers("", segments) + assert len(result) == 1 + assert result[0].text == "" + + def test_single_segment(self): + """Test alignment with a single segment.""" + segments = [DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=5.0)] + result = align_transcript_with_speakers("Hello world", segments) + assert len(result) == 1 + assert result[0].text == "Hello world" + assert result[0].speaker == "SPEAKER_00" + + def test_multiple_segments_proportional(self): + """Test word distribution based on segment duration.""" + segments = [ + DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=2.0), # 2s + DiarizedSegment(speaker="SPEAKER_01", start=2.0, end=4.0), # 2s + ] + result = align_transcript_with_speakers("one two three four", segments) + assert len(result) == 2 + # With equal durations, words should be split roughly evenly + # Last segment gets remaining words + assert result[0].speaker == "SPEAKER_00" + assert result[1].speaker == "SPEAKER_01" + # Total words should equal original + all_words = result[0].text.split() + result[1].text.split() + assert all_words == ["one", "two", "three", "four"] + + def test_zero_duration_fallback(self): + """Test fallback when total duration is zero.""" + segments = [ + DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=0.0), + DiarizedSegment(speaker="SPEAKER_01", start=0.0, end=0.0), + ] + result = align_transcript_with_speakers("one two three four", segments) + assert len(result) == 2 + # Words should be distributed evenly + all_words = result[0].text.split() + result[1].text.split() + assert all_words == ["one", "two", "three", "four"] + + +class TestFormatDiarizedOutput: + """Tests for the format_diarized_output function.""" + + def test_inline_format(self): + """Test inline format output.""" + segments = [ + DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=2.0, text="Hello"), + DiarizedSegment(speaker="SPEAKER_01", start=2.0, end=4.0, text="Hi there"), + ] + result = format_diarized_output(segments, output_format="inline") + expected = "[SPEAKER_00]: Hello\n[SPEAKER_01]: Hi there" + assert result == expected + + def test_inline_skips_empty_text(self): + """Test that inline format skips segments with empty text.""" + segments = [ + DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=2.0, text="Hello"), + DiarizedSegment(speaker="SPEAKER_01", start=2.0, end=4.0, text=""), + DiarizedSegment(speaker="SPEAKER_00", start=4.0, end=6.0, text="Goodbye"), + ] + result = format_diarized_output(segments, output_format="inline") + expected = "[SPEAKER_00]: Hello\n[SPEAKER_00]: Goodbye" + assert result == expected + + def test_json_format(self): + """Test JSON format output.""" + segments = [ + DiarizedSegment(speaker="SPEAKER_00", start=0.0, end=2.5, text="Hello"), + DiarizedSegment(speaker="SPEAKER_01", start=2.7, end=4.1, text="Hi there"), + ] + result = format_diarized_output(segments, output_format="json") + parsed = json.loads(result) + assert "segments" in parsed + assert len(parsed["segments"]) == 2 + assert parsed["segments"][0]["speaker"] == "SPEAKER_00" + assert parsed["segments"][0]["start"] == 0.0 + assert parsed["segments"][0]["end"] == 2.5 + assert parsed["segments"][0]["text"] == "Hello" + assert parsed["segments"][1]["speaker"] == "SPEAKER_01" + assert parsed["segments"][1]["start"] == 2.7 + assert parsed["segments"][1]["end"] == 4.1 + assert parsed["segments"][1]["text"] == "Hi there" + + def test_json_rounds_timestamps(self): + """Test that JSON format rounds timestamps to 2 decimal places.""" + segments = [ + DiarizedSegment( + speaker="SPEAKER_00", + start=0.123456, + end=2.987654, + text="Hello", + ), + ] + result = format_diarized_output(segments, output_format="json") + parsed = json.loads(result) + assert parsed["segments"][0]["start"] == 0.12 + assert parsed["segments"][0]["end"] == 2.99 + + def test_empty_segments(self): + """Test with empty segment list.""" + result_inline = format_diarized_output([], output_format="inline") + result_json = format_diarized_output([], output_format="json") + assert result_inline == "" + parsed = json.loads(result_json) + assert parsed["segments"] == [] + + +class TestCheckPyannoteInstalled: + """Tests for the pyannote installation check.""" + + def test_check_raises_when_not_installed(self): + """Test that ImportError is raised when pyannote is not installed.""" + from agent_cli.core.diarization import _check_pyannote_installed # noqa: PLC0415 + + with ( + patch.dict("sys.modules", {"pyannote.audio": None}), + patch( + "builtins.__import__", + side_effect=ImportError("No module named 'pyannote'"), + ), + pytest.raises(ImportError) as exc_info, + ): + _check_pyannote_installed() + assert "pyannote-audio is required" in str(exc_info.value) + assert "pip install agent-cli[diarization]" in str(exc_info.value) + + +class TestSpeakerDiarizer: + """Tests for the SpeakerDiarizer class.""" + + def test_diarizer_init_without_pyannote(self): + """Test that SpeakerDiarizer raises ImportError when pyannote not installed.""" + from agent_cli.core.diarization import SpeakerDiarizer # noqa: PLC0415 + + with ( + patch( + "agent_cli.core.diarization._check_pyannote_installed", + side_effect=ImportError("pyannote-audio is required"), + ), + pytest.raises(ImportError), + ): + SpeakerDiarizer(hf_token="test_token") # noqa: S106 + + def test_diarizer_init_with_mock_pyannote(self): + """Test SpeakerDiarizer initialization with mocked pyannote.""" + from agent_cli.core.diarization import SpeakerDiarizer # noqa: PLC0415 + + mock_pipeline = MagicMock() + mock_pipeline_class = MagicMock() + mock_pipeline_class.from_pretrained.return_value = mock_pipeline + + with ( + patch( + "agent_cli.core.diarization._check_pyannote_installed", + ), + patch.dict( + "sys.modules", + {"pyannote.audio": MagicMock(Pipeline=mock_pipeline_class)}, + ), + ): + diarizer = SpeakerDiarizer( + hf_token="test_token", # noqa: S106 + min_speakers=2, + max_speakers=4, + ) + assert diarizer.min_speakers == 2 + assert diarizer.max_speakers == 4 + mock_pipeline_class.from_pretrained.assert_called_once_with( + "pyannote/speaker-diarization-3.1", + use_auth_token="test_token", # noqa: S106 + ) + + def test_diarizer_diarize(self, tmp_path: Path): + """Test diarization with mocked pipeline.""" + from agent_cli.core.diarization import SpeakerDiarizer # noqa: PLC0415 + + # Create a mock diarization result + mock_turn1 = MagicMock() + mock_turn1.start = 0.0 + mock_turn1.end = 2.5 + mock_turn2 = MagicMock() + mock_turn2.start = 2.5 + mock_turn2.end = 5.0 + + mock_annotation = MagicMock() + mock_annotation.itertracks.return_value = [ + (mock_turn1, None, "SPEAKER_00"), + (mock_turn2, None, "SPEAKER_01"), + ] + + mock_pipeline = MagicMock() + mock_pipeline.return_value = mock_annotation + + mock_pipeline_class = MagicMock() + mock_pipeline_class.from_pretrained.return_value = mock_pipeline + + with ( + patch("agent_cli.core.diarization._check_pyannote_installed"), + patch.dict( + "sys.modules", + {"pyannote.audio": MagicMock(Pipeline=mock_pipeline_class)}, + ), + ): + diarizer = SpeakerDiarizer(hf_token="test_token") # noqa: S106 + audio_file = tmp_path / "test.wav" + audio_file.touch() + + segments = diarizer.diarize(audio_file) + + assert len(segments) == 2 + assert segments[0].speaker == "SPEAKER_00" + assert segments[0].start == 0.0 + assert segments[0].end == 2.5 + assert segments[1].speaker == "SPEAKER_01" + assert segments[1].start == 2.5 + assert segments[1].end == 5.0 + mock_pipeline.assert_called_once_with(str(audio_file)) + + def test_diarizer_diarize_with_speaker_hints(self, tmp_path: Path): + """Test diarization passes speaker hints to pipeline.""" + from agent_cli.core.diarization import SpeakerDiarizer # noqa: PLC0415 + + mock_annotation = MagicMock() + mock_annotation.itertracks.return_value = [] + + mock_pipeline = MagicMock() + mock_pipeline.return_value = mock_annotation + + mock_pipeline_class = MagicMock() + mock_pipeline_class.from_pretrained.return_value = mock_pipeline + + with ( + patch("agent_cli.core.diarization._check_pyannote_installed"), + patch.dict( + "sys.modules", + {"pyannote.audio": MagicMock(Pipeline=mock_pipeline_class)}, + ), + ): + diarizer = SpeakerDiarizer( + hf_token="test_token", # noqa: S106 + min_speakers=2, + max_speakers=4, + ) + audio_file = tmp_path / "test.wav" + audio_file.touch() + + diarizer.diarize(audio_file) + + mock_pipeline.assert_called_once_with( + str(audio_file), + min_speakers=2, + max_speakers=4, + ) From be3ad09655687cfd9673b451ed7829b4c6b1e811 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 22:07:33 +0100 Subject: [PATCH 2/7] chore: let pyannote-audio manage torch dependency --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 339f97fa0..7be90a8e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ vad = [ ] diarization = [ "pyannote-audio>=3.3", - "torch>=2.0", ] test = [ "pytest>=7.0.0", From 07d722b21052408be3f261960c27fe830a17190c Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 22:15:59 +0100 Subject: [PATCH 3/7] fix: use 'token' instead of deprecated 'use_auth_token' for pyannote --- agent_cli/core/diarization.py | 2 +- tests/test_diarization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agent_cli/core/diarization.py b/agent_cli/core/diarization.py index 3a9060ed7..9d3b0d1a1 100644 --- a/agent_cli/core/diarization.py +++ b/agent_cli/core/diarization.py @@ -59,7 +59,7 @@ def __init__( self.pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", - use_auth_token=hf_token, + token=hf_token, ) self.min_speakers = min_speakers self.max_speakers = max_speakers diff --git a/tests/test_diarization.py b/tests/test_diarization.py index b628b1e46..276f040eb 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -216,7 +216,7 @@ def test_diarizer_init_with_mock_pyannote(self): assert diarizer.max_speakers == 4 mock_pipeline_class.from_pretrained.assert_called_once_with( "pyannote/speaker-diarization-3.1", - use_auth_token="test_token", # noqa: S106 + token="test_token", # noqa: S106 ) def test_diarizer_diarize(self, tmp_path: Path): From ecea29073178dff47a9ea27d4503660b472e11e1 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 22:20:55 +0100 Subject: [PATCH 4/7] docs: add all required model licenses and token permission info --- agent_cli/agents/transcribe.py | 4 +++- agent_cli/opts.py | 7 ++++++- docs/commands/transcribe.md | 6 +++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index 8dab7a833..62ab5d1f2 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -558,7 +558,9 @@ def transcribe( # noqa: PLR0912, PLR0911 print_with_style( "❌ --hf-token required for diarization. " "Set HF_TOKEN env var or pass --hf-token. " - "Accept license at: https://huggingface.co/pyannote/speaker-diarization-3.1", + "Token must have 'Read access to contents of all public gated repos you can access' permission. " + "Accept licenses at: https://hf.co/pyannote/speaker-diarization-3.1, " + "https://hf.co/pyannote/segmentation-3.0, https://hf.co/pyannote/wespeaker-voxceleb-resnet34-LM", style="red", ) return diff --git a/agent_cli/opts.py b/agent_cli/opts.py index b73b3842d..5f3460ce0 100644 --- a/agent_cli/opts.py +++ b/agent_cli/opts.py @@ -425,7 +425,12 @@ def _conf_callback(ctx: typer.Context, param: typer.CallbackParam, value: str) - HF_TOKEN: str | None = typer.Option( None, "--hf-token", - help="HuggingFace token for pyannote models. Required for diarization. Accept license at: https://huggingface.co/pyannote/speaker-diarization-3.1", + help=( + "HuggingFace token for pyannote models. Required for diarization. " + "Token must have 'Read access to contents of all public gated repos you can access' permission. " + "Accept licenses at: https://hf.co/pyannote/speaker-diarization-3.1, " + "https://hf.co/pyannote/segmentation-3.0, https://hf.co/pyannote/wespeaker-voxceleb-resnet34-LM" + ), envvar="HF_TOKEN", rich_help_panel="Diarization", ) diff --git a/docs/commands/transcribe.md b/docs/commands/transcribe.md index ba4f3001a..0ec659c92 100644 --- a/docs/commands/transcribe.md +++ b/docs/commands/transcribe.md @@ -231,8 +231,12 @@ Speaker diarization identifies and labels different speakers in the transcript. ``` 2. **HuggingFace token**: The pyannote-audio models are gated. You need to: - - Accept the license at [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) + - Accept the license for all three models: + - [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) + - [pyannote/segmentation-3.0](https://huggingface.co/pyannote/segmentation-3.0) + - [pyannote/wespeaker-voxceleb-resnet34-LM](https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM) - Get your token from [HuggingFace settings](https://huggingface.co/settings/tokens) + - Token must have **"Read access to contents of all public gated repos you can access"** permission - Provide it via `--hf-token` or the `HF_TOKEN` environment variable ### Output Formats From 441c6dc1b499b4bec709b32267e132ed9a453e37 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 22:25:54 +0100 Subject: [PATCH 5/7] fix: pre-load audio with torchaudio to avoid torchcodec/FFmpeg issues --- agent_cli/core/diarization.py | 8 +++++++- tests/test_diarization.py | 30 ++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/agent_cli/core/diarization.py b/agent_cli/core/diarization.py index 9d3b0d1a1..968b7f94f 100644 --- a/agent_cli/core/diarization.py +++ b/agent_cli/core/diarization.py @@ -74,6 +74,8 @@ def diarize(self, audio_path: Path) -> list[DiarizedSegment]: List of DiarizedSegment with speaker labels and timestamps. """ + import torchaudio # noqa: PLC0415 + # Build kwargs for speaker count hints kwargs: dict[str, int] = {} if self.min_speakers is not None: @@ -81,8 +83,12 @@ def diarize(self, audio_path: Path) -> list[DiarizedSegment]: if self.max_speakers is not None: kwargs["max_speakers"] = self.max_speakers + # Pre-load audio to avoid torchcodec/FFmpeg issues + waveform, sample_rate = torchaudio.load(str(audio_path)) + audio_input = {"waveform": waveform, "sample_rate": sample_rate} + # Run the pipeline - diarization: Annotation = self.pipeline(str(audio_path), **kwargs) + diarization: Annotation = self.pipeline(audio_input, **kwargs) # Convert to our dataclass format segments: list[DiarizedSegment] = [] diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 276f040eb..77f783fe2 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -221,6 +221,8 @@ def test_diarizer_init_with_mock_pyannote(self): def test_diarizer_diarize(self, tmp_path: Path): """Test diarization with mocked pipeline.""" + import torch # noqa: PLC0415 + from agent_cli.core.diarization import SpeakerDiarizer # noqa: PLC0415 # Create a mock diarization result @@ -243,12 +245,17 @@ def test_diarizer_diarize(self, tmp_path: Path): mock_pipeline_class = MagicMock() mock_pipeline_class.from_pretrained.return_value = mock_pipeline + # Mock torchaudio.load + mock_waveform = torch.zeros(1, 16000) + mock_sample_rate = 16000 + with ( patch("agent_cli.core.diarization._check_pyannote_installed"), patch.dict( "sys.modules", {"pyannote.audio": MagicMock(Pipeline=mock_pipeline_class)}, ), + patch("torchaudio.load", return_value=(mock_waveform, mock_sample_rate)), ): diarizer = SpeakerDiarizer(hf_token="test_token") # noqa: S106 audio_file = tmp_path / "test.wav" @@ -263,10 +270,16 @@ def test_diarizer_diarize(self, tmp_path: Path): assert segments[1].speaker == "SPEAKER_01" assert segments[1].start == 2.5 assert segments[1].end == 5.0 - mock_pipeline.assert_called_once_with(str(audio_file)) + # Pipeline should be called with audio dict, not file path + mock_pipeline.assert_called_once() + call_args = mock_pipeline.call_args[0][0] + assert "waveform" in call_args + assert "sample_rate" in call_args def test_diarizer_diarize_with_speaker_hints(self, tmp_path: Path): """Test diarization passes speaker hints to pipeline.""" + import torch # noqa: PLC0415 + from agent_cli.core.diarization import SpeakerDiarizer # noqa: PLC0415 mock_annotation = MagicMock() @@ -278,12 +291,17 @@ def test_diarizer_diarize_with_speaker_hints(self, tmp_path: Path): mock_pipeline_class = MagicMock() mock_pipeline_class.from_pretrained.return_value = mock_pipeline + # Mock torchaudio.load + mock_waveform = torch.zeros(1, 16000) + mock_sample_rate = 16000 + with ( patch("agent_cli.core.diarization._check_pyannote_installed"), patch.dict( "sys.modules", {"pyannote.audio": MagicMock(Pipeline=mock_pipeline_class)}, ), + patch("torchaudio.load", return_value=(mock_waveform, mock_sample_rate)), ): diarizer = SpeakerDiarizer( hf_token="test_token", # noqa: S106 @@ -295,8 +313,8 @@ def test_diarizer_diarize_with_speaker_hints(self, tmp_path: Path): diarizer.diarize(audio_file) - mock_pipeline.assert_called_once_with( - str(audio_file), - min_speakers=2, - max_speakers=4, - ) + # Check speaker hints were passed + mock_pipeline.assert_called_once() + call_kwargs = mock_pipeline.call_args[1] + assert call_kwargs["min_speakers"] == 2 + assert call_kwargs["max_speakers"] == 4 From 0465090ee14026c71ea79773c87679ed12ae8a4b Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 22:27:25 +0100 Subject: [PATCH 6/7] fix: handle new DiarizeOutput API from pyannote-audio --- agent_cli/core/diarization.py | 10 +++++++++- tests/test_diarization.py | 6 +++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/agent_cli/core/diarization.py b/agent_cli/core/diarization.py index 968b7f94f..ce688654f 100644 --- a/agent_cli/core/diarization.py +++ b/agent_cli/core/diarization.py @@ -88,7 +88,15 @@ def diarize(self, audio_path: Path) -> list[DiarizedSegment]: audio_input = {"waveform": waveform, "sample_rate": sample_rate} # Run the pipeline - diarization: Annotation = self.pipeline(audio_input, **kwargs) + output = self.pipeline(audio_input, **kwargs) + + # Handle both old (Annotation) and new (DiarizeOutput) API + if hasattr(output, "speaker_diarization"): + # New API: DiarizeOutput dataclass + diarization: Annotation = output.speaker_diarization + else: + # Old API: returns Annotation directly + diarization = output # Convert to our dataclass format segments: list[DiarizedSegment] = [] diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 77f783fe2..772501e1a 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -239,8 +239,12 @@ def test_diarizer_diarize(self, tmp_path: Path): (mock_turn2, None, "SPEAKER_01"), ] + # Mock DiarizeOutput (new API) - set spec to avoid auto-creating attributes + mock_output = MagicMock(spec=[]) # Empty spec means hasattr returns False + mock_output.itertracks = mock_annotation.itertracks + mock_pipeline = MagicMock() - mock_pipeline.return_value = mock_annotation + mock_pipeline.return_value = mock_output mock_pipeline_class = MagicMock() mock_pipeline_class.from_pretrained.return_value = mock_pipeline From 17fd7bf1ee9a253692d527fbf5472b5d93af45bf Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Sat, 10 Jan 2026 22:40:40 +0100 Subject: [PATCH 7/7] fix: show all required model URLs on gated repo access error --- agent_cli/agents/transcribe.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index 62ab5d1f2..9a2da71ed 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -387,10 +387,24 @@ async def _async_main( # noqa: PLR0912, PLR0915, C901 ) except Exception as e: LOGGER.exception("Diarization failed") - print_with_style( - f"❌ Diarization error: {e}", - style="red", - ) + error_msg = str(e) + # Check if it's a gated repo access error + if "403" in error_msg or "gated" in error_msg.lower(): + print_with_style( + "❌ Diarization failed: HuggingFace model access denied.\n" + "Accept licenses for ALL required models:\n" + " • https://hf.co/pyannote/speaker-diarization-3.1\n" + " • https://hf.co/pyannote/segmentation-3.0\n" + " • https://hf.co/pyannote/wespeaker-voxceleb-resnet34-LM\n" + " • https://hf.co/pyannote/speaker-diarization-community-1\n" + "Token must have 'Read access to public gated repos' permission.", + style="red", + ) + else: + print_with_style( + f"❌ Diarization error: {e}", + style="red", + ) else: LOGGER.warning("No audio file available for diarization")