- Add FastAPI-based API in whisperx/api/ - Implement transcription endpoint compatible with OpenAI - Added Dockerfile and docker-compose.yml for easy deployment - Updated README with Docker instructions - Added new script whisperx-serve for running the API
86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import os
|
|
import tempfile
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
import torch
|
|
import whisperx
|
|
from whisperx.schema import TranscriptionResult
|
|
|
|
|
|
model = None
|
|
align_model_metadata = None
|
|
|
|
|
|
def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"):
|
|
global model, align_model_metadata
|
|
if device is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
|
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
|
# For alignment, load the metadata
|
|
align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
|
|
print("Model loaded and ready.")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Load the model at startup
|
|
model_name = os.getenv("WHISPERX_MODEL", "turbo")
|
|
device = os.getenv("WHISPERX_DEVICE", "cuda")
|
|
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16")
|
|
load_transcription_model(model_name, device, compute_type)
|
|
yield
|
|
# Cleanup if needed
|
|
print("Shutting down API")
|
|
|
|
|
|
app = FastAPI(
|
|
title="WhisperX API",
|
|
description="OpenAI-compatible API for speech transcription using WhisperX",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "WhisperX API is running"}
|
|
|
|
|
|
@app.post("/v1/audio/transcriptions")
|
|
async def transcribe_audio(
|
|
file: UploadFile = File(...),
|
|
model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this
|
|
language: str = Form(None),
|
|
response_format: str = Form("json"),
|
|
temperature: float = Form(0.0), # We don't use temperature for now
|
|
prompt: str = Form(None) # Not used
|
|
):
|
|
if model is None:
|
|
raise HTTPException(status_code=500, detail="Model not loaded")
|
|
|
|
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')):
|
|
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
|
|
|
# Save uploaded file to temp file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
|
temp_file.write(await file.read())
|
|
audio_path = temp_file.name
|
|
|
|
try:
|
|
# Load audio
|
|
audio = whisperx.load_audio(audio_path)
|
|
|
|
# Transcribe
|
|
result = model(audio, batch_size=16, language=language)
|
|
text = " ".join([segment['text'] for segment in result["segments"]]).strip()
|
|
|
|
# If we have segments, might want to return more info, but for OpenAI compatibility, just text
|
|
|
|
return JSONResponse({"text": text})
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
|
finally:
|
|
os.unlink(audio_path) |