Files
simple-asr-server/main.py

199 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import tempfile
import sys
import yaml
from typing import Optional, List, Union, Tuple, Iterable
from fastapi import FastAPI, UploadFile, File, Depends
from pydantic import BaseModel
from fastapi.responses import HTMLResponse
from faster_whisper import WhisperModel
from converter import convert_to_wav
with open("config.yaml", 'r') as f:
config = yaml.safe_load(f)
app = FastAPI()
w_config = config['whisper']
class TranscriptionOptions(BaseModel):
language: Optional[str] = w_config.get('language')
task: str = w_config.get('task', 'transcribe')
beam_size: int = w_config.get('beam_size', 5)
best_of: int = w_config.get('best_of', 5)
patience: float = w_config.get('patience', 1.0)
length_penalty: float = w_config.get('length_penalty', 1.0)
repetition_penalty: float = w_config.get('repetition_penalty', 1.0)
no_repeat_ngram_size: int = w_config.get('no_repeat_ngram_size', 0)
temperature: Union[float, List[float], Tuple[float, ...]] = w_config.get('temperature', [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
log_progress: bool = w_config.get('log_progress', False)
compression_ratio_threshold: Optional[float] = w_config.get('compression_ratio_threshold', 2.4)
log_prob_threshold: Optional[float] = w_config.get('log_prob_threshold', -1.0)
no_speech_threshold: Optional[float] = w_config.get('no_speech_threshold', 0.6)
condition_on_previous_text: bool = w_config.get('condition_on_previous_text', True)
prompt_reset_on_temperature: float = w_config.get('prompt_reset_on_temperature', 0.5)
initial_prompt: Optional[Union[str, Iterable[int]]] = w_config.get('initial_prompt')
prefix: Optional[str] = w_config.get('prefix')
suppress_blank: bool = w_config.get('suppress_blank', True)
suppress_tokens: Optional[List[int]] = w_config.get('suppress_tokens', [-1])
without_timestamps: bool = w_config.get('without_timestamps', False)
max_initial_timestamp: float = w_config.get('max_initial_timestamp', 1.0)
word_timestamps: bool = w_config.get('word_timestamps', False)
prepend_punctuations: str = w_config.get('prepend_punctuations', '"\'“¿([{-')
append_punctuations: str = w_config.get('append_punctuations', '"\'.。,!?::”)]}、')
vad_filter: bool = w_config.get('vad_filter', False)
vad_parameters: Optional[dict] = w_config.get('vad_parameters')
max_new_tokens: Optional[int] = w_config.get('max_new_tokens')
chunk_length: Optional[int] = w_config.get('chunk_length')
clip_timestamps: Union[str, List[float]] = w_config.get('clip_timestamps', "0")
hallucination_silence_threshold: Optional[float] = w_config.get('hallucination_silence_threshold')
hotwords: Optional[str] = w_config.get('hotwords')
language_detection_threshold: Optional[float] = w_config.get('language_detection_threshold')
language_detection_segments: int = w_config.get('language_detection_segments', 1)
class WhisperTranscriber:
def __init__(self, model_name, device, compute_type):
self.model = WhisperModel(model_name, device=device, compute_type=compute_type)
def transcribe(self, audio_file_path: str, options: dict) -> str:
segments, _ = self.model.transcribe(audio_file_path, **options)
transcription = " ".join([segment.text for segment in segments])
return transcription
transcriber = WhisperTranscriber(
model_name=w_config['model_name'],
device=w_config['device'],
compute_type=w_config['compute_type']
)
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...), options: TranscriptionOptions = Depends()):
temp_audio_file_path = None
converted_file_path = None
was_converted = False
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio_file:
temp_audio_file.write(await file.read())
temp_audio_file_path = temp_audio_file.name
converted_file_path, was_converted = convert_to_wav(temp_audio_file_path)
transcription = transcriber.transcribe(converted_file_path, options.dict(exclude_none=True))
return {"transcription": transcription}
finally:
if temp_audio_file_path and os.path.exists(temp_audio_file_path):
os.remove(temp_audio_file_path)
if was_converted and converted_file_path and os.path.exists(converted_file_path):
os.remove(converted_file_path)
def create_ui():
return '''
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Whisper Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
<style>
body {
background-color: #f8f9fa;
}
.container {
max-width: 700px;
}
#transcriptionOutput {
white-space: pre-wrap;
word-wrap: break-word;
}
</style>
</head>
<body>
<div class="container mt-5">
<div class="card">
<div class="card-body">
<h1 class="card-title text-center mb-4">Upload Audio for Transcription</h1>
<div class="mb-3">
<input class="form-control" type="file" id="audioFile" accept="audio/*">
</div>
<div class="d-grid">
<button class="btn btn-primary" onclick="transcribeAudio()">
<span class="spinner-border spinner-border-sm d-none" role="status" aria-hidden="true" id="spinner"></span>
Transcribe
</button>
</div>
<h2 class="mt-4">Transcription:</h2>
<div class="p-3 bg-light rounded">
<pre id="transcriptionOutput"></pre>
</div>
</div>
</div>
</div>
<script>
async function transcribeAudio() {
const fileInput = document.getElementById('audioFile');
const file = fileInput.files[0];
if (!file) {
alert("Please select a file first.");
return;
}
const formData = new FormData();
formData.append('file', file);
const outputElement = document.getElementById('transcriptionOutput');
const spinner = document.getElementById('spinner');
const transcribeButton = document.querySelector('button');
outputElement.innerText = '';
spinner.classList.remove('d-none');
transcribeButton.disabled = true;
try {
const response = await fetch('/transcribe', {
method: 'POST',
body: formData
});
if (response.ok) {
const result = await response.json();
if (result.transcription) {
outputElement.innerText = result.transcription;
} else if (result.error) {
outputElement.innerText = 'Error: ' + result.error;
}
} else {
const errorText = await response.text();
outputElement.innerText = 'Error: ' + response.statusText + ' - ' + errorText;
}
} catch (error) {
outputElement.innerText = 'An error occurred: ' + error;
} finally {
spinner.classList.add('d-none');
transcribeButton.disabled = false;
}
}
</script>
</body>
</html>
'''
if __name__ == "__main__":
import uvicorn
s_config = config['server']
if s_config['ui'] or "--ui" in sys.argv:
@app.get("/", response_class=HTMLResponse)
async def read_root():
return create_ui()
uvicorn.run(
app,
host=s_config['host'],
port=s_config['port']
)