Замена модели ASR на GigaAM (CTC v2)
This commit is contained in:
54
app.py
54
app.py
@@ -4,8 +4,8 @@ import subprocess
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Optional, Union, List, Tuple
|
||||
import gigaam
|
||||
|
||||
import whisper
|
||||
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
model = gigaam.load_model("v2_ctc", device="cuda", download_root="./model")
|
||||
|
||||
# API key header
|
||||
api_key_header = APIKeyHeader(name="x-api-key")
|
||||
@@ -108,19 +109,7 @@ def get_audio_duration(file_path: str) -> float:
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
token: str = Depends(api_key_header),
|
||||
model_name: str = "turbo",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'\"¿([{-",
|
||||
append_punctuations: str = "\"\'.。,,!!??::\")]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None
|
||||
model_name: str = "turbo"
|
||||
):
|
||||
# Token validation
|
||||
if token not in get_keys():
|
||||
@@ -148,24 +137,29 @@ async def transcribe_audio(
|
||||
|
||||
# Transcribe
|
||||
logger.info("Starting transcription")
|
||||
result = model.transcribe(
|
||||
temp_output_path,
|
||||
verbose=verbose,
|
||||
temperature=temperature,
|
||||
compression_ratio_threshold=compression_ratio_threshold,
|
||||
logprob_threshold=logprob_threshold,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
condition_on_previous_text=condition_on_previous_text,
|
||||
initial_prompt=initial_prompt,
|
||||
word_timestamps=word_timestamps,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
clip_timestamps=clip_timestamps,
|
||||
hallucination_silence_threshold=hallucination_silence_threshold
|
||||
)
|
||||
if original_duration > 30:
|
||||
logger.info("Audio duration > 30 seconds, using transcribe_longform")
|
||||
transcription_result = model.transcribe_longform(
|
||||
temp_output_path
|
||||
)
|
||||
else:
|
||||
logger.info("Audio duration <= 30 seconds, using transcribe")
|
||||
transcription_result = model.transcribe(
|
||||
temp_output_path
|
||||
)
|
||||
|
||||
full_text = ""
|
||||
for part in transcription_result:
|
||||
if part["transcription"].strip() != "":
|
||||
full_text += part["transcription"].strip() + " "
|
||||
|
||||
result = {
|
||||
"transcription": transcription_result,
|
||||
"text": full_text
|
||||
}
|
||||
|
||||
# Calculate metrics
|
||||
metrics.stop(result["text"], original_duration)
|
||||
metrics.stop(full_text, original_duration)
|
||||
logger.info(f"Transcription metrics: {metrics.get_metrics()}")
|
||||
|
||||
# Add metrics to result
|
||||
|
||||
Reference in New Issue
Block a user