Files
rocm-whisper-webui/api/main.py
2026-02-25 23:48:01 +03:00

149 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Whisper API - FastAPI сервис для транскрибации аудио
Поддержка ROCm (AMD GPU)
Эндпоинт: POST /transcribe
"""
import os
import tempfile
import whisper
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
import torch
app = FastAPI(title="Whisper Transcription API")
# Загрузка модели при старте
MODEL_SIZE = os.getenv("MODEL_SIZE", "base")
DEVICE_ENV = os.getenv("DEVICE", "cuda")
# Определение устройства для инференса
# PyTorch на ROCm определяет AMD GPU как CUDA-совместимое устройство
def get_device():
"""Определение доступного устройства для инференса"""
if DEVICE_ENV == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0)
print(f"Используем GPU: {device_name}")
return "cuda"
else:
print("Используем CPU")
return "cpu"
DEVICE = get_device()
print(f"Загрузка модели Whisper {MODEL_SIZE} на устройство: {DEVICE}")
# Информация о GPU
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"GPU: {gpu_name}, память: {gpu_memory:.1f} GB")
try:
model = whisper.load_model(MODEL_SIZE, device=DEVICE)
print("Модель успешно загружена")
except Exception as e:
print(f"Ошибка загрузки модели: {e}")
model = None
@app.get("/health")
async def health_check():
"""Проверка здоровья сервиса"""
gpu_info = None
if torch.cuda.is_available():
gpu_info = {
"name": torch.cuda.get_device_name(0),
"available": True
}
return {
"status": "healthy",
"model_loaded": model is not None,
"device": DEVICE,
"gpu": gpu_info
}
@app.get("/")
async def root():
"""Корневой эндпоинт с информацией об API"""
return {
"service": "Whisper API (ROCm/AMD GPU)",
"model": MODEL_SIZE,
"device": DEVICE,
"gpu_available": torch.cuda.is_available(),
"endpoints": {
"transcribe": "POST /transcribe",
"health": "GET /health"
}
}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
task: str = Form("transcribe"),
language: str = Form(None)
):
"""
Транскрибация или перевод аудио файла
Параметры:
- file: аудио/видео файл
- task: "transcribe" или "translate"
- language: код языка (опционально, например "ru", "en")
Пример curl:
curl -X POST -F "file=@test.m4a" -F "task=transcribe" http://localhost:8080/transcribe
"""
if model is None:
raise HTTPException(status_code=503, detail="Модель Whisper не загружена")
# Проверка типа задачи
if task not in ["transcribe", "translate"]:
raise HTTPException(status_code=400, detail="Недопустимое значение task. Используйте 'transcribe' или 'translate'")
# Сохранение загруженного файла во временный файл
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_input:
content = await file.read()
tmp_input.write(content)
tmp_input_path = tmp_input.name
try:
# Параметры транскрибации
transcribe_options = {
"task": task,
"verbose": False,
}
# Добавляем язык если указан
if language:
transcribe_options["language"] = language
# Выполнение транскрибации
result = model.transcribe(tmp_input_path, **transcribe_options)
# Возвращаем результат
return JSONResponse(content={
"text": result["text"],
"segments": result.get("segments", []),
"language": result.get("language", language),
"task": task,
"filename": file.filename,
"device_used": DEVICE
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Ошибка при транскрибации: {str(e)}")
finally:
# Удаление временного файла
try:
os.unlink(tmp_input_path)
except:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)