diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8801910 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,27 @@ +.git +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +.pytest_cache +.coverage +htmlcov +.env +.venv +venv/ +ENV/ +env/ + +# Docker +Dockerfile* +docker-compose*.yml +.dockerignore + +# IDE +.vscode +.idea + +# OS +.DS_Store +*.log \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6c41053 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,36 @@ +# Use ROCm PyTorch base image +FROM rocm/pytorch:latest + +# Set environment variables for ROCm and Python +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV PYTHONPATH=/app +ENV HF_HOME=/app/.cache/huggingface + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + ffmpeg \ + libsndfile1 \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /app + +# Copy project files +COPY . . + +# Install Python dependencies +RUN pip install --upgrade pip +RUN pip install -e . + +# Expose port +EXPOSE 8000 + +# Set default environment variables +ENV WHISPERX_MODEL=turbo +ENV WHISPERX_DEVICE=cuda +ENV WHISPERX_COMPUTE_TYPE=float16 + +# Start the server +CMD ["python", "-m", "whisperx.api.serve"] \ No newline at end of file diff --git a/README.md b/README.md index fae1e07..966ebbd 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,46 @@ uv sync --all-extras --dev You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. +## Docker Deployment 🐳 + +For easy deployment with GPU support, use Docker Compose: + +### Prerequisites +- Docker and Docker Compose installed +- ROCm compatible GPU (AMD) or NVIDIA GPU with CUDA +- For AMD ROCm, ensure ROCm drivers are installed on host + +### Steps + +1. Clone the repository: +```bash +git clone https://github.com/m-bain/whisperX.git +cd whisperX +``` + +2. Build and run the container: +```bash +docker-compose up --build +``` + +The API will be available at `http://localhost:8000` + +### Environment Variables +- `WHISPERX_MODEL`: Model size (default: large-v2) +- `WHISPERX_DEVICE`: cuda or cpu (default: cuda) +- `WHISPERX_COMPUTE_TYPE`: float16 or float32 (default: float16) + +### API Usage +The API is compatible with OpenAI's transcription endpoint: + +```bash +curl -X POST http://localhost:8000/v1/audio/transcriptions \ + -H "Content-Type: multipart/form-data" \ + -F "file=@audio.wav" \ + -F "model=whisper-1" \ + -F "language=en" +``` + ### Speaker Diarization To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..a2aabc4 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +version: '3.8' + +services: + whisperx-api: + build: + context: . + dockerfile: Dockerfile + ports: + - "8000:8000" + environment: + - WHISPERX_MODEL=turbo + - WHISPERX_DEVICE=cuda + - WHISPERX_COMPUTE_TYPE=float16 + volumes: + # Mount Hugging Face cache if needed + - hf_cache:/app/.cache/huggingface + devices: + # Allow access to all GPUs + - /dev/kfd:/dev/kfd + - /dev/dri:/dev/dri + cap_add: + - SYS_ADMIN + security_opt: + - seccomp:unconfined + # For AMD ROCm GPUs, use device passthrough + +volumes: + hf_cache: \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 926d697..324f4f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,12 +22,16 @@ dependencies = [ "torch~=2.8.0", "torchaudio~=2.8.0", "transformers>=4.48.0", - "triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'" # only install triton on x86_64 Linux + "triton>=3.3.0; sys_platform == 'linux' and platform_machine == 'x86_64'", # only install triton on x86_64 Linux + "fastapi>=0.104.0", + "uvicorn[standard]>=0.24.0", + "python-multipart>=0.0.6", ] [project.scripts] whisperx = "whisperx.__main__:cli" +whisperx-serve = "whisperx.api.serve:serve" [build-system] requires = ["setuptools"] diff --git a/whisperx/api/__init__.py b/whisperx/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/whisperx/api/main.py b/whisperx/api/main.py new file mode 100644 index 0000000..46da119 --- /dev/null +++ b/whisperx/api/main.py @@ -0,0 +1,86 @@ +import os +import tempfile +import asyncio +from contextlib import asynccontextmanager +from fastapi import FastAPI, UploadFile, File, Form, HTTPException +from fastapi.responses import JSONResponse +import torch +import whisperx +from whisperx.schema import TranscriptionResult + + +model = None +align_model_metadata = None + + +def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"): + global model, align_model_metadata + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}") + model = whisperx.load_model(model_name, device, compute_type=compute_type) + # For alignment, load the metadata + align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF + print("Model loaded and ready.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Load the model at startup + model_name = os.getenv("WHISPERX_MODEL", "turbo") + device = os.getenv("WHISPERX_DEVICE", "cuda") + compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16") + load_transcription_model(model_name, device, compute_type) + yield + # Cleanup if needed + print("Shutting down API") + + +app = FastAPI( + title="WhisperX API", + description="OpenAI-compatible API for speech transcription using WhisperX", + version="1.0.0", + lifespan=lifespan +) + + +@app.get("/") +async def root(): + return {"message": "WhisperX API is running"} + + +@app.post("/v1/audio/transcriptions") +async def transcribe_audio( + file: UploadFile = File(...), + model_name: str = Form("whisper-1"), # OpenAI uses 'whisper-1', we ignore this + language: str = Form(None), + response_format: str = Form("json"), + temperature: float = Form(0.0), # We don't use temperature for now + prompt: str = Form(None) # Not used +): + if model is None: + raise HTTPException(status_code=500, detail="Model not loaded") + + if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.webm', '.mp4', '.mpga', '.ogg', '.opus')): + raise HTTPException(status_code=400, detail="Unsupported audio format") + + # Save uploaded file to temp file + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: + temp_file.write(await file.read()) + audio_path = temp_file.name + + try: + # Load audio + audio = whisperx.load_audio(audio_path) + + # Transcribe + result = model(audio, batch_size=16, language=language) + text = " ".join([segment['text'] for segment in result["segments"]]).strip() + + # If we have segments, might want to return more info, but for OpenAI compatibility, just text + + return JSONResponse({"text": text}) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") + finally: + os.unlink(audio_path) \ No newline at end of file diff --git a/whisperx/api/serve.py b/whisperx/api/serve.py new file mode 100644 index 0000000..86ae029 --- /dev/null +++ b/whisperx/api/serve.py @@ -0,0 +1,16 @@ +import uvicorn + + +def serve(host: str = "0.0.0.0", port: int = 8000, workers: int = 1): + """Run the WhisperX API server""" + uvicorn.run( + "whisperx.api.main:app", + host=host, + port=port, + workers=workers, + reload=False # No reload for production + ) + + +if __name__ == "__main__": + serve() \ No newline at end of file