Files
whisperx-rocm-api/whisperx/api/main.py
SlavaVlad 16bdf2bd00 Fix CUDA/ROCm compatibility: default to CPU for AMD GPUs
- Change default device to CPU in docker-compose and main.py
- Set compute_type=int8 for CPU inference
- Auto-detect device if env not set
2026-05-13 03:50:58 +03:00

81 lines
2.8 KiB
Python

import os
import tempfile
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
import torch
import whisperx
model = None
def load_transcription_model(model_name: str = "large-v2", device: str = None, compute_type: str = None):
global model
if device is None:
device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
if compute_type is None:
compute_type = "int8" if device == "cpu" else "float16"
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
model = whisperx.load_model(model_name, device, compute_type=compute_type)
print("Model loaded and ready.")
@asynccontextmanager
async def lifespan(app: FastAPI):
model_name = os.getenv("WHISPERX_MODEL", "large-v2")
device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "int8" if device == "cpu" else "float16")
load_transcription_model(model_name, device, compute_type)
yield
app = FastAPI(
title="WhisperX API",
description="OpenAI-compatible API for speech transcription using WhisperX",
version="1.0.0",
lifespan=lifespan
)
@app.get("/")
async def root():
return {"message": "WhisperX API is running"}
@app.post("/v1/audio/transcriptions")
async def transcribe_audio(
file: UploadFile = File(...),
model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this
language: str = Form(None),
response_format: str = Form("json"),
temperature: float = Form(0.0), # We don't use temperature for now
prompt: str = Form(None) # Not used
):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')):
raise HTTPException(status_code=400, detail="Unsupported audio format")
# Save uploaded file to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(await file.read())
audio_path = temp_file.name
try:
# Load audio
audio = whisperx.load_audio(audio_path)
# Transcribe
result = model(audio, batch_size=16, language=language)
text = " ".join([segment['text'] for segment in result["segments"]]).strip()
# If we have segments, might want to return more info, but for OpenAI compatibility, just text
return JSONResponse({"text": text})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
finally:
os.unlink(audio_path)