- Добавлены параметры модели в гет эндпоинт

This commit is contained in:
red
2025-08-20 23:25:05 +09:00
parent 228f67d07f
commit ce41cf4a09

74
app.py
View File

@@ -2,7 +2,7 @@ import logging
import os import os
import subprocess import subprocess
import tempfile import tempfile
from typing import Optional from typing import Optional, Union, List, Tuple
from enum import Enum from enum import Enum
import whisper import whisper
@@ -102,9 +102,23 @@ async def transcribe_audio(
token: str = Depends(api_key_header), token: str = Depends(api_key_header),
model_name: Optional[str] = Query(None, description="Model name to use for transcription"), model_name: Optional[str] = Query(None, description="Model name to use for transcription"),
output_format: OutputFormat = Query(OutputFormat.json, description="Output format: plaintext, simple, or json"), output_format: OutputFormat = Query(OutputFormat.json, description="Output format: plaintext, simple, or json"),
speedup: float = Query(1.0, ge=0.25, le=4.0, description="Speed up factor for audio (0.25-4.0)") speedup: float = Query(1.0, ge=0.25, le=4.0, description="Speed up factor for audio (0.25-4.0)"),
# Whisper model parameters
verbose: Optional[bool] = Query(None, description="Whether to print out the progress and debug messages"),
temperature: Union[float, str] = Query("0.0,0.2,0.4,0.6,0.8,1.0", description="Temperature for sampling (single float or comma-separated values)"),
compression_ratio_threshold: Optional[float] = Query(2.4, description="If the gzip compression ratio is above this value, treat as failed"),
logprob_threshold: Optional[float] = Query(-1.0, description="If the average log probability over sampled tokens is below this value, treat as failed"),
no_speech_threshold: Optional[float] = Query(0.6, description="If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below logprob_threshold, consider the segment as silent"),
condition_on_previous_text: bool = Query(True, description="If True, the previous output of the model is provided as a prompt for the next window"),
initial_prompt: Optional[str] = Query(None, description="Optional text to provide as a prompt for the first window"),
carry_initial_prompt: bool = Query(False, description="If True, the initial prompt is carried over to the next window"),
word_timestamps: bool = Query(False, description="Extract word-level timestamps using the cross-attention pattern and dynamic time warping"),
prepend_punctuations: str = Query("\"'([{-", description="If word_timestamps is True, merge these punctuation marks with the next word"),
append_punctuations: str = Query("\"'.,:;!?)]}", description="If word_timestamps is True, merge these punctuation marks with the previous word"),
clip_timestamps: Union[str, List[float]] = Query("0", description="Comma-separated list of clip timestamps to use for transcription"),
hallucination_silence_threshold: Optional[float] = Query(None, description="When word_timestamps is True, skip silent periods longer than this threshold (in seconds)"),
): ):
"""Transcribe audio file with configurable output format""" """Transcribe audio file with configurable output format and comprehensive Whisper parameters"""
# Token validation # Token validation
if token not in get_keys(): if token not in get_keys():
@@ -142,9 +156,59 @@ async def transcribe_audio(
else: else:
audio_file_path = temp_input_path audio_file_path = temp_input_path
# Prepare transcription parameters
transcribe_params = {}
# Handle temperature parameter (can be single value or tuple)
if isinstance(temperature, str) and "," in temperature:
try:
temp_values = [float(x.strip()) for x in temperature.split(",")]
transcribe_params["temperature"] = tuple(temp_values)
except ValueError:
transcribe_params["temperature"] = 0.0
else:
try:
transcribe_params["temperature"] = float(temperature)
except (ValueError, TypeError):
transcribe_params["temperature"] = 0.0
# Handle clip_timestamps parameter
if isinstance(clip_timestamps, str) and clip_timestamps != "0":
try:
if "," in clip_timestamps:
transcribe_params["clip_timestamps"] = [float(x.strip()) for x in clip_timestamps.split(",")]
else:
transcribe_params["clip_timestamps"] = clip_timestamps
except ValueError:
transcribe_params["clip_timestamps"] = "0"
else:
transcribe_params["clip_timestamps"] = clip_timestamps
# Add other parameters if they are not None
if verbose is not None:
transcribe_params["verbose"] = verbose
if compression_ratio_threshold is not None:
transcribe_params["compression_ratio_threshold"] = compression_ratio_threshold
if logprob_threshold is not None:
transcribe_params["logprob_threshold"] = logprob_threshold
if no_speech_threshold is not None:
transcribe_params["no_speech_threshold"] = no_speech_threshold
transcribe_params["condition_on_previous_text"] = condition_on_previous_text
transcribe_params["carry_initial_prompt"] = carry_initial_prompt
transcribe_params["word_timestamps"] = word_timestamps
transcribe_params["prepend_punctuations"] = prepend_punctuations
transcribe_params["append_punctuations"] = append_punctuations
if initial_prompt is not None:
transcribe_params["initial_prompt"] = initial_prompt
if hallucination_silence_threshold is not None:
transcribe_params["hallucination_silence_threshold"] = hallucination_silence_threshold
# Transcribe # Transcribe
logger.info("Starting transcription") logger.info("Starting transcription")
result = model.transcribe(audio_file_path) logger.debug(f"Transcription parameters: {transcribe_params}")
result = model.transcribe(audio_file_path, **transcribe_params)
# Format output based on requested format # Format output based on requested format
if output_format == OutputFormat.plaintext: if output_format == OutputFormat.plaintext:
@@ -170,7 +234,7 @@ async def transcribe_audio(
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint"""
return {"status": "healthy", "model_loaded": default_model is not None} return {"status": "healthy", "model_loaded": default_model is not None, "model_name": default_model.__str__()}
def main(): def main():
import uvicorn import uvicorn