Fix CUDA/ROCm compatibility: default to CPU for AMD GPUs

- Change default device to CPU in docker-compose and main.py
- Set compute_type=int8 for CPU inference
- Auto-detect device if env not set
This commit is contained in:
2026-05-13 03:50:58 +03:00
parent fc87497df9
commit 16bdf2bd00
3 changed files with 11 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
# Use ROCm PyTorch base image with compatible PyTorch # Use ROCm PyTorch base image with compatible PyTorch
FROM rocm/pytorch:2.4.0-rocm6.0-cxx11-ubuntu22.04 FROM rocm/pytorch:rocm7.2.3_ubuntu22.04_py3.10_pytorch_release_2.10.0
# Set environment variables for ROCm and Python # Set environment variables for ROCm and Python
ENV PYTHONDONTWRITEBYTECODE=1 ENV PYTHONDONTWRITEBYTECODE=1

View File

@@ -9,8 +9,8 @@ services:
- "8000:8000" - "8000:8000"
environment: environment:
- WHISPERX_MODEL=large-v2 - WHISPERX_MODEL=large-v2
- WHISPERX_DEVICE=cuda - WHISPERX_DEVICE=cpu
- WHISPERX_COMPUTE_TYPE=float16 - WHISPERX_COMPUTE_TYPE=int8
- HF_HUB_DISABLE_TELEMETRY=1 - HF_HUB_DISABLE_TELEMETRY=1
volumes: volumes:
# Mount Hugging Face cache if needed # Mount Hugging Face cache if needed

View File

@@ -6,34 +6,29 @@ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import torch import torch
import whisperx import whisperx
from whisperx.schema import TranscriptionResult
model = None model = None
align_model_metadata = None
def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"): def load_transcription_model(model_name: str = "large-v2", device: str = None, compute_type: str = None):
global model, align_model_metadata global model
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
if compute_type is None:
compute_type = "int8" if device == "cpu" else "float16"
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}") print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
model = whisperx.load_model(model_name, device, compute_type=compute_type) model = whisperx.load_model(model_name, device, compute_type=compute_type)
# For alignment, load the metadata
align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
print("Model loaded and ready.") print("Model loaded and ready.")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Load the model at startup model_name = os.getenv("WHISPERX_MODEL", "large-v2")
model_name = os.getenv("WHISPERX_MODEL", "turbo") device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
device = os.getenv("WHISPERX_DEVICE", "cuda") compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "int8" if device == "cpu" else "float16")
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16")
load_transcription_model(model_name, device, compute_type) load_transcription_model(model_name, device, compute_type)
yield yield
# Cleanup if needed
print("Shutting down API")
app = FastAPI( app = FastAPI(