- рефакторинг приложения

- снова изменения в Readme
- работа над валидацией параметров
- большая гибкость и конфигурироемость
This commit is contained in:
2025-09-06 21:05:56 +03:00
parent 3f97810f89
commit fcae47cad1
8 changed files with 475 additions and 334 deletions

389
app.py
View File

@@ -1,279 +1,176 @@
import logging
import os
import subprocess
import time
from os import getenv
from typing import Dict
import json
import whisper
from pathlib import Path
from typing import Dict, Optional, Set, Literal, List, Union
from threading import Lock
import gigaam
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File, Request, Form
from fastapi.security import APIKeyHeader
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Pydantic модель для параметров транскрибации
class TranscribeParams(BaseModel):
language: Optional[str] = Field(None, description="Язык аудио (auto-detect по умолчанию)")
task: Optional[str] = Field("transcribe", description="transcribe или translate")
temperature: Optional[float] = Field(0.0, description="Температура для генерации (0.0-1.0)")
beam_size: Optional[int] = Field(None, description="Размер луча для поиска")
best_of: Optional[int] = Field(None, description="Количество кандидатов для выбора лучшего")
compression_ratio_threshold: Optional[float] = Field(None, description="Порог сжатия для фильтрации")
logprob_threshold: Optional[float] = Field(None, description="Порог логарифмической вероятности")
no_speech_threshold: Optional[float] = Field(None, description="Порог детекции отсутствия речи")
condition_on_previous_text: Optional[bool] = Field(True, description="Использовать предыдущий текст как контекст")
initial_prompt: Optional[str] = Field(None, description="Начальная подсказка для модели")
word_timestamps: Optional[bool] = Field(False, description="Временные метки слов")
prepend_punctuations: Optional[str] = Field(None, description="Знаки препинания для добавления в начало")
append_punctuations: Optional[str] = Field(None, description="Знаки препинания для добавления в конец")
clip_timestamps: Optional[List[float]] = Field(None, description="Временные метки для обрезки аудио")
hallucination_silence_threshold: Optional[float] = Field(None, description="Порог тишины для детекции галлюцинаций")
format: Optional[Literal["json", "simple", "text"]] = Field("json", description="Формат ответа")
model = gigaam.load_model("v2_ctc", device=getenv("ASR_DEVICE"), download_root=getenv("ASR_MODELS_ROOT"))
# Глобальные переменные для модели и ключей
model = None
api_keys: Set[str] = set()
keys_lock = Lock()
keys_file_path = os.getenv("KEYS_FILE", "keys.txt")
# API key header
api_key_header = APIKeyHeader(name="x-api-key")
# Схема безопасности
api_key_header = APIKeyHeader(name="X-API-Key")
app = FastAPI(title="Whisper ASR Service", version="1.0.0")
def get_keys(): # не бейте меня за это
keys_file = "keys.txt"
if not os.path.exists(keys_file):
# Create a new keys file with a default key
default_key = os.urandom(32).hex()
with open(keys_file, "w") as f:
f.write(default_key + "\n")
logger.info(f"Created new keys file with default key: {default_key}")
return [default_key]
else:
# Read keys from the existing file
with open(keys_file, "r") as f:
keys = [line.strip() for line in f if line.strip()]
logger.info(f"Loaded {len(keys)} keys from file")
logger.debug(f"Keys: {keys}")
if not keys:
raise ValueError("No keys found in keys.txt")
return keys
def convert_audio(input_path: str, output_path: str, speed: float = 1.25):
"""
Convert audio to compatible format and speed up
"""
def load_api_keys():
"""Загружает API ключи из файла"""
global api_keys
try:
command = [
'ffmpeg', '-i', input_path,
'-filter:a', f'atempo={speed}',
'-ar', '16000',
'-ac', '1',
'-c:a', 'pcm_s16le',
output_path,
'-y'
]
logger.debug(f"Running FFmpeg command: {' '.join(command)}")
subprocess.run(command, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
return False
class TranscriptionMetrics:
def __init__(self):
self.start_time = time.time()
self.end_time = None
self.text_length = 0
self.audio_duration = 0
def stop(self, text: str, audio_duration: float):
self.end_time = time.time()
self.text_length = len(text)
self.audio_duration = audio_duration
def get_metrics(self) -> Dict[str, float]:
processing_time = self.end_time - self.start_time
return {
"processing_time_seconds": round(processing_time, 2),
"characters_per_second": round(self.text_length / processing_time, 2),
"audio_realtime_ratio": round(self.audio_duration / processing_time, 2),
"audio_duration": round(self.audio_duration, 2),
"text_length": self.text_length
}
def get_audio_duration(file_path: str) -> float:
"""Get audio duration using ffprobe"""
cmd = [
'ffprobe',
'-v', 'quiet',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
file_path
]
try:
output = subprocess.check_output(cmd).decode().strip()
return float(output)
except:
return 0.0
@app.post("/transcribe/simple")
async def transcribe_simple(
file: UploadFile = File(...),
token: str = Depends(api_key_header),
model_name: str = "turbo"
):
# Token validation
if token not in get_keys():
logger.warning(f"Invalid token attempt: {token}")
if token == "" or token is None:
raise HTTPException(status_code=401, detail="Forbidden. x-api-key header is missing or empty.")
raise HTTPException(status_code=403, detail="Forbidden. Invalid API key.")
logger.info(f"Processing file: {file.filename} with model: {model_name}")
if file.size > int(os.getenv("MAX_UPLOAD_SIZE_MB")) * 1024 * 1024:
raise HTTPException(status_code=400, detail=f'File size exceeds ${os.getenv("MAX_UPLOAD_SIZE_MB")}MB limit')
# Save uploaded file
temp_input_path = f"/tmp/input_{file.filename}"
temp_output_path = f"/tmp/converted_{file.filename}.wav"
try:
with open(temp_input_path, "wb") as f:
f.write(await file.read())
# Convert audio if needed
logger.debug("Converting audio file")
if not convert_audio(temp_input_path, temp_output_path):
raise HTTPException(status_code=400, detail="Audio conversion failed")
# Get audio duration before speed up
original_duration = get_audio_duration(temp_input_path)
# Transcribe
logger.info("Starting transcription")
if original_duration > 30:
logger.info("Audio duration > 30 seconds, using transcribe_longform")
transcription_result = model.transcribe_longform(
temp_output_path
)
if os.path.exists(keys_file_path):
with open(keys_file_path, 'r') as f:
keys = [line.strip() for line in f.readlines() if line.strip()]
with keys_lock:
api_keys = set(keys)
logger.info(f"Загружено {len(api_keys)} API ключей")
else:
logger.info("Audio duration <= 30 seconds, using transcribe")
transcription_result = model.transcribe(
temp_output_path
)
full_text = ""
for part in transcription_result:
if part["transcription"].strip() != "":
full_text += part["transcription"].strip() + " "
result = full_text
return result
logger.warning(f"Файл ключей {keys_file_path} не найден")
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
logger.error(f"Ошибка загрузки ключей: {e}")
finally:
# Cleanup temporary files
if os.path.exists(temp_input_path):
os.remove(temp_input_path)
if os.path.exists(temp_output_path):
os.remove(temp_output_path)
def load_model():
"""Загружает модель Whisper"""
global model
model_name = os.getenv("DEFAULT_MODEL", "turbo")
download_root = os.getenv("MODEL_DOWNLOAD_ROOT", "./models")
device = os.getenv("MODEL_DEVICE", "cpu")
try:
logger.info(f"Загрузка модели Whisper: {model_name}")
model = whisper.load_model(model_name, device=device, download_root=download_root)
logger.info("Модель успешно загружена")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
raise
def verify_api_key(api_key: str = Depends(api_key_header)) -> str:
"""Проверяет API ключ"""
if not api_key:
raise HTTPException(status_code=401, detail="API ключ не предоставлен")
# Перезагружаем ключи для проверки обновлений
load_api_keys()
with keys_lock:
if api_key not in api_keys:
raise HTTPException(status_code=403, detail="Неверный API ключ")
return api_key
@app.on_event("startup")
async def startup_event():
"""Инициализация при запуске"""
load_api_keys()
load_model()
@app.get("/health")
async def health_check():
"""Проверка здоровья сервиса"""
return {"status": "healthy", "model_loaded": model is not None, "current_model": str(model) if model else None}
@app.post("/transcribe")
async def transcribe_audio(
file: UploadFile = File(...),
token: str = Depends(api_key_header),
model_name: str = "turbo"
audio_file: UploadFile = File(...),
params: TranscribeParams = Depends(),
api_key: str = Depends(verify_api_key)
):
# Token validation
if token not in get_keys():
logger.warning(f"Invalid token attempt: {token}")
raise HTTPException(status_code=403, detail="Forbidden")
"""Транскрибирует аудиофайл"""
if model is None:
raise HTTPException(status_code=500, detail="Модель не загружена")
logger.info(f"Processing file: {file.filename} with model: {model_name}")
# Готовим параметры для whisper.transcribe()
whisper_params = {}
for field_name, field_value in params.dict(exclude_none=True, exclude={'format'}).items():
whisper_params[field_name] = field_value
if file.size > int(os.getenv("MAX_UPLOAD_SIZE_MB")) * 1024 * 1024:
raise HTTPException(status_code=400, detail=f'File size exceeds ${os.getenv("MAX_UPLOAD_SIZE_MB")}MB limit')
metrics = TranscriptionMetrics()
# Save uploaded file
temp_input_path = f"/tmp/input_{file.filename}"
temp_output_path = f"/tmp/converted_{file.filename}.wav"
# Формат ответа
response_format = params.format
temp_file_path = None
try:
with open(temp_input_path, "wb") as f:
f.write(await file.read())
# Сохраняем временный файл
temp_file_path = f"/tmp/{audio_file.filename}"
with open(temp_file_path, "wb") as temp_file:
content = await audio_file.read()
temp_file.write(content)
# Convert audio if needed
logger.debug("Converting audio file")
if not convert_audio(temp_input_path, temp_output_path):
raise HTTPException(status_code=400, detail="Audio conversion failed")
# Транскрибируем
logger.info(f"Транскрибация файла: {audio_file.filename} с параметрами: {whisper_params}")
result = model.transcribe(temp_file_path, **whisper_params)
# Get audio duration before speed up
original_duration = get_audio_duration(temp_input_path)
# Удаляем временный файл
os.unlink(temp_file_path)
# Transcribe
logger.info("Starting transcription")
if original_duration > 30:
logger.info("Audio duration > 30 seconds, using transcribe_longform")
cmd = [
'ffmpeg', '-i', temp_input_path,
'-filter:a', f'atempo={os.getenv("AUDIO_SPEEDUP", 1.25)}',
'-ar', '16000',
'-ac', '1',
'-c:a', 'pcm_s16le',
temp_output_path,
'-y'
]
log = subprocess.run(cmd, check=True, capture_output=True)
logger.debug(f"Running FFmpeg command: {' '.join(cmd)}")
logger.info("Audio sped up for longform transcription")
if log.stderr:
logger.error(f"FFmpeg err log: {log.stderr.decode()}")
logger.debug(f"FFmpeg log: {log.stdout.decode()}")
else:
logger.debug(f"FFmpeg log: {log.stdout.decode()}")
transcription_result = model.transcribe_longform(
temp_output_path
)
else:
logger.info("Audio duration <= 30 seconds, using transcribe")
transcription_result = model.transcribe(
temp_output_path
)
full_text = ""
for part in transcription_result:
if part["transcription"].strip() != "":
full_text += part["transcription"].strip() + " "
result = {
"transcription": transcription_result,
"text": full_text
}
# Calculate metrics
metrics.stop(full_text, original_duration)
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
# Add metrics to result
result["metrics"] = metrics.get_metrics()
return result
# Возвращаем результат в нужном формате
if response_format == 'text':
return PlainTextResponse(content=result['text'])
elif response_format == 'simple':
return {"text": result['text']}
else: # json - полный ответ по умолчанию
return result
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
logger.error(f"Ошибка транскрибации: {e}")
# Удаляем временный файл в случае ошибки
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Ошибка транскрибации: {str(e)}")
finally:
# Cleanup temporary files
if os.path.exists(temp_input_path):
os.remove(temp_input_path)
if os.path.exists(temp_output_path):
os.remove(temp_output_path)
def main():
import uvicorn
get_keys()
uvicorn.run(app, host="0.0.0.0", port=9854, log_level=os.getenv("LOG_LEVEL", "info"))
@app.post("/keys/reload")
async def reload_keys(api_key: str = Depends(verify_api_key)):
"""Перезагружает ключи из файла"""
load_api_keys()
with keys_lock:
return {"message": f"Перезагружено {len(api_keys)} ключей"}
@app.get("/keys/count")
async def get_keys_count(api_key: str = Depends(verify_api_key)):
"""Возвращает количество активных ключей"""
with keys_lock:
return {"count": len(api_keys)}
if __name__ == "__main__":
main()
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "9854"))
log_level = os.getenv("LOG_LEVEL", "info")
uvicorn.run(
"app:app",
host=host,
port=port,
log_level=log_level,
reload=False
)