fix(asr): load VAD model on correct CUDA device (#835)
fix(asr): load VAD model on correct CUDA device
Previously, the VAD sub‐model was always initialized on the default CUDA device (cuda:0), even when a higher device_index was specified. This change sets `device_vad` to `cuda:{device_index}` whenever `device == 'cuda'`, while falling back to the original `device` string for non‐CUDA cases. This ensures the VAD model is loaded on the intended GPU.
Co-authored-by: dujing <dujing@xmov.ai>
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
This commit is contained in:
@@ -401,7 +401,11 @@ def load_model(
|
|||||||
if vad_method == "silero":
|
if vad_method == "silero":
|
||||||
vad_model = Silero(**default_vad_options)
|
vad_model = Silero(**default_vad_options)
|
||||||
elif vad_method == "pyannote":
|
elif vad_method == "pyannote":
|
||||||
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
if device == 'cuda':
|
||||||
|
device_vad = f'cuda:{device_index}'
|
||||||
|
else:
|
||||||
|
device_vad = device
|
||||||
|
vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid vad_method: {vad_method}")
|
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user