created
This commit is contained in:
28
Dockerfile
Normal file
28
Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
# Используем образ ROCm с предустановленным PyTorch
|
||||
FROM rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
||||
|
||||
# Устанавливаем рабочую директорию в контейнере
|
||||
WORKDIR /app
|
||||
|
||||
# Устанавливаем системные зависимости
|
||||
RUN apt-get update && apt-get install -y \
|
||||
ffmpeg \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Устанавливаем зависимости Python
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --default-timeout=100 -r requirements.txt
|
||||
|
||||
# Копируем остальные файлы приложения
|
||||
COPY . .
|
||||
|
||||
# Открываем порт, на котором будет работать приложение
|
||||
EXPOSE 9854
|
||||
|
||||
# Устанавливаем переменные окружения для ROCm
|
||||
ENV HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||
ENV PYTORCH_ROCM_ARCH=gfx1030
|
||||
|
||||
# Команда для запуска приложения
|
||||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "9854", "--log-level", "debug"]
|
||||
192
app.py
Normal file
192
app.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Optional, Union, List, Tuple
|
||||
|
||||
import whisper
|
||||
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# API key header
|
||||
api_key_header = APIKeyHeader(name="x-api-key")
|
||||
|
||||
def get_keys(): # не бейте меня за это
|
||||
keys_file = "keys.txt"
|
||||
if not os.path.exists(keys_file):
|
||||
# Create a new keys file with a default key
|
||||
default_key = os.urandom(32).hex()
|
||||
with open(keys_file, "w") as f:
|
||||
f.write(default_key + "\n")
|
||||
logger.info(f"Created new keys file with default key: {default_key}")
|
||||
return [default_key]
|
||||
else:
|
||||
# Read keys from the existing file
|
||||
with open(keys_file, "r") as f:
|
||||
keys = [line.strip() for line in f if line.strip()]
|
||||
logger.info(f"Loaded {len(keys)} keys from file")
|
||||
logger.debug(f"Keys: {keys}")
|
||||
if not keys:
|
||||
raise ValueError("No keys found in keys.txt")
|
||||
return keys
|
||||
|
||||
def convert_audio(input_path: str, output_path: str, speed: float = 1.25):
|
||||
"""
|
||||
Convert audio to compatible format and speed up
|
||||
"""
|
||||
try:
|
||||
command = [
|
||||
'ffmpeg', '-i', input_path,
|
||||
'-filter:a', f'atempo={speed}',
|
||||
'-ar', '16000',
|
||||
'-ac', '1',
|
||||
'-c:a', 'pcm_s16le',
|
||||
output_path,
|
||||
'-y'
|
||||
]
|
||||
logger.debug(f"Running FFmpeg command: {' '.join(command)}")
|
||||
subprocess.run(command, check=True, capture_output=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
||||
return False
|
||||
|
||||
class TranscriptionMetrics:
|
||||
def __init__(self):
|
||||
self.start_time = time.time()
|
||||
self.end_time = None
|
||||
self.text_length = 0
|
||||
self.audio_duration = 0
|
||||
|
||||
def stop(self, text: str, audio_duration: float):
|
||||
self.end_time = time.time()
|
||||
self.text_length = len(text)
|
||||
self.audio_duration = audio_duration
|
||||
|
||||
def get_metrics(self) -> Dict[str, float]:
|
||||
processing_time = self.end_time - self.start_time
|
||||
return {
|
||||
"processing_time_seconds": round(processing_time, 2),
|
||||
"characters_per_second": round(self.text_length / processing_time, 2),
|
||||
"audio_realtime_ratio": round(self.audio_duration / processing_time, 2),
|
||||
"audio_duration": round(self.audio_duration, 2),
|
||||
"text_length": self.text_length
|
||||
}
|
||||
|
||||
def get_audio_duration(file_path: str) -> float:
|
||||
"""Get audio duration using ffprobe"""
|
||||
cmd = [
|
||||
'ffprobe',
|
||||
'-v', 'quiet',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
file_path
|
||||
]
|
||||
try:
|
||||
output = subprocess.check_output(cmd).decode().strip()
|
||||
return float(output)
|
||||
except:
|
||||
return 0.0
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
token: str = Depends(api_key_header),
|
||||
model_name: str = "medium",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'\"¿([{-",
|
||||
append_punctuations: str = "\"\'.。,,!!??::\")]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None
|
||||
):
|
||||
# Token validation
|
||||
if token not in get_keys():
|
||||
logger.warning(f"Invalid token attempt: {token}")
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
logger.info(f"Processing file: {file.filename} with model: {model_name}")
|
||||
metrics = TranscriptionMetrics()
|
||||
|
||||
# Save uploaded file
|
||||
temp_input_path = f"/tmp/input_{file.filename}"
|
||||
temp_output_path = f"/tmp/converted_{file.filename}.wav"
|
||||
|
||||
try:
|
||||
with open(temp_input_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
# Convert audio if needed
|
||||
logger.debug("Converting audio file")
|
||||
if not convert_audio(temp_input_path, temp_output_path):
|
||||
raise HTTPException(status_code=400, detail="Audio conversion failed")
|
||||
|
||||
# Get audio duration before speed up
|
||||
original_duration = get_audio_duration(temp_input_path)
|
||||
|
||||
# Load model
|
||||
logger.debug(f"Loading model: {model_name}")
|
||||
model = whisper.load_model(model_name, device="cuda")
|
||||
|
||||
# Transcribe
|
||||
logger.info("Starting transcription")
|
||||
result = model.transcribe(
|
||||
temp_output_path,
|
||||
verbose=verbose,
|
||||
temperature=temperature,
|
||||
compression_ratio_threshold=compression_ratio_threshold,
|
||||
logprob_threshold=logprob_threshold,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
condition_on_previous_text=condition_on_previous_text,
|
||||
initial_prompt=initial_prompt,
|
||||
word_timestamps=word_timestamps,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
clip_timestamps=clip_timestamps,
|
||||
hallucination_silence_threshold=hallucination_silence_threshold
|
||||
)
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
metrics.stop(result["text"], original_duration)
|
||||
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
|
||||
|
||||
# Add metrics to result
|
||||
result["metrics"] = metrics.get_metrics()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
finally:
|
||||
# Cleanup temporary files
|
||||
if os.path.exists(temp_input_path):
|
||||
os.remove(temp_input_path)
|
||||
if os.path.exists(temp_output_path):
|
||||
os.remove(temp_output_path)
|
||||
|
||||
def main():
|
||||
import uvicorn
|
||||
get_keys()
|
||||
uvicorn.run(app, host="0.0.0.0")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +0,0 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8000
|
||||
ui: true
|
||||
|
||||
whisper:
|
||||
model_name: "turbo"
|
||||
device: "cuda"
|
||||
compute_type: "int8"
|
||||
45
converter.py
45
converter.py
@@ -1,45 +0,0 @@
|
||||
import ffmpeg
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
def is_valid_format(file_path: str) -> bool:
|
||||
"""Проверяет, является ли аудиофайл 16kHz моно WAV."""
|
||||
try:
|
||||
probe = ffmpeg.probe(file_path)
|
||||
audio_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
|
||||
if audio_stream is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
audio_stream.get('codec_name') == 'pcm_s16le' and
|
||||
audio_stream.get('channels') == 1 and
|
||||
audio_stream.get('sample_rate') == '16000'
|
||||
)
|
||||
except ffmpeg.Error:
|
||||
return False
|
||||
|
||||
def convert_to_wav(input_file_path: str) -> tuple[str, bool]:
|
||||
"""
|
||||
Конвертирует аудиофайл в 16kHz моно WAV.
|
||||
Возвращает путь к сконвертированному файлу и флаг, указывающий, была ли выполнена конвертация.
|
||||
Если файл уже в нужном формате, возвращает исходный путь и False.
|
||||
"""
|
||||
if is_valid_format(input_file_path):
|
||||
return input_file_path, False
|
||||
|
||||
output_file_path = tempfile.mktemp(suffix=".wav")
|
||||
|
||||
try:
|
||||
ffmpeg.input(input_file_path).output(
|
||||
output_file_path,
|
||||
acodec='pcm_s16le',
|
||||
ac=1,
|
||||
ar='16k'
|
||||
).run(capture_stdout=True, capture_stderr=True)
|
||||
return output_file_path, True
|
||||
except ffmpeg.Error as e:
|
||||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
raise e
|
||||
|
||||
198
main.py
198
main.py
@@ -1,198 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import sys
|
||||
import yaml
|
||||
from typing import Optional, List, Union, Tuple, Iterable
|
||||
from fastapi import FastAPI, UploadFile, File, Depends
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import HTMLResponse
|
||||
from faster_whisper import WhisperModel
|
||||
from converter import convert_to_wav
|
||||
|
||||
with open("config.yaml", 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
w_config = config['whisper']
|
||||
|
||||
class TranscriptionOptions(BaseModel):
|
||||
language: Optional[str] = w_config.get('language')
|
||||
task: str = w_config.get('task', 'transcribe')
|
||||
beam_size: int = w_config.get('beam_size', 5)
|
||||
best_of: int = w_config.get('best_of', 5)
|
||||
patience: float = w_config.get('patience', 1.0)
|
||||
length_penalty: float = w_config.get('length_penalty', 1.0)
|
||||
repetition_penalty: float = w_config.get('repetition_penalty', 1.0)
|
||||
no_repeat_ngram_size: int = w_config.get('no_repeat_ngram_size', 0)
|
||||
temperature: Union[float, List[float], Tuple[float, ...]] = w_config.get('temperature', [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
|
||||
log_progress: bool = w_config.get('log_progress', False)
|
||||
compression_ratio_threshold: Optional[float] = w_config.get('compression_ratio_threshold', 2.4)
|
||||
log_prob_threshold: Optional[float] = w_config.get('log_prob_threshold', -1.0)
|
||||
no_speech_threshold: Optional[float] = w_config.get('no_speech_threshold', 0.6)
|
||||
condition_on_previous_text: bool = w_config.get('condition_on_previous_text', True)
|
||||
prompt_reset_on_temperature: float = w_config.get('prompt_reset_on_temperature', 0.5)
|
||||
initial_prompt: Optional[Union[str, Iterable[int]]] = w_config.get('initial_prompt')
|
||||
prefix: Optional[str] = w_config.get('prefix')
|
||||
suppress_blank: bool = w_config.get('suppress_blank', True)
|
||||
suppress_tokens: Optional[List[int]] = w_config.get('suppress_tokens', [-1])
|
||||
without_timestamps: bool = w_config.get('without_timestamps', False)
|
||||
max_initial_timestamp: float = w_config.get('max_initial_timestamp', 1.0)
|
||||
word_timestamps: bool = w_config.get('word_timestamps', False)
|
||||
prepend_punctuations: str = w_config.get('prepend_punctuations', '"\'“¿([{-')
|
||||
append_punctuations: str = w_config.get('append_punctuations', '"\'.。,,!!??::”)]}、')
|
||||
vad_filter: bool = w_config.get('vad_filter', False)
|
||||
vad_parameters: Optional[dict] = w_config.get('vad_parameters')
|
||||
max_new_tokens: Optional[int] = w_config.get('max_new_tokens')
|
||||
chunk_length: Optional[int] = w_config.get('chunk_length')
|
||||
clip_timestamps: Union[str, List[float]] = w_config.get('clip_timestamps', "0")
|
||||
hallucination_silence_threshold: Optional[float] = w_config.get('hallucination_silence_threshold')
|
||||
hotwords: Optional[str] = w_config.get('hotwords')
|
||||
language_detection_threshold: Optional[float] = w_config.get('language_detection_threshold')
|
||||
language_detection_segments: int = w_config.get('language_detection_segments', 1)
|
||||
|
||||
class WhisperTranscriber:
|
||||
def __init__(self, model_name, device, compute_type):
|
||||
self.model = WhisperModel(model_name, device=device, compute_type=compute_type)
|
||||
|
||||
def transcribe(self, audio_file_path: str, options: dict) -> str:
|
||||
segments, _ = self.model.transcribe(audio_file_path, **options)
|
||||
transcription = " ".join([segment.text for segment in segments])
|
||||
return transcription
|
||||
|
||||
transcriber = WhisperTranscriber(
|
||||
model_name=w_config['model_name'],
|
||||
device=w_config['device'],
|
||||
compute_type=w_config['compute_type']
|
||||
)
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe_audio(file: UploadFile = File(...), options: TranscriptionOptions = Depends()):
|
||||
temp_audio_file_path = None
|
||||
converted_file_path = None
|
||||
was_converted = False
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio_file:
|
||||
temp_audio_file.write(await file.read())
|
||||
temp_audio_file_path = temp_audio_file.name
|
||||
|
||||
converted_file_path, was_converted = convert_to_wav(temp_audio_file_path)
|
||||
|
||||
transcription = transcriber.transcribe(converted_file_path, options.dict(exclude_none=True))
|
||||
|
||||
return {"transcription": transcription}
|
||||
finally:
|
||||
if temp_audio_file_path and os.path.exists(temp_audio_file_path):
|
||||
os.remove(temp_audio_file_path)
|
||||
if was_converted and converted_file_path and os.path.exists(converted_file_path):
|
||||
os.remove(converted_file_path)
|
||||
|
||||
def create_ui():
|
||||
return '''
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Whisper Transcription</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||
<style>
|
||||
body {
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
.container {
|
||||
max-width: 700px;
|
||||
}
|
||||
#transcriptionOutput {
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container mt-5">
|
||||
<div class="card">
|
||||
<div class="card-body">
|
||||
<h1 class="card-title text-center mb-4">Upload Audio for Transcription</h1>
|
||||
<div class="mb-3">
|
||||
<input class="form-control" type="file" id="audioFile" accept="audio/*">
|
||||
</div>
|
||||
<div class="d-grid">
|
||||
<button class="btn btn-primary" onclick="transcribeAudio()">
|
||||
<span class="spinner-border spinner-border-sm d-none" role="status" aria-hidden="true" id="spinner"></span>
|
||||
Transcribe
|
||||
</button>
|
||||
</div>
|
||||
<h2 class="mt-4">Transcription:</h2>
|
||||
<div class="p-3 bg-light rounded">
|
||||
<pre id="transcriptionOutput"></pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
async function transcribeAudio() {
|
||||
const fileInput = document.getElementById('audioFile');
|
||||
const file = fileInput.files[0];
|
||||
if (!file) {
|
||||
alert("Please select a file first.");
|
||||
return;
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
const outputElement = document.getElementById('transcriptionOutput');
|
||||
const spinner = document.getElementById('spinner');
|
||||
const transcribeButton = document.querySelector('button');
|
||||
|
||||
outputElement.innerText = '';
|
||||
spinner.classList.remove('d-none');
|
||||
transcribeButton.disabled = true;
|
||||
|
||||
|
||||
try {
|
||||
const response = await fetch('/transcribe', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const result = await response.json();
|
||||
if (result.transcription) {
|
||||
outputElement.innerText = result.transcription;
|
||||
} else if (result.error) {
|
||||
outputElement.innerText = 'Error: ' + result.error;
|
||||
}
|
||||
} else {
|
||||
const errorText = await response.text();
|
||||
outputElement.innerText = 'Error: ' + response.statusText + ' - ' + errorText;
|
||||
}
|
||||
} catch (error) {
|
||||
outputElement.innerText = 'An error occurred: ' + error;
|
||||
} finally {
|
||||
spinner.classList.add('d-none');
|
||||
transcribeButton.disabled = false;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
s_config = config['server']
|
||||
|
||||
if s_config['ui'] or "--ui" in sys.argv:
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def read_root():
|
||||
return create_ui()
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=s_config['host'],
|
||||
port=s_config['port']
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
python-multipart
|
||||
faster-whisper
|
||||
openai-whisper
|
||||
ffmpeg-python
|
||||
PyYAML
|
||||
|
||||
Reference in New Issue
Block a user