Add jr, sr, and ph.d to punkt abbreviations

This commit is contained in:
Alex Cannan
2025-02-18 12:25:57 -05:00
committed by Barabazs
parent 83afb81ac7
commit c7d31883bc

View File

@@ -24,7 +24,7 @@ from whisperx.types import (
) )
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof', 'jr', 'sr', 'ph.d']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -124,14 +124,14 @@ def align(
""" """
Align phoneme recognition predictions to known transcription. Align phoneme recognition predictions to known transcription.
""" """
if not torch.is_tensor(audio): if not torch.is_tensor(audio):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio)
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio.unsqueeze(0) audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"] model_dictionary = align_model_metadata["dictionary"]
@@ -148,7 +148,7 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100 base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...") print(f"Progress: {percent_complete:.2f}%...")
num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"] text = segment["text"]
@@ -165,7 +165,7 @@ def align(
# wav2vec2 models use "|" character to represent spaces # wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES: if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|") char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript # ignore whitespace at beginning and end of transcript
if cdx < num_leading: if cdx < num_leading:
pass pass
@@ -187,7 +187,7 @@ def align(
# index for placeholder # index for placeholder
clean_wdx.append(wdx) clean_wdx.append(wdx)
punkt_param = PunktParameters() punkt_param = PunktParameters()
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_splitter = PunktSentenceTokenizer(punkt_param)
@@ -199,12 +199,12 @@ def align(
"clean_wdx": clean_wdx, "clean_wdx": clean_wdx,
"sentence_spans": sentence_spans "sentence_spans": sentence_spans
} }
aligned_segments: List[SingleAlignedSegment] = [] aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align # 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
t1 = segment["start"] t1 = segment["start"]
t2 = segment["end"] t2 = segment["end"]
text = segment["text"] text = segment["text"]
@@ -247,7 +247,7 @@ def align(
) )
else: else:
lengths = None lengths = None
with torch.inference_mode(): with torch.inference_mode():
if model_type == "torchaudio": if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device), lengths=lengths) emissions, _ = model(waveform_segment.to(device), lengths=lengths)
@@ -304,7 +304,7 @@ def align(
word_idx += 1 word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ": elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1 word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr) char_segments_arr = pd.DataFrame(char_segments_arr)
aligned_subsegments = [] aligned_subsegments = []
@@ -333,7 +333,7 @@ def align(
word_end = word_chars["end"].max() word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3) word_score = round(word_chars["score"].mean(), 3)
# -1 indicates unalignable # -1 indicates unalignable
word_segment = {"word": word_text} word_segment = {"word": word_text}
if not np.isnan(word_start): if not np.isnan(word_start):
@@ -344,7 +344,7 @@ def align(
word_segment["score"] = word_score word_segment["score"] = word_score
sentence_words.append(word_segment) sentence_words.append(word_segment)
aligned_subsegments.append({ aligned_subsegments.append({
"text": sentence_text, "text": sentence_text,
"start": sentence_start, "start": sentence_start,