diff --git a/Dockerfile b/Dockerfile index 0059473..07b5f4e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,23 +1,17 @@ -# Используем образ ROCm с предустановленным PyTorch FROM rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0 -# Устанавливаем рабочую директорию в контейнере WORKDIR /app -# Устанавливаем системные зависимости RUN apt-get update && apt-get install -y \ ffmpeg \ python3-pip \ && rm -rf /var/lib/apt/lists/* -# Устанавливаем зависимости Python COPY requirements.txt . RUN pip install --no-cache-dir --default-timeout=100 -r requirements.txt -# Копируем остальные файлы приложения COPY . . -# Открываем порт, на котором будет работать приложение EXPOSE 9854 # Устанавливаем переменные окружения для ROCm diff --git a/app.py b/app.py index 465541b..b0820f5 100644 --- a/app.py +++ b/app.py @@ -18,10 +18,23 @@ logger = logging.getLogger(__name__) app = FastAPI() +# Load model on startup +model = None + + +@app.on_event("startup") +def load_model(): + global model + logger.info("Loading whisper model...") + model = whisper.load_model("medium", device="cuda") + logger.info("Whisper model loaded.") + + # API key header api_key_header = APIKeyHeader(name="x-api-key") -def get_keys(): # не бейте меня за это + +def get_keys(): # не бейте меня за это keys_file = "keys.txt" if not os.path.exists(keys_file): # Create a new keys file with a default key @@ -40,6 +53,7 @@ def get_keys(): # не бейте меня за это raise ValueError("No keys found in keys.txt") return keys + def convert_audio(input_path: str, output_path: str, speed: float = 1.25): """ Convert audio to compatible format and speed up @@ -61,6 +75,7 @@ def convert_audio(input_path: str, output_path: str, speed: float = 1.25): logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") return False + class TranscriptionMetrics: def __init__(self): self.start_time = time.time() @@ -83,6 +98,7 @@ class TranscriptionMetrics: "text_length": self.text_length } + def get_audio_duration(file_path: str) -> float: """Get audio duration using ffprobe""" cmd = [ @@ -98,23 +114,24 @@ def get_audio_duration(file_path: str) -> float: except: return 0.0 + @app.post("/transcribe") async def transcribe_audio( - file: UploadFile = File(...), - token: str = Depends(api_key_header), - model_name: str = "medium", - verbose: Optional[bool] = None, - temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'\"¿([{-", - append_punctuations: str = "\"\'.。,,!!??::\")]}、", - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None + file: UploadFile = File(...), + token: str = Depends(api_key_header), + model_name: str = "medium", + verbose: Optional[bool] = None, + temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'\"¿([{-", + append_punctuations: str = "\"\'.。,,!!??::\")]}、", + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None ): # Token validation if token not in get_keys(): @@ -140,10 +157,6 @@ async def transcribe_audio( # Get audio duration before speed up original_duration = get_audio_duration(temp_input_path) - # Load model - logger.debug(f"Loading model: {model_name}") - model = whisper.load_model(model_name, device="cuda") - # Transcribe logger.info("Starting transcription") result = model.transcribe( @@ -162,7 +175,6 @@ async def transcribe_audio( hallucination_silence_threshold=hallucination_silence_threshold ) - # Calculate metrics metrics.stop(result["text"], original_duration) logger.info(f"Transcription metrics: {metrics.get_metrics()}") @@ -183,10 +195,11 @@ async def transcribe_audio( if os.path.exists(temp_output_path): os.remove(temp_output_path) + def main(): import uvicorn get_keys() - uvicorn.run(app, host="0.0.0.0") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docker-compose.yml b/docker-compose.yml index a88dc49..9841f63 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,8 +3,6 @@ services: build: . ports: - "9854:9854" - volumes: - - ./keys.txt:/app/keys.txt devices: - "/dev/kfd:/dev/kfd" - "/dev/dri:/dev/dri"