-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
in the code where is checking for cuda please also check for mps (mac gpu )
current code
| return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
example for fix
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
Metadata
Metadata
Assignees
Labels
No labels