diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0059473 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +# Используем образ ROCm с предустановленным PyTorch +FROM rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0 + +# Устанавливаем рабочую директорию в контейнере +WORKDIR /app + +# Устанавливаем системные зависимости +RUN apt-get update && apt-get install -y \ + ffmpeg \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* + +# Устанавливаем зависимости Python +COPY requirements.txt . +RUN pip install --no-cache-dir --default-timeout=100 -r requirements.txt + +# Копируем остальные файлы приложения +COPY . . + +# Открываем порт, на котором будет работать приложение +EXPOSE 9854 + +# Устанавливаем переменные окружения для ROCm +ENV HSA_OVERRIDE_GFX_VERSION=10.3.0 +ENV PYTORCH_ROCM_ARCH=gfx1030 + +# Команда для запуска приложения +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "9854", "--log-level", "debug"] diff --git a/app.py b/app.py new file mode 100644 index 0000000..465541b --- /dev/null +++ b/app.py @@ -0,0 +1,192 @@ +import logging +import os +import subprocess +import time +from typing import Dict +from typing import Optional, Union, List, Tuple + +import whisper +from fastapi import FastAPI, Depends, HTTPException, UploadFile, File +from fastapi.security import APIKeyHeader + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +app = FastAPI() + +# API key header +api_key_header = APIKeyHeader(name="x-api-key") + +def get_keys(): # не бейте меня за это + keys_file = "keys.txt" + if not os.path.exists(keys_file): + # Create a new keys file with a default key + default_key = os.urandom(32).hex() + with open(keys_file, "w") as f: + f.write(default_key + "\n") + logger.info(f"Created new keys file with default key: {default_key}") + return [default_key] + else: + # Read keys from the existing file + with open(keys_file, "r") as f: + keys = [line.strip() for line in f if line.strip()] + logger.info(f"Loaded {len(keys)} keys from file") + logger.debug(f"Keys: {keys}") + if not keys: + raise ValueError("No keys found in keys.txt") + return keys + +def convert_audio(input_path: str, output_path: str, speed: float = 1.25): + """ + Convert audio to compatible format and speed up + """ + try: + command = [ + 'ffmpeg', '-i', input_path, + '-filter:a', f'atempo={speed}', + '-ar', '16000', + '-ac', '1', + '-c:a', 'pcm_s16le', + output_path, + '-y' + ] + logger.debug(f"Running FFmpeg command: {' '.join(command)}") + subprocess.run(command, check=True, capture_output=True) + return True + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") + return False + +class TranscriptionMetrics: + def __init__(self): + self.start_time = time.time() + self.end_time = None + self.text_length = 0 + self.audio_duration = 0 + + def stop(self, text: str, audio_duration: float): + self.end_time = time.time() + self.text_length = len(text) + self.audio_duration = audio_duration + + def get_metrics(self) -> Dict[str, float]: + processing_time = self.end_time - self.start_time + return { + "processing_time_seconds": round(processing_time, 2), + "characters_per_second": round(self.text_length / processing_time, 2), + "audio_realtime_ratio": round(self.audio_duration / processing_time, 2), + "audio_duration": round(self.audio_duration, 2), + "text_length": self.text_length + } + +def get_audio_duration(file_path: str) -> float: + """Get audio duration using ffprobe""" + cmd = [ + 'ffprobe', + '-v', 'quiet', + '-show_entries', 'format=duration', + '-of', 'default=noprint_wrappers=1:nokey=1', + file_path + ] + try: + output = subprocess.check_output(cmd).decode().strip() + return float(output) + except: + return 0.0 + +@app.post("/transcribe") +async def transcribe_audio( + file: UploadFile = File(...), + token: str = Depends(api_key_header), + model_name: str = "medium", + verbose: Optional[bool] = None, + temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'\"¿([{-", + append_punctuations: str = "\"\'.。,,!!??::\")]}、", + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None +): + # Token validation + if token not in get_keys(): + logger.warning(f"Invalid token attempt: {token}") + raise HTTPException(status_code=403, detail="Forbidden") + + logger.info(f"Processing file: {file.filename} with model: {model_name}") + metrics = TranscriptionMetrics() + + # Save uploaded file + temp_input_path = f"/tmp/input_{file.filename}" + temp_output_path = f"/tmp/converted_{file.filename}.wav" + + try: + with open(temp_input_path, "wb") as f: + f.write(await file.read()) + + # Convert audio if needed + logger.debug("Converting audio file") + if not convert_audio(temp_input_path, temp_output_path): + raise HTTPException(status_code=400, detail="Audio conversion failed") + + # Get audio duration before speed up + original_duration = get_audio_duration(temp_input_path) + + # Load model + logger.debug(f"Loading model: {model_name}") + model = whisper.load_model(model_name, device="cuda") + + # Transcribe + logger.info("Starting transcription") + result = model.transcribe( + temp_output_path, + verbose=verbose, + temperature=temperature, + compression_ratio_threshold=compression_ratio_threshold, + logprob_threshold=logprob_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_previous_text=condition_on_previous_text, + initial_prompt=initial_prompt, + word_timestamps=word_timestamps, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + clip_timestamps=clip_timestamps, + hallucination_silence_threshold=hallucination_silence_threshold + ) + + + # Calculate metrics + metrics.stop(result["text"], original_duration) + logger.info(f"Transcription metrics: {metrics.get_metrics()}") + + # Add metrics to result + result["metrics"] = metrics.get_metrics() + + return result + + except Exception as e: + logger.error(f"Transcription failed: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + finally: + # Cleanup temporary files + if os.path.exists(temp_input_path): + os.remove(temp_input_path) + if os.path.exists(temp_output_path): + os.remove(temp_output_path) + +def main(): + import uvicorn + get_keys() + uvicorn.run(app, host="0.0.0.0") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 406175b..0000000 --- a/config.yaml +++ /dev/null @@ -1,9 +0,0 @@ -server: - host: "0.0.0.0" - port: 8000 - ui: true - -whisper: - model_name: "turbo" - device: "cuda" - compute_type: "int8" \ No newline at end of file diff --git a/converter.py b/converter.py deleted file mode 100644 index 7d05819..0000000 --- a/converter.py +++ /dev/null @@ -1,45 +0,0 @@ -import ffmpeg -import os -import tempfile -import shutil - -def is_valid_format(file_path: str) -> bool: - """Проверяет, является ли аудиофайл 16kHz моно WAV.""" - try: - probe = ffmpeg.probe(file_path) - audio_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None) - if audio_stream is None: - return False - - return ( - audio_stream.get('codec_name') == 'pcm_s16le' and - audio_stream.get('channels') == 1 and - audio_stream.get('sample_rate') == '16000' - ) - except ffmpeg.Error: - return False - -def convert_to_wav(input_file_path: str) -> tuple[str, bool]: - """ - Конвертирует аудиофайл в 16kHz моно WAV. - Возвращает путь к сконвертированному файлу и флаг, указывающий, была ли выполнена конвертация. - Если файл уже в нужном формате, возвращает исходный путь и False. - """ - if is_valid_format(input_file_path): - return input_file_path, False - - output_file_path = tempfile.mktemp(suffix=".wav") - - try: - ffmpeg.input(input_file_path).output( - output_file_path, - acodec='pcm_s16le', - ac=1, - ar='16k' - ).run(capture_stdout=True, capture_stderr=True) - return output_file_path, True - except ffmpeg.Error as e: - if os.path.exists(output_file_path): - os.remove(output_file_path) - raise e - diff --git a/main.py b/main.py deleted file mode 100644 index d6512c9..0000000 --- a/main.py +++ /dev/null @@ -1,198 +0,0 @@ -import os -import tempfile -import sys -import yaml -from typing import Optional, List, Union, Tuple, Iterable -from fastapi import FastAPI, UploadFile, File, Depends -from pydantic import BaseModel -from fastapi.responses import HTMLResponse -from faster_whisper import WhisperModel -from converter import convert_to_wav - -with open("config.yaml", 'r') as f: - config = yaml.safe_load(f) - -app = FastAPI() - -w_config = config['whisper'] - -class TranscriptionOptions(BaseModel): - language: Optional[str] = w_config.get('language') - task: str = w_config.get('task', 'transcribe') - beam_size: int = w_config.get('beam_size', 5) - best_of: int = w_config.get('best_of', 5) - patience: float = w_config.get('patience', 1.0) - length_penalty: float = w_config.get('length_penalty', 1.0) - repetition_penalty: float = w_config.get('repetition_penalty', 1.0) - no_repeat_ngram_size: int = w_config.get('no_repeat_ngram_size', 0) - temperature: Union[float, List[float], Tuple[float, ...]] = w_config.get('temperature', [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) - log_progress: bool = w_config.get('log_progress', False) - compression_ratio_threshold: Optional[float] = w_config.get('compression_ratio_threshold', 2.4) - log_prob_threshold: Optional[float] = w_config.get('log_prob_threshold', -1.0) - no_speech_threshold: Optional[float] = w_config.get('no_speech_threshold', 0.6) - condition_on_previous_text: bool = w_config.get('condition_on_previous_text', True) - prompt_reset_on_temperature: float = w_config.get('prompt_reset_on_temperature', 0.5) - initial_prompt: Optional[Union[str, Iterable[int]]] = w_config.get('initial_prompt') - prefix: Optional[str] = w_config.get('prefix') - suppress_blank: bool = w_config.get('suppress_blank', True) - suppress_tokens: Optional[List[int]] = w_config.get('suppress_tokens', [-1]) - without_timestamps: bool = w_config.get('without_timestamps', False) - max_initial_timestamp: float = w_config.get('max_initial_timestamp', 1.0) - word_timestamps: bool = w_config.get('word_timestamps', False) - prepend_punctuations: str = w_config.get('prepend_punctuations', '"\'“¿([{-') - append_punctuations: str = w_config.get('append_punctuations', '"\'.。,,!!??::”)]}、') - vad_filter: bool = w_config.get('vad_filter', False) - vad_parameters: Optional[dict] = w_config.get('vad_parameters') - max_new_tokens: Optional[int] = w_config.get('max_new_tokens') - chunk_length: Optional[int] = w_config.get('chunk_length') - clip_timestamps: Union[str, List[float]] = w_config.get('clip_timestamps', "0") - hallucination_silence_threshold: Optional[float] = w_config.get('hallucination_silence_threshold') - hotwords: Optional[str] = w_config.get('hotwords') - language_detection_threshold: Optional[float] = w_config.get('language_detection_threshold') - language_detection_segments: int = w_config.get('language_detection_segments', 1) - -class WhisperTranscriber: - def __init__(self, model_name, device, compute_type): - self.model = WhisperModel(model_name, device=device, compute_type=compute_type) - - def transcribe(self, audio_file_path: str, options: dict) -> str: - segments, _ = self.model.transcribe(audio_file_path, **options) - transcription = " ".join([segment.text for segment in segments]) - return transcription - -transcriber = WhisperTranscriber( - model_name=w_config['model_name'], - device=w_config['device'], - compute_type=w_config['compute_type'] -) - -@app.post("/transcribe") -async def transcribe_audio(file: UploadFile = File(...), options: TranscriptionOptions = Depends()): - temp_audio_file_path = None - converted_file_path = None - was_converted = False - try: - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio_file: - temp_audio_file.write(await file.read()) - temp_audio_file_path = temp_audio_file.name - - converted_file_path, was_converted = convert_to_wav(temp_audio_file_path) - - transcription = transcriber.transcribe(converted_file_path, options.dict(exclude_none=True)) - - return {"transcription": transcription} - finally: - if temp_audio_file_path and os.path.exists(temp_audio_file_path): - os.remove(temp_audio_file_path) - if was_converted and converted_file_path and os.path.exists(converted_file_path): - os.remove(converted_file_path) - -def create_ui(): - return ''' - - -
- - -