diff --git a/patches/model_management.py b/patches/model_management.py index 43b95be..ec9edb5 100644 --- a/patches/model_management.py +++ b/patches/model_management.py @@ -1106,7 +1106,7 @@ def text_encoder_dtype(device=None): def intermediate_device(): - if args.gpu_only: + if args.gpu_only or UNIFIED_MEMORY: return get_torch_device() else: return torch.device("cpu")