feat: add centralized logging to replace ad-hoc print statements (#1254)
* feat: add logging utility functions * feat: add logging setup and log level argument to CLI * feat: integrate logging across modules
This commit is contained in:
@@ -29,3 +29,29 @@ def load_audio(*args, **kwargs):
|
|||||||
def assign_word_speakers(*args, **kwargs):
|
def assign_word_speakers(*args, **kwargs):
|
||||||
diarize = _lazy_import("diarize")
|
diarize = _lazy_import("diarize")
|
||||||
return diarize.assign_word_speakers(*args, **kwargs)
|
return diarize.assign_word_speakers(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Configure logging for WhisperX.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (debug, info, warning, error, critical). Default: warning
|
||||||
|
log_file: Optional path to log file. If None, logs only to console.
|
||||||
|
"""
|
||||||
|
logging_module = _lazy_import("log_utils")
|
||||||
|
return logging_module.setup_logging(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Get a logger instance for the given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name (typically __name__ from calling module)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger instance configured with WhisperX settings
|
||||||
|
"""
|
||||||
|
logging_module = _lazy_import("log_utils")
|
||||||
|
return logging_module.get_logger(*args, **kwargs)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import torch
|
|||||||
|
|
||||||
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
|
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
|
||||||
optional_int, str2bool)
|
optional_int, str2bool)
|
||||||
|
from whisperx.log_utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
@@ -23,6 +24,7 @@ def cli():
|
|||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||||
@@ -80,6 +82,16 @@ def cli():
|
|||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
|
||||||
|
log_level = args.get("log_level")
|
||||||
|
verbose = args.get("verbose")
|
||||||
|
|
||||||
|
if log_level is not None:
|
||||||
|
setup_logging(level=log_level)
|
||||||
|
elif verbose:
|
||||||
|
setup_logging(level="info")
|
||||||
|
else:
|
||||||
|
setup_logging(level="warning")
|
||||||
|
|
||||||
from whisperx.transcribe import transcribe_task
|
from whisperx.transcribe import transcribe_task
|
||||||
|
|
||||||
transcribe_task(args, parser)
|
transcribe_task(args, parser)
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ from whisperx.schema import (
|
|||||||
)
|
)
|
||||||
import nltk
|
import nltk
|
||||||
from nltk.data import load as nltk_load
|
from nltk.data import load as nltk_load
|
||||||
|
from whisperx.log_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
@@ -81,8 +84,9 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]
|
|||||||
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
||||||
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
||||||
else:
|
else:
|
||||||
print(f"There is no default alignment model set for this language ({language_code}).\
|
logger.error(f"No default alignment model for language: {language_code}. "
|
||||||
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
|
||||||
|
f"then pass the model name via --align_model [MODEL_NAME]")
|
||||||
raise ValueError(f"No default align-model for language: {language_code}")
|
raise ValueError(f"No default align-model for language: {language_code}")
|
||||||
|
|
||||||
if model_name in torchaudio.pipelines.__all__:
|
if model_name in torchaudio.pipelines.__all__:
|
||||||
@@ -223,12 +227,12 @@ def align(
|
|||||||
|
|
||||||
# check we can align
|
# check we can align
|
||||||
if len(segment_data[sdx]["clean_char"]) == 0:
|
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if t1 >= MAX_DURATION:
|
if t1 >= MAX_DURATION:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -270,7 +274,7 @@ def align(
|
|||||||
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
|||||||
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
from whisperx.schema import SingleSegment, TranscriptionResult
|
from whisperx.schema import SingleSegment, TranscriptionResult
|
||||||
from whisperx.vads import Vad, Silero, Pyannote
|
from whisperx.vads import Vad, Silero, Pyannote
|
||||||
|
from whisperx.log_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_numeral_symbol_tokens(tokenizer):
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
@@ -247,7 +250,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
if self.suppress_numerals:
|
if self.suppress_numerals:
|
||||||
previous_suppress_tokens = self.options.suppress_tokens
|
previous_suppress_tokens = self.options.suppress_tokens
|
||||||
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||||
print(f"Suppressing numeral and symbol tokens")
|
logger.info("Suppressing numeral and symbol tokens")
|
||||||
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
||||||
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
||||||
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
||||||
@@ -285,7 +288,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
def detect_language(self, audio: np.ndarray) -> str:
|
def detect_language(self, audio: np.ndarray) -> str:
|
||||||
if audio.shape[0] < N_SAMPLES:
|
if audio.shape[0] < N_SAMPLES:
|
||||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
|
||||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||||
@@ -294,7 +297,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
results = self.model.model.detect_language(encoder_output)
|
results = self.model.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio")
|
||||||
return language
|
return language
|
||||||
|
|
||||||
|
|
||||||
@@ -344,7 +347,7 @@ def load_model(
|
|||||||
if language is not None:
|
if language is not None:
|
||||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||||
else:
|
else:
|
||||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|
||||||
default_asr_options = {
|
default_asr_options = {
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ import torch
|
|||||||
|
|
||||||
from whisperx.audio import load_audio, SAMPLE_RATE
|
from whisperx.audio import load_audio, SAMPLE_RATE
|
||||||
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult
|
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult
|
||||||
|
from whisperx.log_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
@@ -18,6 +21,7 @@ class DiarizationPipeline:
|
|||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
||||||
|
logger.info(f"Loading diarization model: {model_config}")
|
||||||
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
67
whisperx/log_utils.py
Normal file
67
whisperx/log_utils.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
level: str = "info",
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Configure logging for WhisperX.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (debug, info, warning, error, critical). Default: info
|
||||||
|
log_file: Optional path to log file. If None, logs only to console.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger("whisperx")
|
||||||
|
|
||||||
|
logger.handlers.clear()
|
||||||
|
|
||||||
|
try:
|
||||||
|
log_level = getattr(logging, level.upper())
|
||||||
|
except AttributeError:
|
||||||
|
log_level = logging.WARNING
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)
|
||||||
|
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(log_level)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
if log_file:
|
||||||
|
try:
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_handler.setLevel(log_level)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
except (OSError) as e:
|
||||||
|
logger.warning(f"Failed to create log file '{log_file}': {e}")
|
||||||
|
logger.warning("Continuing with console logging only")
|
||||||
|
|
||||||
|
# Don't propagate to root logger to avoid duplicate messages
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Get a logger instance for the given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name (typically __name__ from calling module)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger instance configured with WhisperX settings
|
||||||
|
"""
|
||||||
|
whisperx_logger = logging.getLogger("whisperx")
|
||||||
|
if not whisperx_logger.handlers:
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
logger_name = "whisperx" if name == "__main__" else name
|
||||||
|
return logging.getLogger(logger_name)
|
||||||
@@ -12,6 +12,9 @@ from whisperx.audio import load_audio
|
|||||||
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||||
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
|
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
|
||||||
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
||||||
|
from whisperx.log_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||||
@@ -142,7 +145,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
# >> VAD & ASR
|
# >> VAD & ASR
|
||||||
print(">>Performing transcription...")
|
logger.info("Performing transcription...")
|
||||||
result: TranscriptionResult = model.transcribe(
|
result: TranscriptionResult = model.transcribe(
|
||||||
audio,
|
audio,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -175,13 +178,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|||||||
if align_model is not None and len(result["segments"]) > 0:
|
if align_model is not None and len(result["segments"]) > 0:
|
||||||
if result.get("language", "en") != align_metadata["language"]:
|
if result.get("language", "en") != align_metadata["language"]:
|
||||||
# load new language
|
# load new language
|
||||||
print(
|
logger.info(
|
||||||
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
||||||
)
|
)
|
||||||
align_model, align_metadata = load_align_model(
|
align_model, align_metadata = load_align_model(
|
||||||
result["language"], device
|
result["language"], device
|
||||||
)
|
)
|
||||||
print(">>Performing alignment...")
|
logger.info("Performing alignment...")
|
||||||
result: AlignedTranscriptionResult = align(
|
result: AlignedTranscriptionResult = align(
|
||||||
result["segments"],
|
result["segments"],
|
||||||
align_model,
|
align_model,
|
||||||
@@ -203,12 +206,12 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
|||||||
# >> Diarize
|
# >> Diarize
|
||||||
if diarize:
|
if diarize:
|
||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print(
|
logger.warning(
|
||||||
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
"No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model"
|
||||||
)
|
)
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
print(">>Performing diarization...")
|
logger.info("Performing diarization...")
|
||||||
print(">>Using model:", diarize_model_name)
|
logger.info(f"Using model: {diarize_model_name}")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ from pyannote.core import Segment
|
|||||||
|
|
||||||
from whisperx.diarize import Segment as SegmentX
|
from whisperx.diarize import Segment as SegmentX
|
||||||
from whisperx.vads.vad import Vad
|
from whisperx.vads.vad import Vad
|
||||||
|
from whisperx.log_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||||
@@ -232,7 +235,7 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
|||||||
class Pyannote(Vad):
|
class Pyannote(Vad):
|
||||||
|
|
||||||
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||||
print(">>Performing voice activity detection using Pyannote...")
|
logger.info("Performing voice activity detection using Pyannote...")
|
||||||
super().__init__(kwargs['vad_onset'])
|
super().__init__(kwargs['vad_onset'])
|
||||||
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
||||||
|
|
||||||
@@ -257,7 +260,7 @@ class Pyannote(Vad):
|
|||||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
|
|
||||||
if len(segments_list) == 0:
|
if len(segments_list) == 0:
|
||||||
print("No active speech found in audio")
|
logger.warning("No active speech found in audio")
|
||||||
return []
|
return []
|
||||||
assert segments_list, "segments_list is empty."
|
assert segments_list, "segments_list is empty."
|
||||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import torch
|
|||||||
|
|
||||||
from whisperx.diarize import Segment as SegmentX
|
from whisperx.diarize import Segment as SegmentX
|
||||||
from whisperx.vads.vad import Vad
|
from whisperx.vads.vad import Vad
|
||||||
|
from whisperx.log_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
AudioFile = Union[Text, Path, IOBase, Mapping]
|
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||||
|
|
||||||
@@ -15,7 +18,7 @@ AudioFile = Union[Text, Path, IOBase, Mapping]
|
|||||||
class Silero(Vad):
|
class Silero(Vad):
|
||||||
# check again default values
|
# check again default values
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
print(">>Performing voice activity detection using Silero...")
|
logger.info("Performing voice activity detection using Silero...")
|
||||||
super().__init__(kwargs['vad_onset'])
|
super().__init__(kwargs['vad_onset'])
|
||||||
|
|
||||||
self.vad_onset = kwargs['vad_onset']
|
self.vad_onset = kwargs['vad_onset']
|
||||||
@@ -60,7 +63,7 @@ class Silero(Vad):
|
|||||||
):
|
):
|
||||||
assert chunk_size > 0
|
assert chunk_size > 0
|
||||||
if len(segments_list) == 0:
|
if len(segments_list) == 0:
|
||||||
print("No active speech found in audio")
|
logger.warning("No active speech found in audio")
|
||||||
return []
|
return []
|
||||||
assert segments_list, "segments_list is empty."
|
assert segments_list, "segments_list is empty."
|
||||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
||||||
|
|||||||
Reference in New Issue
Block a user