From c7d31883bcce818ed78264a05cdc666ef7d022d2 Mon Sep 17 00:00:00 2001 From: Alex Cannan Date: Tue, 18 Feb 2025 12:25:57 -0500 Subject: [PATCH] Add jr, sr, and ph.d to punkt abbreviations --- whisperx/alignment.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 34fbbbb..3e19292 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -24,7 +24,7 @@ from whisperx.types import ( ) from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters -PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] +PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof', 'jr', 'sr', 'ph.d'] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -124,14 +124,14 @@ def align( """ Align phoneme recognition predictions to known transcription. """ - + if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) audio = torch.from_numpy(audio) if len(audio.shape) == 1: audio = audio.unsqueeze(0) - + MAX_DURATION = audio.shape[1] / SAMPLE_RATE model_dictionary = align_model_metadata["dictionary"] @@ -148,7 +148,7 @@ def align( base_progress = ((sdx + 1) / total_segments) * 100 percent_complete = (50 + base_progress / 2) if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") - + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) text = segment["text"] @@ -165,7 +165,7 @@ def align( # wav2vec2 models use "|" character to represent spaces if model_lang not in LANGUAGES_WITHOUT_SPACES: char_ = char_.replace(" ", "|") - + # ignore whitespace at beginning and end of transcript if cdx < num_leading: pass @@ -187,7 +187,7 @@ def align( # index for placeholder clean_wdx.append(wdx) - + punkt_param = PunktParameters() punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) sentence_splitter = PunktSentenceTokenizer(punkt_param) @@ -199,12 +199,12 @@ def align( "clean_wdx": clean_wdx, "sentence_spans": sentence_spans } - + aligned_segments: List[SingleAlignedSegment] = [] - + # 2. Get prediction matrix from alignment model & align for sdx, segment in enumerate(transcript): - + t1 = segment["start"] t2 = segment["end"] text = segment["text"] @@ -247,7 +247,7 @@ def align( ) else: lengths = None - + with torch.inference_mode(): if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device), lengths=lengths) @@ -304,7 +304,7 @@ def align( word_idx += 1 elif cdx == len(text) - 1 or text[cdx+1] == " ": word_idx += 1 - + char_segments_arr = pd.DataFrame(char_segments_arr) aligned_subsegments = [] @@ -333,7 +333,7 @@ def align( word_end = word_chars["end"].max() word_score = round(word_chars["score"].mean(), 3) - # -1 indicates unalignable + # -1 indicates unalignable word_segment = {"word": word_text} if not np.isnan(word_start): @@ -344,7 +344,7 @@ def align( word_segment["score"] = word_score sentence_words.append(word_segment) - + aligned_subsegments.append({ "text": sentence_text, "start": sentence_start,