149 lines
4.8 KiB
Python
149 lines
4.8 KiB
Python
"""
|
||
Whisper API - FastAPI сервис для транскрибации аудио
|
||
Поддержка ROCm (AMD GPU)
|
||
Эндпоинт: POST /transcribe
|
||
"""
|
||
import os
|
||
import tempfile
|
||
import whisper
|
||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
import torch
|
||
|
||
app = FastAPI(title="Whisper Transcription API")
|
||
|
||
# Загрузка модели при старте
|
||
MODEL_SIZE = os.getenv("MODEL_SIZE", "base")
|
||
DEVICE_ENV = os.getenv("DEVICE", "cuda")
|
||
|
||
# Определение устройства для инференса
|
||
# PyTorch на ROCm определяет AMD GPU как CUDA-совместимое устройство
|
||
def get_device():
|
||
"""Определение доступного устройства для инференса"""
|
||
if DEVICE_ENV == "cuda" and torch.cuda.is_available():
|
||
device_name = torch.cuda.get_device_name(0)
|
||
print(f"Используем GPU: {device_name}")
|
||
return "cuda"
|
||
else:
|
||
print("Используем CPU")
|
||
return "cpu"
|
||
|
||
DEVICE = get_device()
|
||
|
||
print(f"Загрузка модели Whisper {MODEL_SIZE} на устройство: {DEVICE}")
|
||
|
||
# Информация о GPU
|
||
if torch.cuda.is_available():
|
||
gpu_name = torch.cuda.get_device_name(0)
|
||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||
print(f"GPU: {gpu_name}, память: {gpu_memory:.1f} GB")
|
||
|
||
try:
|
||
model = whisper.load_model(MODEL_SIZE, device=DEVICE)
|
||
print("Модель успешно загружена")
|
||
except Exception as e:
|
||
print(f"Ошибка загрузки модели: {e}")
|
||
model = None
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""Проверка здоровья сервиса"""
|
||
gpu_info = None
|
||
if torch.cuda.is_available():
|
||
gpu_info = {
|
||
"name": torch.cuda.get_device_name(0),
|
||
"available": True
|
||
}
|
||
return {
|
||
"status": "healthy",
|
||
"model_loaded": model is not None,
|
||
"device": DEVICE,
|
||
"gpu": gpu_info
|
||
}
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""Корневой эндпоинт с информацией об API"""
|
||
return {
|
||
"service": "Whisper API (ROCm/AMD GPU)",
|
||
"model": MODEL_SIZE,
|
||
"device": DEVICE,
|
||
"gpu_available": torch.cuda.is_available(),
|
||
"endpoints": {
|
||
"transcribe": "POST /transcribe",
|
||
"health": "GET /health"
|
||
}
|
||
}
|
||
|
||
|
||
@app.post("/transcribe")
|
||
async def transcribe(
|
||
file: UploadFile = File(...),
|
||
task: str = Form("transcribe"),
|
||
language: str = Form(None)
|
||
):
|
||
"""
|
||
Транскрибация или перевод аудио файла
|
||
|
||
Параметры:
|
||
- file: аудио/видео файл
|
||
- task: "transcribe" или "translate"
|
||
- language: код языка (опционально, например "ru", "en")
|
||
|
||
Пример curl:
|
||
curl -X POST -F "file=@test.m4a" -F "task=transcribe" http://localhost:8080/transcribe
|
||
"""
|
||
if model is None:
|
||
raise HTTPException(status_code=503, detail="Модель Whisper не загружена")
|
||
|
||
# Проверка типа задачи
|
||
if task not in ["transcribe", "translate"]:
|
||
raise HTTPException(status_code=400, detail="Недопустимое значение task. Используйте 'transcribe' или 'translate'")
|
||
|
||
# Сохранение загруженного файла во временный файл
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_input:
|
||
content = await file.read()
|
||
tmp_input.write(content)
|
||
tmp_input_path = tmp_input.name
|
||
|
||
try:
|
||
# Параметры транскрибации
|
||
transcribe_options = {
|
||
"task": task,
|
||
"verbose": False,
|
||
}
|
||
|
||
# Добавляем язык если указан
|
||
if language:
|
||
transcribe_options["language"] = language
|
||
|
||
# Выполнение транскрибации
|
||
result = model.transcribe(tmp_input_path, **transcribe_options)
|
||
|
||
# Возвращаем результат
|
||
return JSONResponse(content={
|
||
"text": result["text"],
|
||
"segments": result.get("segments", []),
|
||
"language": result.get("language", language),
|
||
"task": task,
|
||
"filename": file.filename,
|
||
"device_used": DEVICE
|
||
})
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Ошибка при транскрибации: {str(e)}")
|
||
|
||
finally:
|
||
# Удаление временного файла
|
||
try:
|
||
os.unlink(tmp_input_path)
|
||
except:
|
||
pass
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8080)
|