import os import tempfile import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import JSONResponse import torch import whisperx model = None def load_transcription_model(model_name: str = "large-v2", device: str = None, compute_type: str = None): global model if device is None: device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") if compute_type is None: compute_type = "int8" if device == "cpu" else "float16" print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}") model = whisperx.load_model(model_name, device, compute_type=compute_type) print("Model loaded and ready.") @asynccontextmanager async def lifespan(app: FastAPI): model_name = os.getenv("WHISPERX_MODEL", "large-v2") device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "int8" if device == "cpu" else "float16") load_transcription_model(model_name, device, compute_type) yield app = FastAPI( title="WhisperX API", description="OpenAI-compatible API for speech transcription using WhisperX", version="1.0.0", lifespan=lifespan ) @app.get("/") async def root(): return {"message": "WhisperX API is running"} @app.post("/v1/audio/transcriptions") async def transcribe_audio( file: UploadFile = File(...), model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this language: str = Form(None), response_format: str = Form("json"), temperature: float = Form(0.0), # We don't use temperature for now prompt: str = Form(None) # Not used ): if model is None: raise HTTPException(status_code=500, detail="Model not loaded") if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')): raise HTTPException(status_code=400, detail="Unsupported audio format") # Save uploaded file to temp file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_file.write(await file.read()) audio_path = temp_file.name try: # Load audio audio = whisperx.load_audio(audio_path) # Transcribe result = model(audio, batch_size=16, language=language) text = " ".join([segment['text'] for segment in result["segments"]]).strip() # If we have segments, might want to return more info, but for OpenAI compatibility, just text return JSONResponse({"text": text}) except Exception as e: raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") finally: os.unlink(audio_path)