Add OpenAI-compatible API and Docker deployment
- Add FastAPI-based API in whisperx/api/ - Implement transcription endpoint compatible with OpenAI - Added Dockerfile and docker-compose.yml for easy deployment - Updated README with Docker instructions - Added new script whisperx-serve for running the API
This commit is contained in:
27
.dockerignore
Normal file
27
.dockerignore
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
.git
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
.pytest_cache
|
||||||
|
.coverage
|
||||||
|
htmlcov
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
Dockerfile*
|
||||||
|
docker-compose*.yml
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode
|
||||||
|
.idea
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
*.log
|
||||||
36
Dockerfile
Normal file
36
Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Use ROCm PyTorch base image
|
||||||
|
FROM rocm/pytorch:latest
|
||||||
|
|
||||||
|
# Set environment variables for ROCm and Python
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONPATH=/app
|
||||||
|
ENV HF_HOME=/app/.cache/huggingface
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
ffmpeg \
|
||||||
|
libsndfile1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy project files
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip install --upgrade pip
|
||||||
|
RUN pip install -e .
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Set default environment variables
|
||||||
|
ENV WHISPERX_MODEL=turbo
|
||||||
|
ENV WHISPERX_DEVICE=cuda
|
||||||
|
ENV WHISPERX_COMPUTE_TYPE=float16
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
CMD ["python", "-m", "whisperx.api.serve"]
|
||||||
40
README.md
40
README.md
@@ -111,6 +111,46 @@ uv sync --all-extras --dev
|
|||||||
|
|
||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
|
## Docker Deployment 🐳
|
||||||
|
|
||||||
|
For easy deployment with GPU support, use Docker Compose:
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Docker and Docker Compose installed
|
||||||
|
- ROCm compatible GPU (AMD) or NVIDIA GPU with CUDA
|
||||||
|
- For AMD ROCm, ensure ROCm drivers are installed on host
|
||||||
|
|
||||||
|
### Steps
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/m-bain/whisperX.git
|
||||||
|
cd whisperX
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Build and run the container:
|
||||||
|
```bash
|
||||||
|
docker-compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
The API will be available at `http://localhost:8000`
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
- `WHISPERX_MODEL`: Model size (default: large-v2)
|
||||||
|
- `WHISPERX_DEVICE`: cuda or cpu (default: cuda)
|
||||||
|
- `WHISPERX_COMPUTE_TYPE`: float16 or float32 (default: float16)
|
||||||
|
|
||||||
|
### API Usage
|
||||||
|
The API is compatible with OpenAI's transcription endpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/v1/audio/transcriptions \
|
||||||
|
-H "Content-Type: multipart/form-data" \
|
||||||
|
-F "file=@audio.wav" \
|
||||||
|
-F "model=whisper-1" \
|
||||||
|
-F "language=en"
|
||||||
|
```
|
||||||
|
|
||||||
### Speaker Diarization
|
### Speaker Diarization
|
||||||
|
|
||||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||||
|
|||||||
28
docker-compose.yml
Normal file
28
docker-compose.yml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
whisperx-api:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
environment:
|
||||||
|
- WHISPERX_MODEL=turbo
|
||||||
|
- WHISPERX_DEVICE=cuda
|
||||||
|
- WHISPERX_COMPUTE_TYPE=float16
|
||||||
|
volumes:
|
||||||
|
# Mount Hugging Face cache if needed
|
||||||
|
- hf_cache:/app/.cache/huggingface
|
||||||
|
devices:
|
||||||
|
# Allow access to all GPUs
|
||||||
|
- /dev/kfd:/dev/kfd
|
||||||
|
- /dev/dri:/dev/dri
|
||||||
|
cap_add:
|
||||||
|
- SYS_ADMIN
|
||||||
|
security_opt:
|
||||||
|
- seccomp:unconfined
|
||||||
|
# For AMD ROCm GPUs, use device passthrough
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
hf_cache:
|
||||||
@@ -22,12 +22,16 @@ dependencies = [
|
|||||||
"torch~=2.8.0",
|
"torch~=2.8.0",
|
||||||
"torchaudio~=2.8.0",
|
"torchaudio~=2.8.0",
|
||||||
"transformers>=4.48.0",
|
"transformers>=4.48.0",
|
||||||
"triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'" # only install triton on x86_64 Linux
|
"triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'", # only install triton on x86_64 Linux
|
||||||
|
"fastapi>=0.104.0",
|
||||||
|
"uvicorn[standard]>=0.24.0",
|
||||||
|
"python-multipart>=0.0.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
whisperx = "whisperx.__main__:cli"
|
whisperx = "whisperx.__main__:cli"
|
||||||
|
whisperx-serve = "whisperx.api.serve:serve"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools"]
|
requires = ["setuptools"]
|
||||||
|
|||||||
0
whisperx/api/__init__.py
Normal file
0
whisperx/api/__init__.py
Normal file
86
whisperx/api/main.py
Normal file
86
whisperx/api/main.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import torch
|
||||||
|
import whisperx
|
||||||
|
from whisperx.schema import TranscriptionResult
|
||||||
|
|
||||||
|
|
||||||
|
model = None
|
||||||
|
align_model_metadata = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"):
|
||||||
|
global model, align_model_metadata
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
||||||
|
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
||||||
|
# For alignment, load the metadata
|
||||||
|
align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
|
||||||
|
print("Model loaded and ready.")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Load the model at startup
|
||||||
|
model_name = os.getenv("WHISPERX_MODEL", "turbo")
|
||||||
|
device = os.getenv("WHISPERX_DEVICE", "cuda")
|
||||||
|
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16")
|
||||||
|
load_transcription_model(model_name, device, compute_type)
|
||||||
|
yield
|
||||||
|
# Cleanup if needed
|
||||||
|
print("Shutting down API")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="WhisperX API",
|
||||||
|
description="OpenAI-compatible API for speech transcription using WhisperX",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "WhisperX API is running"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions")
|
||||||
|
async def transcribe_audio(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this
|
||||||
|
language: str = Form(None),
|
||||||
|
response_format: str = Form("json"),
|
||||||
|
temperature: float = Form(0.0), # We don't use temperature for now
|
||||||
|
prompt: str = Form(None) # Not used
|
||||||
|
):
|
||||||
|
if model is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||||
|
|
||||||
|
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')):
|
||||||
|
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
||||||
|
|
||||||
|
# Save uploaded file to temp file
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
||||||
|
temp_file.write(await file.read())
|
||||||
|
audio_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load audio
|
||||||
|
audio = whisperx.load_audio(audio_path)
|
||||||
|
|
||||||
|
# Transcribe
|
||||||
|
result = model(audio, batch_size=16, language=language)
|
||||||
|
text = " ".join([segment['text'] for segment in result["segments"]]).strip()
|
||||||
|
|
||||||
|
# If we have segments, might want to return more info, but for OpenAI compatibility, just text
|
||||||
|
|
||||||
|
return JSONResponse({"text": text})
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||||
|
finally:
|
||||||
|
os.unlink(audio_path)
|
||||||
16
whisperx/api/serve.py
Normal file
16
whisperx/api/serve.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
def serve(host: str = "0.0.0.0", port: int = 8000, workers: int = 1):
|
||||||
|
"""Run the WhisperX API server"""
|
||||||
|
uvicorn.run(
|
||||||
|
"whisperx.api.main:app",
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
workers=workers,
|
||||||
|
reload=False # No reload for production
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve()
|
||||||
Reference in New Issue
Block a user