Add jr, sr, and ph.d to punkt abbreviations
This commit is contained in:
@@ -24,7 +24,7 @@ from whisperx.types import (
|
||||
)
|
||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof', 'jr', 'sr', 'ph.d']
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@@ -124,14 +124,14 @@ def align(
|
||||
"""
|
||||
Align phoneme recognition predictions to known transcription.
|
||||
"""
|
||||
|
||||
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata["dictionary"]
|
||||
@@ -148,7 +148,7 @@ def align(
|
||||
base_progress = ((sdx + 1) / total_segments) * 100
|
||||
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||
print(f"Progress: {percent_complete:.2f}%...")
|
||||
|
||||
|
||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||
text = segment["text"]
|
||||
@@ -165,7 +165,7 @@ def align(
|
||||
# wav2vec2 models use "|" character to represent spaces
|
||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||
char_ = char_.replace(" ", "|")
|
||||
|
||||
|
||||
# ignore whitespace at beginning and end of transcript
|
||||
if cdx < num_leading:
|
||||
pass
|
||||
@@ -187,7 +187,7 @@ def align(
|
||||
# index for placeholder
|
||||
clean_wdx.append(wdx)
|
||||
|
||||
|
||||
|
||||
punkt_param = PunktParameters()
|
||||
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||
@@ -199,12 +199,12 @@ def align(
|
||||
"clean_wdx": clean_wdx,
|
||||
"sentence_spans": sentence_spans
|
||||
}
|
||||
|
||||
|
||||
aligned_segments: List[SingleAlignedSegment] = []
|
||||
|
||||
|
||||
# 2. Get prediction matrix from alignment model & align
|
||||
for sdx, segment in enumerate(transcript):
|
||||
|
||||
|
||||
t1 = segment["start"]
|
||||
t2 = segment["end"]
|
||||
text = segment["text"]
|
||||
@@ -247,7 +247,7 @@ def align(
|
||||
)
|
||||
else:
|
||||
lengths = None
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||
@@ -304,7 +304,7 @@ def align(
|
||||
word_idx += 1
|
||||
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||
word_idx += 1
|
||||
|
||||
|
||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||
|
||||
aligned_subsegments = []
|
||||
@@ -333,7 +333,7 @@ def align(
|
||||
word_end = word_chars["end"].max()
|
||||
word_score = round(word_chars["score"].mean(), 3)
|
||||
|
||||
# -1 indicates unalignable
|
||||
# -1 indicates unalignable
|
||||
word_segment = {"word": word_text}
|
||||
|
||||
if not np.isnan(word_start):
|
||||
@@ -344,7 +344,7 @@ def align(
|
||||
word_segment["score"] = word_score
|
||||
|
||||
sentence_words.append(word_segment)
|
||||
|
||||
|
||||
aligned_subsegments.append({
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
|
||||
Reference in New Issue
Block a user