Add OpenAI-compatible API and Docker deployment
- Add FastAPI-based API in whisperx/api/ - Implement transcription endpoint compatible with OpenAI - Added Dockerfile and docker-compose.yml for easy deployment - Updated README with Docker instructions - Added new script whisperx-serve for running the API
This commit is contained in:
27
.dockerignore
Normal file
27
.dockerignore
Normal file
@@ -0,0 +1,27 @@
|
||||
.git
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.pytest_cache
|
||||
.coverage
|
||||
htmlcov
|
||||
.env
|
||||
.venv
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*.yml
|
||||
.dockerignore
|
||||
|
||||
# IDE
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
*.log
|
||||
36
Dockerfile
Normal file
36
Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
||||
# Use ROCm PyTorch base image
|
||||
FROM rocm/pytorch:latest
|
||||
|
||||
# Set environment variables for ROCm and Python
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONPATH=/app
|
||||
ENV HF_HOME=/app/.cache/huggingface
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy project files
|
||||
COPY . .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install -e .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Set default environment variables
|
||||
ENV WHISPERX_MODEL=turbo
|
||||
ENV WHISPERX_DEVICE=cuda
|
||||
ENV WHISPERX_COMPUTE_TYPE=float16
|
||||
|
||||
# Start the server
|
||||
CMD ["python", "-m", "whisperx.api.serve"]
|
||||
40
README.md
40
README.md
@@ -111,6 +111,46 @@ uv sync --all-extras --dev
|
||||
|
||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||
|
||||
## Docker Deployment 🐳
|
||||
|
||||
For easy deployment with GPU support, use Docker Compose:
|
||||
|
||||
### Prerequisites
|
||||
- Docker and Docker Compose installed
|
||||
- ROCm compatible GPU (AMD) or NVIDIA GPU with CUDA
|
||||
- For AMD ROCm, ensure ROCm drivers are installed on host
|
||||
|
||||
### Steps
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/m-bain/whisperX.git
|
||||
cd whisperX
|
||||
```
|
||||
|
||||
2. Build and run the container:
|
||||
```bash
|
||||
docker-compose up --build
|
||||
```
|
||||
|
||||
The API will be available at `http://localhost:8000`
|
||||
|
||||
### Environment Variables
|
||||
- `WHISPERX_MODEL`: Model size (default: large-v2)
|
||||
- `WHISPERX_DEVICE`: cuda or cpu (default: cuda)
|
||||
- `WHISPERX_COMPUTE_TYPE`: float16 or float32 (default: float16)
|
||||
|
||||
### API Usage
|
||||
The API is compatible with OpenAI's transcription endpoint:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/audio/transcriptions \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "file=@audio.wav" \
|
||||
-F "model=whisper-1" \
|
||||
-F "language=en"
|
||||
```
|
||||
|
||||
### Speaker Diarization
|
||||
|
||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||
|
||||
28
docker-compose.yml
Normal file
28
docker-compose.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
whisperx-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- WHISPERX_MODEL=turbo
|
||||
- WHISPERX_DEVICE=cuda
|
||||
- WHISPERX_COMPUTE_TYPE=float16
|
||||
volumes:
|
||||
# Mount Hugging Face cache if needed
|
||||
- hf_cache:/app/.cache/huggingface
|
||||
devices:
|
||||
# Allow access to all GPUs
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
cap_add:
|
||||
- SYS_ADMIN
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
# For AMD ROCm GPUs, use device passthrough
|
||||
|
||||
volumes:
|
||||
hf_cache:
|
||||
@@ -22,12 +22,16 @@ dependencies = [
|
||||
"torch~=2.8.0",
|
||||
"torchaudio~=2.8.0",
|
||||
"transformers>=4.48.0",
|
||||
"triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'" # only install triton on x86_64 Linux
|
||||
"triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'", # only install triton on x86_64 Linux
|
||||
"fastapi>=0.104.0",
|
||||
"uvicorn[standard]>=0.24.0",
|
||||
"python-multipart>=0.0.6",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
whisperx = "whisperx.__main__:cli"
|
||||
whisperx-serve = "whisperx.api.serve:serve"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
|
||||
0
whisperx/api/__init__.py
Normal file
0
whisperx/api/__init__.py
Normal file
86
whisperx/api/main.py
Normal file
86
whisperx/api/main.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import tempfile
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import torch
|
||||
import whisperx
|
||||
from whisperx.schema import TranscriptionResult
|
||||
|
||||
|
||||
model = None
|
||||
align_model_metadata = None
|
||||
|
||||
|
||||
def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"):
|
||||
global model, align_model_metadata
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
||||
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
||||
# For alignment, load the metadata
|
||||
align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
|
||||
print("Model loaded and ready.")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Load the model at startup
|
||||
model_name = os.getenv("WHISPERX_MODEL", "turbo")
|
||||
device = os.getenv("WHISPERX_DEVICE", "cuda")
|
||||
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16")
|
||||
load_transcription_model(model_name, device, compute_type)
|
||||
yield
|
||||
# Cleanup if needed
|
||||
print("Shutting down API")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="WhisperX API",
|
||||
description="OpenAI-compatible API for speech transcription using WhisperX",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "WhisperX API is running"}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this
|
||||
language: str = Form(None),
|
||||
response_format: str = Form("json"),
|
||||
temperature: float = Form(0.0), # We don't use temperature for now
|
||||
prompt: str = Form(None) # Not used
|
||||
):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')):
|
||||
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
||||
|
||||
# Save uploaded file to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
||||
temp_file.write(await file.read())
|
||||
audio_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Load audio
|
||||
audio = whisperx.load_audio(audio_path)
|
||||
|
||||
# Transcribe
|
||||
result = model(audio, batch_size=16, language=language)
|
||||
text = " ".join([segment['text'] for segment in result["segments"]]).strip()
|
||||
|
||||
# If we have segments, might want to return more info, but for OpenAI compatibility, just text
|
||||
|
||||
return JSONResponse({"text": text})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||
finally:
|
||||
os.unlink(audio_path)
|
||||
16
whisperx/api/serve.py
Normal file
16
whisperx/api/serve.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import uvicorn
|
||||
|
||||
|
||||
def serve(host: str = "0.0.0.0", port: int = 8000, workers: int = 1):
|
||||
"""Run the WhisperX API server"""
|
||||
uvicorn.run(
|
||||
"whisperx.api.main:app",
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
reload=False # No reload for production
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
Reference in New Issue
Block a user