Files
simple-asr-server/app.py
red 228f67d07f - Поменял всё снова на Whisper
- Добавил предзагрузку модели по-умолчанию
- Убрал метрики
- Добавил скрипты для старта
- Для отчаянных Dockerfile для сборки контейнера на 70ГБ
2025-08-20 23:18:02 +09:00

189 lines
6.5 KiB
Python

import logging
import os
import subprocess
import tempfile
from typing import Optional
from enum import Enum
import whisper
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File, Query
from fastapi.security import APIKeyHeader
from fastapi.responses import PlainTextResponse
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(title="Simple ASR Server", description="Audio transcription API using Whisper")
# API key header
api_key_header = APIKeyHeader(name="x-api-key")
# Global model variable
default_model = None
class OutputFormat(str, Enum):
plaintext = "plaintext"
simple = "simple"
json = "json"
def get_keys():
keys_file = os.getenv("KEYS_FILE", "keys.txt")
if not os.path.exists(keys_file):
# Create a new keys file with a default key
default_key = os.urandom(32).hex()
with open(keys_file, "w") as f:
f.write(default_key + "\n")
logger.info(f"Created new keys file with default key: {default_key}")
return [default_key]
else:
# Read keys from the existing file
with open(keys_file, "r") as f:
keys = [line.strip() for line in f if line.strip()]
logger.info(f"Loaded {len(keys)} keys from file")
if not keys:
raise ValueError("No keys found in keys.txt")
return keys
def load_default_model():
"""Load the default model on startup"""
global default_model
model_name = os.getenv("DEFAULT_MODEL", "turbo")
model_download_root = os.getenv("MODEL_DOWNLOAD_ROOT", None)
logger.info(f"Loading default model: {model_name}")
try:
default_model = whisper.load_model(model_name, download_root=model_download_root, in_memory=True)
logger.info(f"Successfully loaded model: {model_name}")
except Exception as e:
logger.error(f"Failed to load default model {model_name}: {e}")
raise
def get_model(model_name: Optional[str] = None):
"""Get model - either default or load new one if specified"""
global default_model
if model_name is None:
return default_model
# If different model requested, load it
if model_name != os.getenv("DEFAULT_MODEL", "turbo"):
model_download_root = os.getenv("MODEL_DOWNLOAD_ROOT", None)
logger.info(f"Loading requested model: {model_name}")
return whisper.load_model(model_name, download_root=model_download_root)
return default_model
def convert_audio(input_path: str, output_path: str, speed: float = 1.0):
"""Convert audio to compatible format and speed up if needed."""
try:
command = [
'ffmpeg', '-i', input_path,
'-filter:a', f'atempo={speed}',
'-ar', '16000',
'-ac', '1',
'-c:a', 'pcm_s16le',
output_path,
'-y'
]
logger.debug(f"Running FFmpeg command: {' '.join(command)}")
result = subprocess.run(command, check=True, capture_output=True, text=True)
return True
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg conversion failed: {e.stderr}")
return False
@app.post("/transcribe")
async def transcribe_audio(
file: UploadFile = File(...),
token: str = Depends(api_key_header),
model_name: Optional[str] = Query(None, description="Model name to use for transcription"),
output_format: OutputFormat = Query(OutputFormat.json, description="Output format: plaintext, simple, or json"),
speedup: float = Query(1.0, ge=0.25, le=4.0, description="Speed up factor for audio (0.25-4.0)")
):
"""Transcribe audio file with configurable output format"""
# Token validation
if token not in get_keys():
logger.warning(f"Invalid token attempt: {token}")
raise HTTPException(status_code=403, detail="Forbidden")
logger.info(f"Processing file: {file.filename}, model: {model_name or 'default'}, format: {output_format}, speedup: {speedup}")
# Get model
try:
model = get_model(model_name)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
# Create temporary files
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as temp_input:
temp_input_path = temp_input.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_output:
temp_output_path = temp_output.name
try:
# Save uploaded file
with open(temp_input_path, "wb") as f:
content = await file.read()
f.write(content)
# Convert audio if speedup is not 1.0 or format needs conversion
if speedup != 1.0 or not file.filename.lower().endswith('.wav'):
logger.debug(f"Converting audio file with speedup: {speedup}")
if not convert_audio(temp_input_path, temp_output_path, speedup):
raise HTTPException(status_code=400, detail="Audio conversion failed")
audio_file_path = temp_output_path
else:
audio_file_path = temp_input_path
# Transcribe
logger.info("Starting transcription")
result = model.transcribe(audio_file_path)
# Format output based on requested format
if output_format == OutputFormat.plaintext:
return PlainTextResponse(content=result["text"], media_type="text/plain")
elif output_format == OutputFormat.simple:
return {"text": result["text"]}
else: # json format
return result
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
# Cleanup temporary files
for path in [temp_input_path, temp_output_path]:
if os.path.exists(path):
try:
os.remove(path)
except Exception as e:
logger.warning(f"Failed to remove temp file {path}: {e}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": default_model is not None}
def main():
import uvicorn
# Load default model and keys
load_default_model()
get_keys()
port = int(os.getenv("PORT", 9854))
host = os.getenv("HOST", "0.0.0.0")
uvicorn.run(app, host=host, port=port, log_level="info")
if __name__ == "__main__":
main()