This commit is contained in:
2025-07-11 19:17:33 +03:00
parent 303b7b7584
commit e5fd44e3c3
6 changed files with 221 additions and 253 deletions

28
Dockerfile Normal file
View File

@@ -0,0 +1,28 @@
# Используем образ ROCm с предустановленным PyTorch
FROM rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
# Устанавливаем рабочую директорию в контейнере
WORKDIR /app
# Устанавливаем системные зависимости
RUN apt-get update && apt-get install -y \
ffmpeg \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
# Устанавливаем зависимости Python
COPY requirements.txt .
RUN pip install --no-cache-dir --default-timeout=100 -r requirements.txt
# Копируем остальные файлы приложения
COPY . .
# Открываем порт, на котором будет работать приложение
EXPOSE 9854
# Устанавливаем переменные окружения для ROCm
ENV HSA_OVERRIDE_GFX_VERSION=10.3.0
ENV PYTORCH_ROCM_ARCH=gfx1030
# Команда для запуска приложения
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "9854", "--log-level", "debug"]

192
app.py Normal file
View File

@@ -0,0 +1,192 @@
import logging
import os
import subprocess
import time
from typing import Dict
from typing import Optional, Union, List, Tuple
import whisper
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
from fastapi.security import APIKeyHeader
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI()
# API key header
api_key_header = APIKeyHeader(name="x-api-key")
def get_keys(): # не бейте меня за это
keys_file = "keys.txt"
if not os.path.exists(keys_file):
# Create a new keys file with a default key
default_key = os.urandom(32).hex()
with open(keys_file, "w") as f:
f.write(default_key + "\n")
logger.info(f"Created new keys file with default key: {default_key}")
return [default_key]
else:
# Read keys from the existing file
with open(keys_file, "r") as f:
keys = [line.strip() for line in f if line.strip()]
logger.info(f"Loaded {len(keys)} keys from file")
logger.debug(f"Keys: {keys}")
if not keys:
raise ValueError("No keys found in keys.txt")
return keys
def convert_audio(input_path: str, output_path: str, speed: float = 1.25):
"""
Convert audio to compatible format and speed up
"""
try:
command = [
'ffmpeg', '-i', input_path,
'-filter:a', f'atempo={speed}',
'-ar', '16000',
'-ac', '1',
'-c:a', 'pcm_s16le',
output_path,
'-y'
]
logger.debug(f"Running FFmpeg command: {' '.join(command)}")
subprocess.run(command, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
return False
class TranscriptionMetrics:
def __init__(self):
self.start_time = time.time()
self.end_time = None
self.text_length = 0
self.audio_duration = 0
def stop(self, text: str, audio_duration: float):
self.end_time = time.time()
self.text_length = len(text)
self.audio_duration = audio_duration
def get_metrics(self) -> Dict[str, float]:
processing_time = self.end_time - self.start_time
return {
"processing_time_seconds": round(processing_time, 2),
"characters_per_second": round(self.text_length / processing_time, 2),
"audio_realtime_ratio": round(self.audio_duration / processing_time, 2),
"audio_duration": round(self.audio_duration, 2),
"text_length": self.text_length
}
def get_audio_duration(file_path: str) -> float:
"""Get audio duration using ffprobe"""
cmd = [
'ffprobe',
'-v', 'quiet',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
file_path
]
try:
output = subprocess.check_output(cmd).decode().strip()
return float(output)
except:
return 0.0
@app.post("/transcribe")
async def transcribe_audio(
file: UploadFile = File(...),
token: str = Depends(api_key_header),
model_name: str = "medium",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'\"¿([{-",
append_punctuations: str = "\"\'.。,!?:\")]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None
):
# Token validation
if token not in get_keys():
logger.warning(f"Invalid token attempt: {token}")
raise HTTPException(status_code=403, detail="Forbidden")
logger.info(f"Processing file: {file.filename} with model: {model_name}")
metrics = TranscriptionMetrics()
# Save uploaded file
temp_input_path = f"/tmp/input_{file.filename}"
temp_output_path = f"/tmp/converted_{file.filename}.wav"
try:
with open(temp_input_path, "wb") as f:
f.write(await file.read())
# Convert audio if needed
logger.debug("Converting audio file")
if not convert_audio(temp_input_path, temp_output_path):
raise HTTPException(status_code=400, detail="Audio conversion failed")
# Get audio duration before speed up
original_duration = get_audio_duration(temp_input_path)
# Load model
logger.debug(f"Loading model: {model_name}")
model = whisper.load_model(model_name, device="cuda")
# Transcribe
logger.info("Starting transcription")
result = model.transcribe(
temp_output_path,
verbose=verbose,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
logprob_threshold=logprob_threshold,
no_speech_threshold=no_speech_threshold,
condition_on_previous_text=condition_on_previous_text,
initial_prompt=initial_prompt,
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold
)
# Calculate metrics
metrics.stop(result["text"], original_duration)
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
# Add metrics to result
result["metrics"] = metrics.get_metrics()
return result
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
# Cleanup temporary files
if os.path.exists(temp_input_path):
os.remove(temp_input_path)
if os.path.exists(temp_output_path):
os.remove(temp_output_path)
def main():
import uvicorn
get_keys()
uvicorn.run(app, host="0.0.0.0")
if __name__ == "__main__":
main()

View File

@@ -1,9 +0,0 @@
server:
host: "0.0.0.0"
port: 8000
ui: true
whisper:
model_name: "turbo"
device: "cuda"
compute_type: "int8"

View File

@@ -1,45 +0,0 @@
import ffmpeg
import os
import tempfile
import shutil
def is_valid_format(file_path: str) -> bool:
"""Проверяет, является ли аудиофайл 16kHz моно WAV."""
try:
probe = ffmpeg.probe(file_path)
audio_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
if audio_stream is None:
return False
return (
audio_stream.get('codec_name') == 'pcm_s16le' and
audio_stream.get('channels') == 1 and
audio_stream.get('sample_rate') == '16000'
)
except ffmpeg.Error:
return False
def convert_to_wav(input_file_path: str) -> tuple[str, bool]:
"""
Конвертирует аудиофайл в 16kHz моно WAV.
Возвращает путь к сконвертированному файлу и флаг, указывающий, была ли выполнена конвертация.
Если файл уже в нужном формате, возвращает исходный путь и False.
"""
if is_valid_format(input_file_path):
return input_file_path, False
output_file_path = tempfile.mktemp(suffix=".wav")
try:
ffmpeg.input(input_file_path).output(
output_file_path,
acodec='pcm_s16le',
ac=1,
ar='16k'
).run(capture_stdout=True, capture_stderr=True)
return output_file_path, True
except ffmpeg.Error as e:
if os.path.exists(output_file_path):
os.remove(output_file_path)
raise e

198
main.py
View File

@@ -1,198 +0,0 @@
import os
import tempfile
import sys
import yaml
from typing import Optional, List, Union, Tuple, Iterable
from fastapi import FastAPI, UploadFile, File, Depends
from pydantic import BaseModel
from fastapi.responses import HTMLResponse
from faster_whisper import WhisperModel
from converter import convert_to_wav
with open("config.yaml", 'r') as f:
config = yaml.safe_load(f)
app = FastAPI()
w_config = config['whisper']
class TranscriptionOptions(BaseModel):
language: Optional[str] = w_config.get('language')
task: str = w_config.get('task', 'transcribe')
beam_size: int = w_config.get('beam_size', 5)
best_of: int = w_config.get('best_of', 5)
patience: float = w_config.get('patience', 1.0)
length_penalty: float = w_config.get('length_penalty', 1.0)
repetition_penalty: float = w_config.get('repetition_penalty', 1.0)
no_repeat_ngram_size: int = w_config.get('no_repeat_ngram_size', 0)
temperature: Union[float, List[float], Tuple[float, ...]] = w_config.get('temperature', [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
log_progress: bool = w_config.get('log_progress', False)
compression_ratio_threshold: Optional[float] = w_config.get('compression_ratio_threshold', 2.4)
log_prob_threshold: Optional[float] = w_config.get('log_prob_threshold', -1.0)
no_speech_threshold: Optional[float] = w_config.get('no_speech_threshold', 0.6)
condition_on_previous_text: bool = w_config.get('condition_on_previous_text', True)
prompt_reset_on_temperature: float = w_config.get('prompt_reset_on_temperature', 0.5)
initial_prompt: Optional[Union[str, Iterable[int]]] = w_config.get('initial_prompt')
prefix: Optional[str] = w_config.get('prefix')
suppress_blank: bool = w_config.get('suppress_blank', True)
suppress_tokens: Optional[List[int]] = w_config.get('suppress_tokens', [-1])
without_timestamps: bool = w_config.get('without_timestamps', False)
max_initial_timestamp: float = w_config.get('max_initial_timestamp', 1.0)
word_timestamps: bool = w_config.get('word_timestamps', False)
prepend_punctuations: str = w_config.get('prepend_punctuations', '"\'“¿([{-')
append_punctuations: str = w_config.get('append_punctuations', '"\'.。,!?::”)]}、')
vad_filter: bool = w_config.get('vad_filter', False)
vad_parameters: Optional[dict] = w_config.get('vad_parameters')
max_new_tokens: Optional[int] = w_config.get('max_new_tokens')
chunk_length: Optional[int] = w_config.get('chunk_length')
clip_timestamps: Union[str, List[float]] = w_config.get('clip_timestamps', "0")
hallucination_silence_threshold: Optional[float] = w_config.get('hallucination_silence_threshold')
hotwords: Optional[str] = w_config.get('hotwords')
language_detection_threshold: Optional[float] = w_config.get('language_detection_threshold')
language_detection_segments: int = w_config.get('language_detection_segments', 1)
class WhisperTranscriber:
def __init__(self, model_name, device, compute_type):
self.model = WhisperModel(model_name, device=device, compute_type=compute_type)
def transcribe(self, audio_file_path: str, options: dict) -> str:
segments, _ = self.model.transcribe(audio_file_path, **options)
transcription = " ".join([segment.text for segment in segments])
return transcription
transcriber = WhisperTranscriber(
model_name=w_config['model_name'],
device=w_config['device'],
compute_type=w_config['compute_type']
)
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...), options: TranscriptionOptions = Depends()):
temp_audio_file_path = None
converted_file_path = None
was_converted = False
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio_file:
temp_audio_file.write(await file.read())
temp_audio_file_path = temp_audio_file.name
converted_file_path, was_converted = convert_to_wav(temp_audio_file_path)
transcription = transcriber.transcribe(converted_file_path, options.dict(exclude_none=True))
return {"transcription": transcription}
finally:
if temp_audio_file_path and os.path.exists(temp_audio_file_path):
os.remove(temp_audio_file_path)
if was_converted and converted_file_path and os.path.exists(converted_file_path):
os.remove(converted_file_path)
def create_ui():
return '''
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Whisper Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
<style>
body {
background-color: #f8f9fa;
}
.container {
max-width: 700px;
}
#transcriptionOutput {
white-space: pre-wrap;
word-wrap: break-word;
}
</style>
</head>
<body>
<div class="container mt-5">
<div class="card">
<div class="card-body">
<h1 class="card-title text-center mb-4">Upload Audio for Transcription</h1>
<div class="mb-3">
<input class="form-control" type="file" id="audioFile" accept="audio/*">
</div>
<div class="d-grid">
<button class="btn btn-primary" onclick="transcribeAudio()">
<span class="spinner-border spinner-border-sm d-none" role="status" aria-hidden="true" id="spinner"></span>
Transcribe
</button>
</div>
<h2 class="mt-4">Transcription:</h2>
<div class="p-3 bg-light rounded">
<pre id="transcriptionOutput"></pre>
</div>
</div>
</div>
</div>
<script>
async function transcribeAudio() {
const fileInput = document.getElementById('audioFile');
const file = fileInput.files[0];
if (!file) {
alert("Please select a file first.");
return;
}
const formData = new FormData();
formData.append('file', file);
const outputElement = document.getElementById('transcriptionOutput');
const spinner = document.getElementById('spinner');
const transcribeButton = document.querySelector('button');
outputElement.innerText = '';
spinner.classList.remove('d-none');
transcribeButton.disabled = true;
try {
const response = await fetch('/transcribe', {
method: 'POST',
body: formData
});
if (response.ok) {
const result = await response.json();
if (result.transcription) {
outputElement.innerText = result.transcription;
} else if (result.error) {
outputElement.innerText = 'Error: ' + result.error;
}
} else {
const errorText = await response.text();
outputElement.innerText = 'Error: ' + response.statusText + ' - ' + errorText;
}
} catch (error) {
outputElement.innerText = 'An error occurred: ' + error;
} finally {
spinner.classList.add('d-none');
transcribeButton.disabled = false;
}
}
</script>
</body>
</html>
'''
if __name__ == "__main__":
import uvicorn
s_config = config['server']
if s_config['ui'] or "--ui" in sys.argv:
@app.get("/", response_class=HTMLResponse)
async def read_root():
return create_ui()
uvicorn.run(
app,
host=s_config['host'],
port=s_config['port']
)

View File

@@ -1,6 +1,6 @@
fastapi fastapi
uvicorn[standard] uvicorn[standard]
python-multipart python-multipart
faster-whisper openai-whisper
ffmpeg-python ffmpeg-python
PyYAML PyYAML