Add OpenAI-compatible API and Docker deployment
- Add FastAPI-based API in whisperx/api/ - Implement transcription endpoint compatible with OpenAI - Added Dockerfile and docker-compose.yml for easy deployment - Updated README with Docker instructions - Added new script whisperx-serve for running the API
This commit is contained in:
0
whisperx/api/__init__.py
Normal file
0
whisperx/api/__init__.py
Normal file
86
whisperx/api/main.py
Normal file
86
whisperx/api/main.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import tempfile
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import torch
|
||||
import whisperx
|
||||
from whisperx.schema import TranscriptionResult
|
||||
|
||||
|
||||
model = None
|
||||
align_model_metadata = None
|
||||
|
||||
|
||||
def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"):
|
||||
global model, align_model_metadata
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
||||
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
||||
# For alignment, load the metadata
|
||||
align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
|
||||
print("Model loaded and ready.")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Load the model at startup
|
||||
model_name = os.getenv("WHISPERX_MODEL", "turbo")
|
||||
device = os.getenv("WHISPERX_DEVICE", "cuda")
|
||||
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16")
|
||||
load_transcription_model(model_name, device, compute_type)
|
||||
yield
|
||||
# Cleanup if needed
|
||||
print("Shutting down API")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="WhisperX API",
|
||||
description="OpenAI-compatible API for speech transcription using WhisperX",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "WhisperX API is running"}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this
|
||||
language: str = Form(None),
|
||||
response_format: str = Form("json"),
|
||||
temperature: float = Form(0.0), # We don't use temperature for now
|
||||
prompt: str = Form(None) # Not used
|
||||
):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')):
|
||||
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
||||
|
||||
# Save uploaded file to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
||||
temp_file.write(await file.read())
|
||||
audio_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Load audio
|
||||
audio = whisperx.load_audio(audio_path)
|
||||
|
||||
# Transcribe
|
||||
result = model(audio, batch_size=16, language=language)
|
||||
text = " ".join([segment['text'] for segment in result["segments"]]).strip()
|
||||
|
||||
# If we have segments, might want to return more info, but for OpenAI compatibility, just text
|
||||
|
||||
return JSONResponse({"text": text})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||
finally:
|
||||
os.unlink(audio_path)
|
||||
16
whisperx/api/serve.py
Normal file
16
whisperx/api/serve.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import uvicorn
|
||||
|
||||
|
||||
def serve(host: str = "0.0.0.0", port: int = 8000, workers: int = 1):
|
||||
"""Run the WhisperX API server"""
|
||||
uvicorn.run(
|
||||
"whisperx.api.main:app",
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
reload=False # No reload for production
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
Reference in New Issue
Block a user