fix(asr): load VAD model on correct CUDA device (#835)

fix(asr): load VAD model on correct CUDA device

Previously, the VAD sub‐model was always initialized on the default CUDA device (cuda:0), even when a higher device_index was specified. This change sets `device_vad` to `cuda:{device_index}` whenever `device == 'cuda'`, while falling back to the original `device` string for non‐CUDA cases. This ensures the VAD model is loaded on the intended GPU.


Co-authored-by: dujing <dujing@xmov.ai>
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
This commit is contained in:
Jean Du
2025-07-02 14:07:59 +08:00
committed by GitHub
parent f4261f34e9
commit 2d9ce44329

View File

@@ -401,7 +401,11 @@ def load_model(
if vad_method == "silero": if vad_method == "silero":
vad_model = Silero(**default_vad_options) vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote": elif vad_method == "pyannote":
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) if device == 'cuda':
device_vad = f'cuda:{device_index}'
else:
device_vad = device
vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options)
else: else:
raise ValueError(f"Invalid vad_method: {vad_method}") raise ValueError(f"Invalid vad_method: {vad_method}")