fix: cant proceed to launch
This commit is contained in:
@@ -1,23 +1,17 @@
|
|||||||
# Используем образ ROCm с предустановленным PyTorch
|
|
||||||
FROM rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
FROM rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
||||||
|
|
||||||
# Устанавливаем рабочую директорию в контейнере
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Устанавливаем системные зависимости
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Устанавливаем зависимости Python
|
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir --default-timeout=100 -r requirements.txt
|
RUN pip install --no-cache-dir --default-timeout=100 -r requirements.txt
|
||||||
|
|
||||||
# Копируем остальные файлы приложения
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Открываем порт, на котором будет работать приложение
|
|
||||||
EXPOSE 9854
|
EXPOSE 9854
|
||||||
|
|
||||||
# Устанавливаем переменные окружения для ROCm
|
# Устанавливаем переменные окружения для ROCm
|
||||||
|
|||||||
59
app.py
59
app.py
@@ -18,10 +18,23 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Load model on startup
|
||||||
|
model = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def load_model():
|
||||||
|
global model
|
||||||
|
logger.info("Loading whisper model...")
|
||||||
|
model = whisper.load_model("medium", device="cuda")
|
||||||
|
logger.info("Whisper model loaded.")
|
||||||
|
|
||||||
|
|
||||||
# API key header
|
# API key header
|
||||||
api_key_header = APIKeyHeader(name="x-api-key")
|
api_key_header = APIKeyHeader(name="x-api-key")
|
||||||
|
|
||||||
def get_keys(): # не бейте меня за это
|
|
||||||
|
def get_keys(): # не бейте меня за это
|
||||||
keys_file = "keys.txt"
|
keys_file = "keys.txt"
|
||||||
if not os.path.exists(keys_file):
|
if not os.path.exists(keys_file):
|
||||||
# Create a new keys file with a default key
|
# Create a new keys file with a default key
|
||||||
@@ -40,6 +53,7 @@ def get_keys(): # не бейте меня за это
|
|||||||
raise ValueError("No keys found in keys.txt")
|
raise ValueError("No keys found in keys.txt")
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
def convert_audio(input_path: str, output_path: str, speed: float = 1.25):
|
def convert_audio(input_path: str, output_path: str, speed: float = 1.25):
|
||||||
"""
|
"""
|
||||||
Convert audio to compatible format and speed up
|
Convert audio to compatible format and speed up
|
||||||
@@ -61,6 +75,7 @@ def convert_audio(input_path: str, output_path: str, speed: float = 1.25):
|
|||||||
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionMetrics:
|
class TranscriptionMetrics:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
@@ -83,6 +98,7 @@ class TranscriptionMetrics:
|
|||||||
"text_length": self.text_length
|
"text_length": self.text_length
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_audio_duration(file_path: str) -> float:
|
def get_audio_duration(file_path: str) -> float:
|
||||||
"""Get audio duration using ffprobe"""
|
"""Get audio duration using ffprobe"""
|
||||||
cmd = [
|
cmd = [
|
||||||
@@ -98,23 +114,24 @@ def get_audio_duration(file_path: str) -> float:
|
|||||||
except:
|
except:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
@app.post("/transcribe")
|
@app.post("/transcribe")
|
||||||
async def transcribe_audio(
|
async def transcribe_audio(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
token: str = Depends(api_key_header),
|
token: str = Depends(api_key_header),
|
||||||
model_name: str = "medium",
|
model_name: str = "medium",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'\"¿([{-",
|
prepend_punctuations: str = "\"'\"¿([{-",
|
||||||
append_punctuations: str = "\"\'.。,,!!??::\")]}、",
|
append_punctuations: str = "\"\'.。,,!!??::\")]}、",
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
hallucination_silence_threshold: Optional[float] = None
|
hallucination_silence_threshold: Optional[float] = None
|
||||||
):
|
):
|
||||||
# Token validation
|
# Token validation
|
||||||
if token not in get_keys():
|
if token not in get_keys():
|
||||||
@@ -140,10 +157,6 @@ async def transcribe_audio(
|
|||||||
# Get audio duration before speed up
|
# Get audio duration before speed up
|
||||||
original_duration = get_audio_duration(temp_input_path)
|
original_duration = get_audio_duration(temp_input_path)
|
||||||
|
|
||||||
# Load model
|
|
||||||
logger.debug(f"Loading model: {model_name}")
|
|
||||||
model = whisper.load_model(model_name, device="cuda")
|
|
||||||
|
|
||||||
# Transcribe
|
# Transcribe
|
||||||
logger.info("Starting transcription")
|
logger.info("Starting transcription")
|
||||||
result = model.transcribe(
|
result = model.transcribe(
|
||||||
@@ -162,7 +175,6 @@ async def transcribe_audio(
|
|||||||
hallucination_silence_threshold=hallucination_silence_threshold
|
hallucination_silence_threshold=hallucination_silence_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Calculate metrics
|
# Calculate metrics
|
||||||
metrics.stop(result["text"], original_duration)
|
metrics.stop(result["text"], original_duration)
|
||||||
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
|
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
|
||||||
@@ -183,10 +195,11 @@ async def transcribe_audio(
|
|||||||
if os.path.exists(temp_output_path):
|
if os.path.exists(temp_output_path):
|
||||||
os.remove(temp_output_path)
|
os.remove(temp_output_path)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import uvicorn
|
import uvicorn
|
||||||
get_keys()
|
get_keys()
|
||||||
uvicorn.run(app, host="0.0.0.0")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ services:
|
|||||||
build: .
|
build: .
|
||||||
ports:
|
ports:
|
||||||
- "9854:9854"
|
- "9854:9854"
|
||||||
volumes:
|
|
||||||
- ./keys.txt:/app/keys.txt
|
|
||||||
devices:
|
devices:
|
||||||
- "/dev/kfd:/dev/kfd"
|
- "/dev/kfd:/dev/kfd"
|
||||||
- "/dev/dri:/dev/dri"
|
- "/dev/dri:/dev/dri"
|
||||||
|
|||||||
Reference in New Issue
Block a user