From a51ae7a81a67a33ad0d7de6e0f5654c92954e5d2 Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Fri, 10 Oct 2025 08:41:06 +0200 Subject: [PATCH] feat: add centralized logging to replace ad-hoc print statements (#1254) * feat: add logging utility functions * feat: add logging setup and log level argument to CLI * feat: integrate logging across modules --- whisperx/__init__.py | 26 +++++++++++++++ whisperx/__main__.py | 12 +++++++ whisperx/alignment.py | 14 +++++--- whisperx/asr.py | 11 ++++--- whisperx/diarize.py | 4 +++ whisperx/log_utils.py | 67 +++++++++++++++++++++++++++++++++++++++ whisperx/transcribe.py | 17 ++++++---- whisperx/vads/pyannote.py | 7 ++-- whisperx/vads/silero.py | 7 ++-- 9 files changed, 145 insertions(+), 20 deletions(-) create mode 100644 whisperx/log_utils.py diff --git a/whisperx/__init__.py b/whisperx/__init__.py index ace17eb..b8f93fe 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -29,3 +29,29 @@ def load_audio(*args, **kwargs): def assign_word_speakers(*args, **kwargs): diarize = _lazy_import("diarize") return diarize.assign_word_speakers(*args, **kwargs) + + +def setup_logging(*args, **kwargs): + """ + Configure logging for WhisperX. + + Args: + level: Logging level (debug, info, warning, error, critical). Default: warning + log_file: Optional path to log file. If None, logs only to console. + """ + logging_module = _lazy_import("log_utils") + return logging_module.setup_logging(*args, **kwargs) + + +def get_logger(*args, **kwargs): + """ + Get a logger instance for the given module. + + Args: + name: Logger name (typically __name__ from calling module) + + Returns: + Logger instance configured with WhisperX settings + """ + logging_module = _lazy_import("log_utils") + return logging_module.get_logger(*args, **kwargs) diff --git a/whisperx/__main__.py b/whisperx/__main__.py index d1fd16d..5102bc0 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -6,6 +6,7 @@ import torch from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float, optional_int, str2bool) +from whisperx.log_utils import setup_logging def cli(): @@ -23,6 +24,7 @@ def cli(): parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") + parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") @@ -80,6 +82,16 @@ def cli(): args = parser.parse_args().__dict__ + log_level = args.get("log_level") + verbose = args.get("verbose") + + if log_level is not None: + setup_logging(level=log_level) + elif verbose: + setup_logging(level="info") + else: + setup_logging(level="warning") + from whisperx.transcribe import transcribe_task transcribe_task(args, parser) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 7637683..9034826 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -24,6 +24,9 @@ from whisperx.schema import ( ) import nltk from nltk.data import load as nltk_load +from whisperx.log_utils import get_logger + +logger = get_logger(__name__) LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -81,8 +84,9 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str] elif language_code in DEFAULT_ALIGN_MODELS_HF: model_name = DEFAULT_ALIGN_MODELS_HF[language_code] else: - print(f"There is no default alignment model set for this language ({language_code}).\ - Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") + logger.error(f"No default alignment model for language: {language_code}. " + f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, " + f"then pass the model name via --align_model [MODEL_NAME]") raise ValueError(f"No default align-model for language: {language_code}") if model_name in torchaudio.pipelines.__all__: @@ -223,12 +227,12 @@ def align( # check we can align if len(segment_data[sdx]["clean_char"]) == 0: - print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') + logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original') aligned_segments.append(aligned_seg) continue if t1 >= MAX_DURATION: - print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') + logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping') aligned_segments.append(aligned_seg) continue @@ -270,7 +274,7 @@ def align( path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) if path is None: - print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') + logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original') aligned_segments.append(aligned_seg) continue diff --git a/whisperx/asr.py b/whisperx/asr.py index 1ad3408..c35900c 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -14,6 +14,9 @@ from transformers.pipelines.pt_utils import PipelineIterator from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from whisperx.schema import SingleSegment, TranscriptionResult from whisperx.vads import Vad, Silero, Pyannote +from whisperx.log_utils import get_logger + +logger = get_logger(__name__) def find_numeral_symbol_tokens(tokenizer): @@ -247,7 +250,7 @@ class FasterWhisperPipeline(Pipeline): if self.suppress_numerals: previous_suppress_tokens = self.options.suppress_tokens numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) - print(f"Suppressing numeral and symbol tokens") + logger.info("Suppressing numeral and symbol tokens") new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens new_suppressed_tokens = list(set(new_suppressed_tokens)) self.options = replace(self.options, suppress_tokens=new_suppressed_tokens) @@ -285,7 +288,7 @@ class FasterWhisperPipeline(Pipeline): def detect_language(self, audio: np.ndarray) -> str: if audio.shape[0] < N_SAMPLES: - print("Warning: audio is shorter than 30s, language detection may be inaccurate.") + logger.warning("Audio is shorter than 30s, language detection may be inaccurate") model_n_mels = self.model.feat_kwargs.get("feature_size") segment = log_mel_spectrogram(audio[: N_SAMPLES], n_mels=model_n_mels if model_n_mels is not None else 80, @@ -294,7 +297,7 @@ class FasterWhisperPipeline(Pipeline): results = self.model.model.detect_language(encoder_output) language_token, language_probability = results[0][0] language = language_token[2:-2] - print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") + logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio") return language @@ -344,7 +347,7 @@ def load_model( if language is not None: tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: - print("No language specified, language will be first be detected for each audio file (increases inference time).") + logger.info("No language specified, language will be detected for each audio file (increases inference time)") tokenizer = None default_asr_options = { diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 6bab799..9f46b02 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -6,6 +6,9 @@ import torch from whisperx.audio import load_audio, SAMPLE_RATE from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult +from whisperx.log_utils import get_logger + +logger = get_logger(__name__) class DiarizationPipeline: @@ -18,6 +21,7 @@ class DiarizationPipeline: if isinstance(device, str): device = torch.device(device) model_config = model_name or "pyannote/speaker-diarization-3.1" + logger.info(f"Loading diarization model: {model_config}") self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device) def __call__( diff --git a/whisperx/log_utils.py b/whisperx/log_utils.py new file mode 100644 index 0000000..3015f53 --- /dev/null +++ b/whisperx/log_utils.py @@ -0,0 +1,67 @@ +import logging +import sys +from typing import Optional + +_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +def setup_logging( + level: str = "info", + log_file: Optional[str] = None, +) -> None: + """ + Configure logging for WhisperX. + + Args: + level: Logging level (debug, info, warning, error, critical). Default: info + log_file: Optional path to log file. If None, logs only to console. + """ + logger = logging.getLogger("whisperx") + + logger.handlers.clear() + + try: + log_level = getattr(logging, level.upper()) + except AttributeError: + log_level = logging.WARNING + logger.setLevel(log_level) + + formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + + if log_file: + try: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + except (OSError) as e: + logger.warning(f"Failed to create log file '{log_file}': {e}") + logger.warning("Continuing with console logging only") + + # Don't propagate to root logger to avoid duplicate messages + logger.propagate = False + + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger instance for the given module. + + Args: + name: Logger name (typically __name__ from calling module) + + Returns: + Logger instance configured with WhisperX settings + """ + whisperx_logger = logging.getLogger("whisperx") + if not whisperx_logger.handlers: + setup_logging() + + logger_name = "whisperx" if name == "__main__" else name + return logging.getLogger(logger_name) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 0b94c13..11110c6 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -12,6 +12,9 @@ from whisperx.audio import load_audio from whisperx.diarize import DiarizationPipeline, assign_word_speakers from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer +from whisperx.log_utils import get_logger + +logger = get_logger(__name__) def transcribe_task(args: dict, parser: argparse.ArgumentParser): @@ -142,7 +145,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): for audio_path in args.pop("audio"): audio = load_audio(audio_path) # >> VAD & ASR - print(">>Performing transcription...") + logger.info("Performing transcription...") result: TranscriptionResult = model.transcribe( audio, batch_size=batch_size, @@ -175,13 +178,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): if align_model is not None and len(result["segments"]) > 0: if result.get("language", "en") != align_metadata["language"]: # load new language - print( + logger.info( f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..." ) align_model, align_metadata = load_align_model( result["language"], device ) - print(">>Performing alignment...") + logger.info("Performing alignment...") result: AlignedTranscriptionResult = align( result["segments"], align_model, @@ -203,12 +206,12 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): # >> Diarize if diarize: if hf_token is None: - print( - "Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..." + logger.warning( + "No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model" ) tmp_results = results - print(">>Performing diarization...") - print(">>Using model:", diarize_model_name) + logger.info("Performing diarization...") + logger.info(f"Using model: {diarize_model_name}") results = [] diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diff --git a/whisperx/vads/pyannote.py b/whisperx/vads/pyannote.py index 62225f0..0e806bb 100644 --- a/whisperx/vads/pyannote.py +++ b/whisperx/vads/pyannote.py @@ -13,6 +13,9 @@ from pyannote.core import Segment from whisperx.diarize import Segment as SegmentX from whisperx.vads.vad import Vad +from whisperx.log_utils import get_logger + +logger = get_logger(__name__) def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): @@ -232,7 +235,7 @@ class VoiceActivitySegmentation(VoiceActivityDetection): class Pyannote(Vad): def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs): - print(">>Performing voice activity detection using Pyannote...") + logger.info("Performing voice activity detection using Pyannote...") super().__init__(kwargs['vad_onset']) self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp) @@ -257,7 +260,7 @@ class Pyannote(Vad): segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) if len(segments_list) == 0: - print("No active speech found in audio") + logger.warning("No active speech found in audio") return [] assert segments_list, "segments_list is empty." return Vad.merge_chunks(segments_list, chunk_size, onset, offset) diff --git a/whisperx/vads/silero.py b/whisperx/vads/silero.py index 88c54b8..24a64d4 100644 --- a/whisperx/vads/silero.py +++ b/whisperx/vads/silero.py @@ -8,6 +8,9 @@ import torch from whisperx.diarize import Segment as SegmentX from whisperx.vads.vad import Vad +from whisperx.log_utils import get_logger + +logger = get_logger(__name__) AudioFile = Union[Text, Path, IOBase, Mapping] @@ -15,7 +18,7 @@ AudioFile = Union[Text, Path, IOBase, Mapping] class Silero(Vad): # check again default values def __init__(self, **kwargs): - print(">>Performing voice activity detection using Silero...") + logger.info("Performing voice activity detection using Silero...") super().__init__(kwargs['vad_onset']) self.vad_onset = kwargs['vad_onset'] @@ -60,7 +63,7 @@ class Silero(Vad): ): assert chunk_size > 0 if len(segments_list) == 0: - print("No active speech found in audio") + logger.warning("No active speech found in audio") return [] assert segments_list, "segments_list is empty." return Vad.merge_chunks(segments_list, chunk_size, onset, offset)