- Change default device to CPU in docker-compose and main.py - Set compute_type=int8 for CPU inference - Auto-detect device if env not set
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
import os
|
|
import tempfile
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
import torch
|
|
import whisperx
|
|
|
|
|
|
model = None
|
|
|
|
|
|
def load_transcription_model(model_name: str = "large-v2", device: str = None, compute_type: str = None):
|
|
global model
|
|
if device is None:
|
|
device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
|
if compute_type is None:
|
|
compute_type = "int8" if device == "cpu" else "float16"
|
|
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
|
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
|
print("Model loaded and ready.")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
model_name = os.getenv("WHISPERX_MODEL", "large-v2")
|
|
device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
|
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "int8" if device == "cpu" else "float16")
|
|
load_transcription_model(model_name, device, compute_type)
|
|
yield
|
|
|
|
|
|
app = FastAPI(
|
|
title="WhisperX API",
|
|
description="OpenAI-compatible API for speech transcription using WhisperX",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "WhisperX API is running"}
|
|
|
|
|
|
@app.post("/v1/audio/transcriptions")
|
|
async def transcribe_audio(
|
|
file: UploadFile = File(...),
|
|
model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this
|
|
language: str = Form(None),
|
|
response_format: str = Form("json"),
|
|
temperature: float = Form(0.0), # We don't use temperature for now
|
|
prompt: str = Form(None) # Not used
|
|
):
|
|
if model is None:
|
|
raise HTTPException(status_code=500, detail="Model not loaded")
|
|
|
|
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')):
|
|
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
|
|
|
# Save uploaded file to temp file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
|
temp_file.write(await file.read())
|
|
audio_path = temp_file.name
|
|
|
|
try:
|
|
# Load audio
|
|
audio = whisperx.load_audio(audio_path)
|
|
|
|
# Transcribe
|
|
result = model(audio, batch_size=16, language=language)
|
|
text = " ".join([segment['text'] for segment in result["segments"]]).strip()
|
|
|
|
# If we have segments, might want to return more info, but for OpenAI compatibility, just text
|
|
|
|
return JSONResponse({"text": text})
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
|
finally:
|
|
os.unlink(audio_path) |