diff --git a/app.py b/app.py index b984f52..781a047 100644 --- a/app.py +++ b/app.py @@ -2,10 +2,10 @@ import logging import os import subprocess import time +from os import getenv from typing import Dict -from typing import Optional, Union, List, Tuple -import gigaam +import gigaam from fastapi import FastAPI, Depends, HTTPException, UploadFile, File from fastapi.security import APIKeyHeader @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) app = FastAPI() -model = gigaam.load_model("v2_ctc", device="cuda", download_root="./model") +model = gigaam.load_model("v2_ctc", device=getenv("ASR_DEVICE"), download_root=getenv("ASR_MODELS_ROOT")) # API key header api_key_header = APIKeyHeader(name="x-api-key")