Files
simple-asr-server/app.py
vladislav fcae47cad1 - рефакторинг приложения
- снова изменения в Readme
- работа над валидацией параметров
- большая гибкость и конфигурироемость
2025-09-06 21:05:56 +03:00

177 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import json
import whisper
from pathlib import Path
from typing import Dict, Optional, Set, Literal, List, Union
from threading import Lock
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File, Request, Form
from fastapi.security import APIKeyHeader
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, Field
import uvicorn
# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pydantic модель для параметров транскрибации
class TranscribeParams(BaseModel):
language: Optional[str] = Field(None, description="Язык аудио (auto-detect по умолчанию)")
task: Optional[str] = Field("transcribe", description="transcribe или translate")
temperature: Optional[float] = Field(0.0, description="Температура для генерации (0.0-1.0)")
beam_size: Optional[int] = Field(None, description="Размер луча для поиска")
best_of: Optional[int] = Field(None, description="Количество кандидатов для выбора лучшего")
compression_ratio_threshold: Optional[float] = Field(None, description="Порог сжатия для фильтрации")
logprob_threshold: Optional[float] = Field(None, description="Порог логарифмической вероятности")
no_speech_threshold: Optional[float] = Field(None, description="Порог детекции отсутствия речи")
condition_on_previous_text: Optional[bool] = Field(True, description="Использовать предыдущий текст как контекст")
initial_prompt: Optional[str] = Field(None, description="Начальная подсказка для модели")
word_timestamps: Optional[bool] = Field(False, description="Временные метки слов")
prepend_punctuations: Optional[str] = Field(None, description="Знаки препинания для добавления в начало")
append_punctuations: Optional[str] = Field(None, description="Знаки препинания для добавления в конец")
clip_timestamps: Optional[List[float]] = Field(None, description="Временные метки для обрезки аудио")
hallucination_silence_threshold: Optional[float] = Field(None, description="Порог тишины для детекции галлюцинаций")
format: Optional[Literal["json", "simple", "text"]] = Field("json", description="Формат ответа")
# Глобальные переменные для модели и ключей
model = None
api_keys: Set[str] = set()
keys_lock = Lock()
keys_file_path = os.getenv("KEYS_FILE", "keys.txt")
# Схема безопасности
api_key_header = APIKeyHeader(name="X-API-Key")
app = FastAPI(title="Whisper ASR Service", version="1.0.0")
def load_api_keys():
"""Загружает API ключи из файла"""
global api_keys
try:
if os.path.exists(keys_file_path):
with open(keys_file_path, 'r') as f:
keys = [line.strip() for line in f.readlines() if line.strip()]
with keys_lock:
api_keys = set(keys)
logger.info(f"Загружено {len(api_keys)} API ключей")
else:
logger.warning(f"Файл ключей {keys_file_path} не найден")
except Exception as e:
logger.error(f"Ошибка загрузки ключей: {e}")
def load_model():
"""Загружает модель Whisper"""
global model
model_name = os.getenv("DEFAULT_MODEL", "turbo")
download_root = os.getenv("MODEL_DOWNLOAD_ROOT", "./models")
device = os.getenv("MODEL_DEVICE", "cpu")
try:
logger.info(f"Загрузка модели Whisper: {model_name}")
model = whisper.load_model(model_name, device=device, download_root=download_root)
logger.info("Модель успешно загружена")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
raise
def verify_api_key(api_key: str = Depends(api_key_header)) -> str:
"""Проверяет API ключ"""
if not api_key:
raise HTTPException(status_code=401, detail="API ключ не предоставлен")
# Перезагружаем ключи для проверки обновлений
load_api_keys()
with keys_lock:
if api_key not in api_keys:
raise HTTPException(status_code=403, detail="Неверный API ключ")
return api_key
@app.on_event("startup")
async def startup_event():
"""Инициализация при запуске"""
load_api_keys()
load_model()
@app.get("/health")
async def health_check():
"""Проверка здоровья сервиса"""
return {"status": "healthy", "model_loaded": model is not None, "current_model": str(model) if model else None}
@app.post("/transcribe")
async def transcribe_audio(
audio_file: UploadFile = File(...),
params: TranscribeParams = Depends(),
api_key: str = Depends(verify_api_key)
):
"""Транскрибирует аудиофайл"""
if model is None:
raise HTTPException(status_code=500, detail="Модель не загружена")
# Готовим параметры для whisper.transcribe()
whisper_params = {}
for field_name, field_value in params.dict(exclude_none=True, exclude={'format'}).items():
whisper_params[field_name] = field_value
# Формат ответа
response_format = params.format
temp_file_path = None
try:
# Сохраняем временный файл
temp_file_path = f"/tmp/{audio_file.filename}"
with open(temp_file_path, "wb") as temp_file:
content = await audio_file.read()
temp_file.write(content)
# Транскрибируем
logger.info(f"Транскрибация файла: {audio_file.filename} с параметрами: {whisper_params}")
result = model.transcribe(temp_file_path, **whisper_params)
# Удаляем временный файл
os.unlink(temp_file_path)
# Возвращаем результат в нужном формате
if response_format == 'text':
return PlainTextResponse(content=result['text'])
elif response_format == 'simple':
return {"text": result['text']}
else: # json - полный ответ по умолчанию
return result
except Exception as e:
logger.error(f"Ошибка транскрибации: {e}")
# Удаляем временный файл в случае ошибки
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Ошибка транскрибации: {str(e)}")
@app.post("/keys/reload")
async def reload_keys(api_key: str = Depends(verify_api_key)):
"""Перезагружает ключи из файла"""
load_api_keys()
with keys_lock:
return {"message": f"Перезагружено {len(api_keys)} ключей"}
@app.get("/keys/count")
async def get_keys_count(api_key: str = Depends(verify_api_key)):
"""Возвращает количество активных ключей"""
with keys_lock:
return {"count": len(api_keys)}
if __name__ == "__main__":
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "9854"))
log_level = os.getenv("LOG_LEVEL", "info")
uvicorn.run(
"app:app",
host=host,
port=port,
log_level=log_level,
reload=False
)