Замена модели ASR на GigaAM (CTC v2)
This commit is contained in:
54
app.py
54
app.py
@@ -4,8 +4,8 @@ import subprocess
|
|||||||
import time
|
import time
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional, Union, List, Tuple
|
from typing import Optional, Union, List, Tuple
|
||||||
|
import gigaam
|
||||||
|
|
||||||
import whisper
|
|
||||||
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
|
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
model = gigaam.load_model("v2_ctc", device="cuda", download_root="./model")
|
||||||
|
|
||||||
# API key header
|
# API key header
|
||||||
api_key_header = APIKeyHeader(name="x-api-key")
|
api_key_header = APIKeyHeader(name="x-api-key")
|
||||||
@@ -108,19 +109,7 @@ def get_audio_duration(file_path: str) -> float:
|
|||||||
async def transcribe_audio(
|
async def transcribe_audio(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
token: str = Depends(api_key_header),
|
token: str = Depends(api_key_header),
|
||||||
model_name: str = "turbo",
|
model_name: str = "turbo"
|
||||||
verbose: Optional[bool] = None,
|
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
|
||||||
condition_on_previous_text: bool = True,
|
|
||||||
initial_prompt: Optional[str] = None,
|
|
||||||
word_timestamps: bool = False,
|
|
||||||
prepend_punctuations: str = "\"'\"¿([{-",
|
|
||||||
append_punctuations: str = "\"\'.。,,!!??::\")]}、",
|
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
|
||||||
hallucination_silence_threshold: Optional[float] = None
|
|
||||||
):
|
):
|
||||||
# Token validation
|
# Token validation
|
||||||
if token not in get_keys():
|
if token not in get_keys():
|
||||||
@@ -148,24 +137,29 @@ async def transcribe_audio(
|
|||||||
|
|
||||||
# Transcribe
|
# Transcribe
|
||||||
logger.info("Starting transcription")
|
logger.info("Starting transcription")
|
||||||
result = model.transcribe(
|
if original_duration > 30:
|
||||||
temp_output_path,
|
logger.info("Audio duration > 30 seconds, using transcribe_longform")
|
||||||
verbose=verbose,
|
transcription_result = model.transcribe_longform(
|
||||||
temperature=temperature,
|
temp_output_path
|
||||||
compression_ratio_threshold=compression_ratio_threshold,
|
)
|
||||||
logprob_threshold=logprob_threshold,
|
else:
|
||||||
no_speech_threshold=no_speech_threshold,
|
logger.info("Audio duration <= 30 seconds, using transcribe")
|
||||||
condition_on_previous_text=condition_on_previous_text,
|
transcription_result = model.transcribe(
|
||||||
initial_prompt=initial_prompt,
|
temp_output_path
|
||||||
word_timestamps=word_timestamps,
|
)
|
||||||
prepend_punctuations=prepend_punctuations,
|
|
||||||
append_punctuations=append_punctuations,
|
full_text = ""
|
||||||
clip_timestamps=clip_timestamps,
|
for part in transcription_result:
|
||||||
hallucination_silence_threshold=hallucination_silence_threshold
|
if part["transcription"].strip() != "":
|
||||||
)
|
full_text += part["transcription"].strip() + " "
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"transcription": transcription_result,
|
||||||
|
"text": full_text
|
||||||
|
}
|
||||||
|
|
||||||
# Calculate metrics
|
# Calculate metrics
|
||||||
metrics.stop(result["text"], original_duration)
|
metrics.stop(full_text, original_duration)
|
||||||
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
|
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
|
||||||
|
|
||||||
# Add metrics to result
|
# Add metrics to result
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
services:
|
services:
|
||||||
whisper-app:
|
whisper-app:
|
||||||
build: .
|
build: .
|
||||||
ports:
|
ports:
|
||||||
- "9854:9854"
|
- "9854:9854"
|
||||||
devices:
|
volumes:
|
||||||
- "/dev/kfd:/dev/kfd"
|
- ./keys.txt:/app/keys.txt
|
||||||
- "/dev/dri:/dev/dri"
|
- /tmp:/tmp
|
||||||
group_add:
|
command: ["python", "app.py"]
|
||||||
- video
|
|
||||||
|
|||||||
@@ -4,3 +4,4 @@ python-multipart
|
|||||||
openai-whisper
|
openai-whisper
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
PyYAML
|
PyYAML
|
||||||
|
numpy<2.0.0
|
||||||
Reference in New Issue
Block a user