From 16bdf2bd003bb76d1e1ff8288e1ce039b84bff82 Mon Sep 17 00:00:00 2001 From: SlavaVlad Date: Wed, 13 May 2026 03:50:58 +0300 Subject: [PATCH] Fix CUDA/ROCm compatibility: default to CPU for AMD GPUs - Change default device to CPU in docker-compose and main.py - Set compute_type=int8 for CPU inference - Auto-detect device if env not set --- Dockerfile | 2 +- docker-compose.yml | 4 ++-- whisperx/api/main.py | 21 ++++++++------------- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/Dockerfile b/Dockerfile index 01c0c00..ccf636d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Use ROCm PyTorch base image with compatible PyTorch -FROM rocm/pytorch:2.4.0-rocm6.0-cxx11-ubuntu22.04 +FROM rocm/pytorch:rocm7.2.3_ubuntu22.04_py3.10_pytorch_release_2.10.0 # Set environment variables for ROCm and Python ENV PYTHONDONTWRITEBYTECODE=1 diff --git a/docker-compose.yml b/docker-compose.yml index 8d544b5..a32c5df 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,8 +9,8 @@ services: - "8000:8000" environment: - WHISPERX_MODEL=large-v2 - - WHISPERX_DEVICE=cuda - - WHISPERX_COMPUTE_TYPE=float16 + - WHISPERX_DEVICE=cpu + - WHISPERX_COMPUTE_TYPE=int8 - HF_HUB_DISABLE_TELEMETRY=1 volumes: # Mount Hugging Face cache if needed diff --git a/whisperx/api/main.py b/whisperx/api/main.py index 46da119..c44413b 100644 --- a/whisperx/api/main.py +++ b/whisperx/api/main.py @@ -6,34 +6,29 @@ from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import JSONResponse import torch import whisperx -from whisperx.schema import TranscriptionResult model = None -align_model_metadata = None -def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"): - global model, align_model_metadata +def load_transcription_model(model_name: str = "large-v2", device: str = None, compute_type: str = None): + global model if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") + if compute_type is None: + compute_type = "int8" if device == "cpu" else "float16" print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}") model = whisperx.load_model(model_name, device, compute_type=compute_type) - # For alignment, load the metadata - align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF print("Model loaded and ready.") @asynccontextmanager async def lifespan(app: FastAPI): - # Load the model at startup - model_name = os.getenv("WHISPERX_MODEL", "turbo") - device = os.getenv("WHISPERX_DEVICE", "cuda") - compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16") + model_name = os.getenv("WHISPERX_MODEL", "large-v2") + device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") + compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "int8" if device == "cpu" else "float16") load_transcription_model(model_name, device, compute_type) yield - # Cleanup if needed - print("Shutting down API") app = FastAPI(