fix(asr): load VAD model on correct CUDA device (#835)
fix(asr): load VAD model on correct CUDA device
Previously, the VAD sub‐model was always initialized on the default CUDA device (cuda:0), even when a higher device_index was specified. This change sets `device_vad` to `cuda:{device_index}` whenever `device == 'cuda'`, while falling back to the original `device` string for non‐CUDA cases. This ensures the VAD model is loaded on the intended GPU.
Co-authored-by: dujing <dujing@xmov.ai>
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
This commit is contained in:
@@ -401,7 +401,11 @@ def load_model(
|
||||
if vad_method == "silero":
|
||||
vad_model = Silero(**default_vad_options)
|
||||
elif vad_method == "pyannote":
|
||||
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
if device == 'cuda':
|
||||
device_vad = f'cuda:{device_index}'
|
||||
else:
|
||||
device_vad = device
|
||||
vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options)
|
||||
else:
|
||||
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user