diff --git a/app.py b/app.py index 4893492..54a955c 100644 --- a/app.py +++ b/app.py @@ -2,7 +2,7 @@ import logging import os import subprocess import tempfile -from typing import Optional +from typing import Optional, Union, List, Tuple from enum import Enum import whisper @@ -102,9 +102,23 @@ async def transcribe_audio( token: str = Depends(api_key_header), model_name: Optional[str] = Query(None, description="Model name to use for transcription"), output_format: OutputFormat = Query(OutputFormat.json, description="Output format: plaintext, simple, or json"), - speedup: float = Query(1.0, ge=0.25, le=4.0, description="Speed up factor for audio (0.25-4.0)") + speedup: float = Query(1.0, ge=0.25, le=4.0, description="Speed up factor for audio (0.25-4.0)"), + # Whisper model parameters + verbose: Optional[bool] = Query(None, description="Whether to print out the progress and debug messages"), + temperature: Union[float, str] = Query("0.0,0.2,0.4,0.6,0.8,1.0", description="Temperature for sampling (single float or comma-separated values)"), + compression_ratio_threshold: Optional[float] = Query(2.4, description="If the gzip compression ratio is above this value, treat as failed"), + logprob_threshold: Optional[float] = Query(-1.0, description="If the average log probability over sampled tokens is below this value, treat as failed"), + no_speech_threshold: Optional[float] = Query(0.6, description="If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below logprob_threshold, consider the segment as silent"), + condition_on_previous_text: bool = Query(True, description="If True, the previous output of the model is provided as a prompt for the next window"), + initial_prompt: Optional[str] = Query(None, description="Optional text to provide as a prompt for the first window"), + carry_initial_prompt: bool = Query(False, description="If True, the initial prompt is carried over to the next window"), + word_timestamps: bool = Query(False, description="Extract word-level timestamps using the cross-attention pattern and dynamic time warping"), + prepend_punctuations: str = Query("\"'([{-", description="If word_timestamps is True, merge these punctuation marks with the next word"), + append_punctuations: str = Query("\"'.,:;!?)]}", description="If word_timestamps is True, merge these punctuation marks with the previous word"), + clip_timestamps: Union[str, List[float]] = Query("0", description="Comma-separated list of clip timestamps to use for transcription"), + hallucination_silence_threshold: Optional[float] = Query(None, description="When word_timestamps is True, skip silent periods longer than this threshold (in seconds)"), ): - """Transcribe audio file with configurable output format""" + """Transcribe audio file with configurable output format and comprehensive Whisper parameters""" # Token validation if token not in get_keys(): @@ -142,9 +156,59 @@ async def transcribe_audio( else: audio_file_path = temp_input_path + # Prepare transcription parameters + transcribe_params = {} + + # Handle temperature parameter (can be single value or tuple) + if isinstance(temperature, str) and "," in temperature: + try: + temp_values = [float(x.strip()) for x in temperature.split(",")] + transcribe_params["temperature"] = tuple(temp_values) + except ValueError: + transcribe_params["temperature"] = 0.0 + else: + try: + transcribe_params["temperature"] = float(temperature) + except (ValueError, TypeError): + transcribe_params["temperature"] = 0.0 + + # Handle clip_timestamps parameter + if isinstance(clip_timestamps, str) and clip_timestamps != "0": + try: + if "," in clip_timestamps: + transcribe_params["clip_timestamps"] = [float(x.strip()) for x in clip_timestamps.split(",")] + else: + transcribe_params["clip_timestamps"] = clip_timestamps + except ValueError: + transcribe_params["clip_timestamps"] = "0" + else: + transcribe_params["clip_timestamps"] = clip_timestamps + + # Add other parameters if they are not None + if verbose is not None: + transcribe_params["verbose"] = verbose + if compression_ratio_threshold is not None: + transcribe_params["compression_ratio_threshold"] = compression_ratio_threshold + if logprob_threshold is not None: + transcribe_params["logprob_threshold"] = logprob_threshold + if no_speech_threshold is not None: + transcribe_params["no_speech_threshold"] = no_speech_threshold + + transcribe_params["condition_on_previous_text"] = condition_on_previous_text + transcribe_params["carry_initial_prompt"] = carry_initial_prompt + transcribe_params["word_timestamps"] = word_timestamps + transcribe_params["prepend_punctuations"] = prepend_punctuations + transcribe_params["append_punctuations"] = append_punctuations + + if initial_prompt is not None: + transcribe_params["initial_prompt"] = initial_prompt + if hallucination_silence_threshold is not None: + transcribe_params["hallucination_silence_threshold"] = hallucination_silence_threshold + # Transcribe logger.info("Starting transcription") - result = model.transcribe(audio_file_path) + logger.debug(f"Transcription parameters: {transcribe_params}") + result = model.transcribe(audio_file_path, **transcribe_params) # Format output based on requested format if output_format == OutputFormat.plaintext: @@ -170,7 +234,7 @@ async def transcribe_audio( @app.get("/health") async def health_check(): """Health check endpoint""" - return {"status": "healthy", "model_loaded": default_model is not None} + return {"status": "healthy", "model_loaded": default_model is not None, "model_name": default_model.__str__()} def main(): import uvicorn