Add jr, sr, and ph.d to punkt abbreviations

This commit is contained in:
Alex Cannan
2025-02-18 12:25:57 -05:00
committed by Barabazs
parent 83afb81ac7
commit c7d31883bc

View File

@@ -24,7 +24,7 @@ from whisperx.types import (
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof', 'jr', 'sr', 'ph.d']
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -124,14 +124,14 @@ def align(
"""
Align phoneme recognition predictions to known transcription.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
@@ -148,7 +148,7 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
@@ -165,7 +165,7 @@ def align(
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
@@ -187,7 +187,7 @@ def align(
# index for placeholder
clean_wdx.append(wdx)
punkt_param = PunktParameters()
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
sentence_splitter = PunktSentenceTokenizer(punkt_param)
@@ -199,12 +199,12 @@ def align(
"clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
@@ -247,7 +247,7 @@ def align(
)
else:
lengths = None
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
@@ -304,7 +304,7 @@ def align(
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr)
aligned_subsegments = []
@@ -333,7 +333,7 @@ def align(
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
# -1 indicates unalignable
# -1 indicates unalignable
word_segment = {"word": word_text}
if not np.isnan(word_start):
@@ -344,7 +344,7 @@ def align(
word_segment["score"] = word_score
sentence_words.append(word_segment)
aligned_subsegments.append({
"text": sentence_text,
"start": sentence_start,