diff --git a/whisperx/__main__.py b/whisperx/__main__.py index e7f80be..d1fd16d 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -44,6 +44,7 @@ def cli(): parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file") parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file") parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use") + parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)") parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 26f33e4..12bb6ba 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd from pyannote.audio import Pipeline -from typing import Optional, Union +from typing import Optional, Union, Tuple, Dict, List, Any import torch from whisperx.audio import load_audio, SAMPLE_RATE @@ -26,25 +26,81 @@ class DiarizationPipeline: num_speakers: Optional[int] = None, min_speakers: Optional[int] = None, max_speakers: Optional[int] = None, - ): + return_embeddings: bool = False, + ) -> Union[Tuple[pd.DataFrame, Optional[Dict[str, List[float]]]], pd.DataFrame]: + """ + Perform speaker diarization on audio. + + Args: + audio: Path to audio file or audio array + num_speakers: Exact number of speakers (if known) + min_speakers: Minimum number of speakers to detect + max_speakers: Maximum number of speakers to detect + return_embeddings: Whether to return speaker embeddings + + Returns: + If return_embeddings is True: + Tuple of (diarization dataframe, speaker embeddings dictionary) + Otherwise: + Just the diarization dataframe + """ if isinstance(audio, str): audio = load_audio(audio) audio_data = { 'waveform': torch.from_numpy(audio[None, :]), 'sample_rate': SAMPLE_RATE } - segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) - diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) + + if return_embeddings: + diarization, embeddings = self.model( + audio_data, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + return_embeddings=True + ) + else: + diarization = self.model( + audio_data, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers + ) + embeddings = None + + diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) - return diarize_df + + if return_embeddings and embeddings is not None: + speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())} + return diarize_df, speaker_embeddings + + # For backwards compatibility + if return_embeddings: + return diarize_df, None + else: + return diarize_df def assign_word_speakers( diarize_df: pd.DataFrame, transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult], - fill_nearest=False, -) -> dict: + speaker_embeddings: Optional[Dict[str, List[float]]] = None, + fill_nearest: bool = False, +) -> Union[AlignedTranscriptionResult, TranscriptionResult]: + """ + Assign speakers to words and segments in the transcript. + + Args: + diarize_df: Diarization dataframe from DiarizationPipeline + transcript_result: Transcription result to augment with speaker labels + speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors + fill_nearest: If True, assign speakers even when there's no direct time overlap + + Returns: + Updated transcript_result with speaker assignments and optionally embeddings + """ transcript_segments = transcript_result["segments"] for seg in transcript_segments: # assign speaker to segment (if any) @@ -75,7 +131,11 @@ def assign_word_speakers( # sum over speakers speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] word["speaker"] = speaker - + + # Add speaker embeddings to the result if provided + if speaker_embeddings is not None: + transcript_result["speaker_embeddings"] = speaker_embeddings + return transcript_result diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 867b378..c1b599a 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -59,6 +59,10 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): max_speakers: int = args.pop("max_speakers") diarize_model_name: str = args.pop("diarize_model") print_progress: bool = args.pop("print_progress") + return_speaker_embeddings: bool = args.pop("speaker_embeddings") + + if return_speaker_embeddings and not diarize: + warnings.warn("--speaker_embeddings has no effect without --diarize") if args["language"] is not None: args["language"] = args["language"].lower() @@ -209,10 +213,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): results = [] diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: - diarize_segments = diarize_model( - input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers + diarize_segments, speaker_embeddings = diarize_model( + input_audio_path, + min_speakers=min_speakers, + max_speakers=max_speakers, + return_embeddings=return_speaker_embeddings ) - result = assign_word_speakers(diarize_segments, result) + result = assign_word_speakers(diarize_segments, result, speaker_embeddings) results.append((result, input_audio_path)) # >> Write for result, audio_path in results: