Fix CUDA/ROCm compatibility: default to CPU for AMD GPUs
- Change default device to CPU in docker-compose and main.py - Set compute_type=int8 for CPU inference - Auto-detect device if env not set
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
# Use ROCm PyTorch base image with compatible PyTorch
|
# Use ROCm PyTorch base image with compatible PyTorch
|
||||||
FROM rocm/pytorch:2.4.0-rocm6.0-cxx11-ubuntu22.04
|
FROM rocm/pytorch:rocm7.2.3_ubuntu22.04_py3.10_pytorch_release_2.10.0
|
||||||
|
|
||||||
# Set environment variables for ROCm and Python
|
# Set environment variables for ROCm and Python
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ services:
|
|||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
environment:
|
environment:
|
||||||
- WHISPERX_MODEL=large-v2
|
- WHISPERX_MODEL=large-v2
|
||||||
- WHISPERX_DEVICE=cuda
|
- WHISPERX_DEVICE=cpu
|
||||||
- WHISPERX_COMPUTE_TYPE=float16
|
- WHISPERX_COMPUTE_TYPE=int8
|
||||||
- HF_HUB_DISABLE_TELEMETRY=1
|
- HF_HUB_DISABLE_TELEMETRY=1
|
||||||
volumes:
|
volumes:
|
||||||
# Mount Hugging Face cache if needed
|
# Mount Hugging Face cache if needed
|
||||||
|
|||||||
@@ -6,34 +6,29 @@ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
import torch
|
import torch
|
||||||
import whisperx
|
import whisperx
|
||||||
from whisperx.schema import TranscriptionResult
|
|
||||||
|
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
align_model_metadata = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_transcription_model(model_name: str = "turbo", device: str = None, compute_type: str = "float16"):
|
def load_transcription_model(model_name: str = "large-v2", device: str = None, compute_type: str = None):
|
||||||
global model, align_model_metadata
|
global model
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if compute_type is None:
|
||||||
|
compute_type = "int8" if device == "cpu" else "float16"
|
||||||
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
print(f"Loading WhisperX model: {model_name} on {device} with {compute_type}")
|
||||||
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
model = whisperx.load_model(model_name, device, compute_type=compute_type)
|
||||||
# For alignment, load the metadata
|
|
||||||
align_model_metadata = whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
|
|
||||||
print("Model loaded and ready.")
|
print("Model loaded and ready.")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Load the model at startup
|
model_name = os.getenv("WHISPERX_MODEL", "large-v2")
|
||||||
model_name = os.getenv("WHISPERX_MODEL", "turbo")
|
device = os.getenv("WHISPERX_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
device = os.getenv("WHISPERX_DEVICE", "cuda")
|
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "int8" if device == "cpu" else "float16")
|
||||||
compute_type = os.getenv("WHISPERX_COMPUTE_TYPE", "float16")
|
|
||||||
load_transcription_model(model_name, device, compute_type)
|
load_transcription_model(model_name, device, compute_type)
|
||||||
yield
|
yield
|
||||||
# Cleanup if needed
|
|
||||||
print("Shutting down API")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
|||||||
Reference in New Issue
Block a user