diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 8b10530..3632cd8 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -30,16 +30,16 @@ class DiarizationPipeline: ) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]: """ Perform speaker diarization on audio. - + Args: audio: Path to audio file or audio array num_speakers: Exact number of speakers (if known) min_speakers: Minimum number of speakers to detect max_speakers: Maximum number of speakers to detect return_embeddings: Whether to return speaker embeddings - + Returns: - If return_embeddings is True: + If return_embeddings is True: Tuple of (diarization dataframe, speaker embeddings dictionary) Otherwise: Just the diarization dataframe @@ -53,18 +53,18 @@ class DiarizationPipeline: if return_embeddings: diarization, embeddings = self.model( - audio_data, + audio_data, num_speakers=num_speakers, - min_speakers=min_speakers, - max_speakers=max_speakers, - return_embeddings=True + min_speakers=min_speakers, + max_speakers=max_speakers, + return_embeddings=True, ) else: diarization = self.model( - audio_data, + audio_data, num_speakers=num_speakers, - min_speakers=min_speakers, - max_speakers=max_speakers + min_speakers=min_speakers, + max_speakers=max_speakers, ) embeddings = None @@ -91,13 +91,13 @@ def assign_word_speakers( ) -> Union[AlignedTranscriptionResult, TranscriptionResult]: """ Assign speakers to words and segments in the transcript. - + Args: diarize_df: Diarization dataframe from DiarizationPipeline transcript_result: Transcription result to augment with speaker labels speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors fill_nearest: If True, assign speakers even when there's no direct time overlap - + Returns: Updated transcript_result with speaker assignments and optionally embeddings """ @@ -131,12 +131,12 @@ def assign_word_speakers( # sum over speakers speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] word["speaker"] = speaker - + # Add speaker embeddings to the result if provided if speaker_embeddings is not None: transcript_result["speaker_embeddings"] = speaker_embeddings - - return transcript_result + + return transcript_result class Segment: