Add jr, sr, and ph.d to punkt abbreviations
This commit is contained in:
@@ -24,7 +24,7 @@ from whisperx.types import (
|
|||||||
)
|
)
|
||||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||||
|
|
||||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof', 'jr', 'sr', 'ph.d']
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
@@ -124,14 +124,14 @@ def align(
|
|||||||
"""
|
"""
|
||||||
Align phoneme recognition predictions to known transcription.
|
Align phoneme recognition predictions to known transcription.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
audio = torch.from_numpy(audio)
|
audio = torch.from_numpy(audio)
|
||||||
if len(audio.shape) == 1:
|
if len(audio.shape) == 1:
|
||||||
audio = audio.unsqueeze(0)
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||||
|
|
||||||
model_dictionary = align_model_metadata["dictionary"]
|
model_dictionary = align_model_metadata["dictionary"]
|
||||||
@@ -148,7 +148,7 @@ def align(
|
|||||||
base_progress = ((sdx + 1) / total_segments) * 100
|
base_progress = ((sdx + 1) / total_segments) * 100
|
||||||
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
|
||||||
print(f"Progress: {percent_complete:.2f}%...")
|
print(f"Progress: {percent_complete:.2f}%...")
|
||||||
|
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
@@ -165,7 +165,7 @@ def align(
|
|||||||
# wav2vec2 models use "|" character to represent spaces
|
# wav2vec2 models use "|" character to represent spaces
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
char_ = char_.replace(" ", "|")
|
char_ = char_.replace(" ", "|")
|
||||||
|
|
||||||
# ignore whitespace at beginning and end of transcript
|
# ignore whitespace at beginning and end of transcript
|
||||||
if cdx < num_leading:
|
if cdx < num_leading:
|
||||||
pass
|
pass
|
||||||
@@ -187,7 +187,7 @@ def align(
|
|||||||
# index for placeholder
|
# index for placeholder
|
||||||
clean_wdx.append(wdx)
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
|
|
||||||
punkt_param = PunktParameters()
|
punkt_param = PunktParameters()
|
||||||
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
|
||||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||||
@@ -199,12 +199,12 @@ def align(
|
|||||||
"clean_wdx": clean_wdx,
|
"clean_wdx": clean_wdx,
|
||||||
"sentence_spans": sentence_spans
|
"sentence_spans": sentence_spans
|
||||||
}
|
}
|
||||||
|
|
||||||
aligned_segments: List[SingleAlignedSegment] = []
|
aligned_segments: List[SingleAlignedSegment] = []
|
||||||
|
|
||||||
# 2. Get prediction matrix from alignment model & align
|
# 2. Get prediction matrix from alignment model & align
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
|
|
||||||
t1 = segment["start"]
|
t1 = segment["start"]
|
||||||
t2 = segment["end"]
|
t2 = segment["end"]
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
@@ -247,7 +247,7 @@ def align(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lengths = None
|
lengths = None
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if model_type == "torchaudio":
|
if model_type == "torchaudio":
|
||||||
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||||
@@ -304,7 +304,7 @@ def align(
|
|||||||
word_idx += 1
|
word_idx += 1
|
||||||
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||||
word_idx += 1
|
word_idx += 1
|
||||||
|
|
||||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||||
|
|
||||||
aligned_subsegments = []
|
aligned_subsegments = []
|
||||||
@@ -333,7 +333,7 @@ def align(
|
|||||||
word_end = word_chars["end"].max()
|
word_end = word_chars["end"].max()
|
||||||
word_score = round(word_chars["score"].mean(), 3)
|
word_score = round(word_chars["score"].mean(), 3)
|
||||||
|
|
||||||
# -1 indicates unalignable
|
# -1 indicates unalignable
|
||||||
word_segment = {"word": word_text}
|
word_segment = {"word": word_text}
|
||||||
|
|
||||||
if not np.isnan(word_start):
|
if not np.isnan(word_start):
|
||||||
@@ -344,7 +344,7 @@ def align(
|
|||||||
word_segment["score"] = word_score
|
word_segment["score"] = word_score
|
||||||
|
|
||||||
sentence_words.append(word_segment)
|
sentence_words.append(word_segment)
|
||||||
|
|
||||||
aligned_subsegments.append({
|
aligned_subsegments.append({
|
||||||
"text": sentence_text,
|
"text": sentence_text,
|
||||||
"start": sentence_start,
|
"start": sentence_start,
|
||||||
|
|||||||
Reference in New Issue
Block a user