fix: resync Grace-Blackwell patches with current ComfyUI master
The mounted patches/model_management.py and patches/utils.py were authored against an older ComfyUI, but COMFYUI_REF=master clones the latest. Upstream added the DynamicVRAM/AIMDO system, and main.py now calls model_management.get_all_torch_devices() (13 functions were missing in total), causing comfyui to crash-loop on startup with AttributeError. Regenerated both patches from the current master files and re-applied the documented Sparky edits on top so they stay API-compatible: - model_management.py: unified-memory detection, NORMAL_VRAM retention, 95% weight ratio, intermediate_device()->cuda, soft_empty_cache skip - utils.py: copy=False tensor load on unified memory comfyui now starts cleanly with DynamicVRAM enabled and the Sparky unified-memory path active. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+271
-115
@@ -15,6 +15,7 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
@@ -27,12 +28,18 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
import comfy_aimdo.host_buffer
|
||||||
import comfy_aimdo.vram_buffer
|
import comfy_aimdo.vram_buffer
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@@ -203,6 +210,107 @@ def get_torch_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
|
def get_all_torch_devices(exclude_current=False):
|
||||||
|
global cpu_state
|
||||||
|
devices = []
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
|
||||||
|
# without the AMD arm, single-GPU ROCm users get an empty list
|
||||||
|
# which silently turns unload_all_models() into a no-op.
|
||||||
|
if is_nvidia() or is_amd():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
devices.append(torch.device("cuda", i))
|
||||||
|
elif is_intel_xpu():
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
devices.append(torch.device("xpu", i))
|
||||||
|
elif is_ascend_npu():
|
||||||
|
for i in range(torch.npu.device_count()):
|
||||||
|
devices.append(torch.device("npu", i))
|
||||||
|
elif is_mlu():
|
||||||
|
for i in range(torch.mlu.device_count()):
|
||||||
|
devices.append(torch.device("mlu", i))
|
||||||
|
else:
|
||||||
|
# Fallback for unhandled GPU backends (e.g. DirectML): at least
|
||||||
|
# report the current device so callers like unload_all_models()
|
||||||
|
# do not silently no-op.
|
||||||
|
devices.append(get_torch_device())
|
||||||
|
else:
|
||||||
|
devices.append(get_torch_device())
|
||||||
|
if exclude_current:
|
||||||
|
current = get_torch_device()
|
||||||
|
if current in devices:
|
||||||
|
devices.remove(current)
|
||||||
|
return devices
|
||||||
|
|
||||||
|
def get_gpu_device_options():
|
||||||
|
"""Return list of device option strings for node widgets.
|
||||||
|
|
||||||
|
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||||
|
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||||
|
"""
|
||||||
|
options = ["default", "cpu"]
|
||||||
|
devices = get_all_torch_devices()
|
||||||
|
if len(devices) > 1:
|
||||||
|
for i in range(len(devices)):
|
||||||
|
options.append(f"gpu:{i}")
|
||||||
|
return options
|
||||||
|
|
||||||
|
def get_gpu_device_options_no_cpu():
|
||||||
|
"""Variant of get_gpu_device_options that omits "cpu".
|
||||||
|
|
||||||
|
Intended for components like the VAE selector where running on CPU
|
||||||
|
is impractical and should not be offered as a choice.
|
||||||
|
"""
|
||||||
|
return [o for o in get_gpu_device_options() if o != "cpu"]
|
||||||
|
|
||||||
|
def resolve_gpu_device_option(option: str):
|
||||||
|
"""Resolve a device option string to a torch.device.
|
||||||
|
|
||||||
|
Returns None for "default" (let the caller use its normal default).
|
||||||
|
Returns torch.device("cpu") for "cpu".
|
||||||
|
For "gpu:N", returns the Nth torch device. Returns None if the
|
||||||
|
index is out of range, the option string is malformed, or
|
||||||
|
unrecognized (callers are expected to log their own context-rich
|
||||||
|
message before falling back to the default device).
|
||||||
|
"""
|
||||||
|
if option is None or option == "default":
|
||||||
|
return None
|
||||||
|
if option == "cpu":
|
||||||
|
return torch.device("cpu")
|
||||||
|
if option.startswith("gpu:"):
|
||||||
|
try:
|
||||||
|
idx = int(option[4:])
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
devices = get_all_torch_devices()
|
||||||
|
if 0 <= idx < len(devices):
|
||||||
|
return devices[idx]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cuda_device_context(device):
|
||||||
|
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||||
|
|
||||||
|
Used when running operations on a non-default CUDA device so that custom
|
||||||
|
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||||
|
device index. The previous device is restored on exit.
|
||||||
|
|
||||||
|
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||||
|
the current device.
|
||||||
|
"""
|
||||||
|
prev = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev = torch.cuda.current_device()
|
||||||
|
if prev != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if prev is not None:
|
||||||
|
torch.cuda.set_device(prev)
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
@@ -460,68 +568,43 @@ if cpu_state == CPUState.MPS:
|
|||||||
vram_state = VRAMState.SHARED
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
# --- Grace-Blackwell Unified Memory Detection (Sparky) ---
|
# --- Grace-Blackwell Unified Memory Detection (Sparky) ---
|
||||||
# On unified memory systems (Grace-Blackwell, Apple Silicon), VRAM and RAM
|
# On unified memory systems (Grace-Blackwell), VRAM and RAM are the same
|
||||||
# are the same physical memory. ComfyUI's default behavior treats them as
|
# physical memory. Detect this so we can tune weight ratios and skip
|
||||||
# separate pools, causing pointless CPU offloading and cache thrashing.
|
# empty_cache() to avoid page faults.
|
||||||
# Detect this and optimize: set HIGH_VRAM (no offloading), higher weight
|
|
||||||
# ratio, and skip empty_cache to avoid page faults.
|
|
||||||
def _is_unified_memory():
|
def _is_unified_memory():
|
||||||
"""Detect if GPU and CPU share the same physical memory pool.
|
"""Detect if GPU and CPU share the same physical memory pool.
|
||||||
|
|
||||||
Grace-Blackwell (GB10/GB200) reports identical VRAM and RAM totals
|
Apple Silicon (MPS) is excluded — it has its own VRAMState.SHARED path.
|
||||||
because they share the same HBM/memory controller.
|
|
||||||
|
|
||||||
Note: Apple Silicon (MPS) is NOT included here — it already has
|
|
||||||
its own VRAMState.SHARED path with different semantics. Including
|
|
||||||
it would clobber SHARED with HIGH_VRAM, breaking MPS behavior.
|
|
||||||
"""
|
"""
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
return False # MPS handles unified memory via VRAMState.SHARED
|
return False
|
||||||
|
|
||||||
if cpu_state != CPUState.GPU:
|
if cpu_state != CPUState.GPU:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Grace-Blackwell detection: VRAM total ≈ RAM total (within 5%)
|
|
||||||
# Discrete GPUs always have VRAM < RAM (e.g., 24GB VRAM vs 64GB RAM)
|
|
||||||
try:
|
try:
|
||||||
vram_bytes = torch.cuda.get_device_properties(0).total_memory
|
vram_bytes = torch.cuda.get_device_properties(0).total_memory
|
||||||
ram_bytes = psutil.virtual_memory().total
|
ram_bytes = psutil.virtual_memory().total
|
||||||
ratio = vram_bytes / ram_bytes if ram_bytes > 0 else 0
|
ratio = vram_bytes / ram_bytes if ram_bytes > 0 else 0
|
||||||
|
|
||||||
# Also check device name for explicit GB detection
|
|
||||||
device_name = torch.cuda.get_device_properties(0).name.lower()
|
device_name = torch.cuda.get_device_properties(0).name.lower()
|
||||||
is_gb = 'gb10' in device_name or 'gb200' in device_name or 'grace' in device_name
|
is_gb = 'gb10' in device_name or 'gb200' in device_name or 'grace' in device_name
|
||||||
|
|
||||||
# If VRAM ≈ RAM (ratio > 0.95), it's unified memory
|
|
||||||
# Or if device name explicitly says Grace-Blackwell
|
|
||||||
if ratio > 0.95 or is_gb:
|
if ratio > 0.95 or is_gb:
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
UNIFIED_MEMORY = _is_unified_memory()
|
UNIFIED_MEMORY = _is_unified_memory()
|
||||||
|
|
||||||
if UNIFIED_MEMORY:
|
if UNIFIED_MEMORY:
|
||||||
# On unified memory, NORMAL_VRAM allows ComfyUI to offload unused model
|
# Keep NORMAL_VRAM so ComfyUI can still offload unused layers; since CPU
|
||||||
# layers to CPU when memory is tight. Since CPU and GPU share the same
|
# and GPU share the same physical RAM the offload stays in the same pool.
|
||||||
# physical RAM on GB10, offloaded layers stay in the same physical pool
|
|
||||||
# but through a different allocator. Per-layer partial loading (LowVramPatch)
|
|
||||||
# means only individual layers are copied on-demand, not whole models,
|
|
||||||
# keeping peak memory manageable.
|
|
||||||
# HIGH_VRAM is available via --highvram if everything fits in VRAM.
|
|
||||||
if not (args.highvram or args.gpu_only):
|
if not (args.highvram or args.gpu_only):
|
||||||
logging.info("[Sparky] Grace-Blackwell unified memory detected — "
|
logging.info("[Sparky] Grace-Blackwell unified memory detected — "
|
||||||
"keeping NORMAL_VRAM mode (allows layer offloading)")
|
"keeping NORMAL_VRAM mode (allows layer offloading)")
|
||||||
else:
|
else:
|
||||||
logging.info("[Sparky] Grace-Blackwell unified memory detected — "
|
logging.info("[Sparky] Grace-Blackwell unified memory detected — "
|
||||||
"HIGH_VRAM requested via --highvram")
|
"HIGH_VRAM requested via --highvram")
|
||||||
# Don't override vram_state — let ComfyUI's default NORMAL_VRAM handle
|
|
||||||
# offloading. User can force HIGH_VRAM with --highvram if models fit.
|
|
||||||
logging.info(f"[Sparky] Set vram state to: {vram_state.name} (unified memory)")
|
logging.info(f"[Sparky] Set vram state to: {vram_state.name} (unified memory)")
|
||||||
else:
|
else:
|
||||||
logging.info(f"Set vram state to: {vram_state.name}")
|
logging.info(f"Set vram state to: {vram_state.name}")
|
||||||
@@ -556,9 +639,21 @@ try:
|
|||||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
try:
|
||||||
|
for device in get_all_torch_devices(exclude_current=True):
|
||||||
|
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
current_loaded_models: list[LoadedModel] = []
|
||||||
|
|
||||||
current_loaded_models = []
|
DIRTY_MMAPS = set()
|
||||||
|
|
||||||
|
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
|
||||||
|
|
||||||
|
#Freeing registerables on pressure does imply a GPU sync, so go big on
|
||||||
|
#the hysteresis so each expensive sync gives us back a good chunk.
|
||||||
|
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
|
||||||
|
|
||||||
def module_size(module):
|
def module_size(module):
|
||||||
module_mem = 0
|
module_mem = 0
|
||||||
@@ -568,30 +663,61 @@ def module_size(module):
|
|||||||
module_mem += t.nbytes
|
module_mem += t.nbytes
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
def module_mmap_residency(module, free=False):
|
def mark_mmap_dirty(storage):
|
||||||
mmap_touched_mem = 0
|
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
|
||||||
module_mem = 0
|
if mmap_refs is not None:
|
||||||
bounced_mmaps = set()
|
DIRTY_MMAPS.add(mmap_refs[0])
|
||||||
sd = module.state_dict()
|
|
||||||
for k in sd:
|
def free_pins(size, evict_active=False):
|
||||||
t = sd[k]
|
freed_total = 0
|
||||||
module_mem += t.nbytes
|
for loaded_model in reversed(current_loaded_models):
|
||||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
if size <= 0:
|
||||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
return freed_total
|
||||||
continue
|
model = loaded_model.model
|
||||||
mmap_touched_mem += t.nbytes
|
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||||
if not free:
|
freed = model.partially_unload_ram(size)
|
||||||
continue
|
freed_total += freed
|
||||||
storage._comfy_tensor_mmap_touched = False
|
size -= freed
|
||||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
return freed_total
|
||||||
if mmap_obj in bounced_mmaps:
|
|
||||||
continue
|
def ensure_pin_budget(size, evict_active=False):
|
||||||
mmap_obj.bounce()
|
if args.fast_disk:
|
||||||
bounced_mmaps.add(mmap_obj)
|
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||||
return mmap_touched_mem, module_mem
|
else:
|
||||||
|
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
|
||||||
|
if shortfall <= 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||||
|
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||||
|
|
||||||
|
def free_registrations(shortfall, evict_active=True):
|
||||||
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
|
return False
|
||||||
|
if shortfall <= 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
||||||
|
for loaded_model in reversed(current_loaded_models):
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
|
||||||
|
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||||
|
if shortfall <= 0:
|
||||||
|
return True
|
||||||
|
if evict_active:
|
||||||
|
for loaded_model in current_loaded_models:
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
|
||||||
|
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||||
|
if shortfall <= 0:
|
||||||
|
return True
|
||||||
|
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||||
|
|
||||||
|
def ensure_pin_registerable(size, evict_active=True):
|
||||||
|
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model: ModelPatcher):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
@@ -599,7 +725,7 @@ class LoadedModel:
|
|||||||
self.model_finalizer = None
|
self.model_finalizer = None
|
||||||
self._patcher_finalizer = None
|
self._patcher_finalizer = None
|
||||||
|
|
||||||
def _set_model(self, model):
|
def _set_model(self, model: ModelPatcher):
|
||||||
self._model = weakref.ref(model)
|
self._model = weakref.ref(model)
|
||||||
if model.parent is not None:
|
if model.parent is not None:
|
||||||
self._parent_model = weakref.ref(model.parent)
|
self._parent_model = weakref.ref(model.parent)
|
||||||
@@ -610,6 +736,7 @@ class LoadedModel:
|
|||||||
model = self._parent_model()
|
model = self._parent_model()
|
||||||
if model is not None:
|
if model is not None:
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
|
self.device = model.load_device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
@@ -618,9 +745,6 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
def model_mmap_residency(self, free=False):
|
|
||||||
return self.model.model_mmap_residency(free=free)
|
|
||||||
|
|
||||||
def model_loaded_memory(self):
|
def model_loaded_memory(self):
|
||||||
return self.model.loaded_size()
|
return self.model.loaded_size()
|
||||||
|
|
||||||
@@ -700,15 +824,9 @@ WINDOWS = any(platform.win32_ver())
|
|||||||
|
|
||||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
import comfy.windows
|
|
||||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||||
def get_free_ram():
|
|
||||||
return comfy.windows.get_free_ram()
|
|
||||||
else:
|
|
||||||
def get_free_ram():
|
|
||||||
return psutil.virtual_memory().available
|
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@@ -722,7 +840,6 @@ def minimum_inference_memory():
|
|||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
|
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
@@ -738,10 +855,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
for x in can_unload_sorted:
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
pins_to_free = 1e32
|
|
||||||
if not DISABLE_SMART_MEMORY or device is None:
|
if not DISABLE_SMART_MEMORY or device is None:
|
||||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||||
pins_to_free = pins_required - get_free_ram()
|
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
@@ -750,22 +865,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
if pins_to_free > 0:
|
|
||||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
|
||||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
|
||||||
|
|
||||||
for x in can_unload_sorted:
|
|
||||||
i = x[-1]
|
|
||||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
|
||||||
if ram_to_free <= 0 and i not in unloaded_model:
|
|
||||||
continue
|
|
||||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
|
||||||
if resident_memory > 0:
|
|
||||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
|
|
||||||
|
if not for_dynamic and pins_required > 0:
|
||||||
|
ensure_pin_budget(pins_required)
|
||||||
|
ensure_pin_registerable(pins_required)
|
||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
elif device is not None:
|
elif device is not None:
|
||||||
@@ -827,29 +934,20 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
model_to_unload.model.detach(unpatch_all=False)
|
model_to_unload.model.detach(unpatch_all=False)
|
||||||
model_to_unload.model_finalizer.detach()
|
model_to_unload.model_finalizer.detach()
|
||||||
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
total_pins_required = {}
|
total_pins_required = {}
|
||||||
total_ram_required = {}
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
device = loaded_model.device
|
device = loaded_model.device
|
||||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
if not loaded_model.model.is_dynamic():
|
||||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
|
||||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
|
||||||
#make this JIT to keep as much pinned as possible.
|
|
||||||
pins_required = model_memory - pinned_memory
|
|
||||||
ram_required = model_memory - resident_memory
|
|
||||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
|
||||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||||
device,
|
device,
|
||||||
for_dynamic=free_for_dynamic,
|
for_dynamic=free_for_dynamic,
|
||||||
pins_required=total_pins_required[device],
|
pins_required=total_pins_required.get(device, 0))
|
||||||
ram_required=total_ram_required[device])
|
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
@@ -979,9 +1077,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
|
|
||||||
def maximum_vram_for_weights(device=None):
|
def maximum_vram_for_weights(device=None):
|
||||||
if UNIFIED_MEMORY:
|
if UNIFIED_MEMORY:
|
||||||
# On unified memory, we don't need to reserve as much for "VRAM-only"
|
# GPU and CPU share one pool, so reserve less; 95% minus a 2GB buffer.
|
||||||
# operations since GPU and CPU share the same pool. Use 95% instead of 88%.
|
|
||||||
# Still reserve 2GB for inference buffers and OS overhead.
|
|
||||||
return (get_total_memory(device) * 0.95 - 2 * 1024 * 1024 * 1024)
|
return (get_total_memory(device) * 0.95 - 2 * 1024 * 1024 * 1024)
|
||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
@@ -1290,8 +1386,8 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
|||||||
if cast_buffer is None:
|
if cast_buffer is None:
|
||||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
|
|
||||||
return cast_buffer
|
return cast_buffer
|
||||||
|
|
||||||
def reset_cast_buffers():
|
def reset_cast_buffers():
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||||
@@ -1303,6 +1399,26 @@ def reset_cast_buffers():
|
|||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
synchronize()
|
synchronize()
|
||||||
|
|
||||||
|
for mmap_obj in DIRTY_MMAPS:
|
||||||
|
mmap_obj.bounce()
|
||||||
|
DIRTY_MMAPS.clear()
|
||||||
|
|
||||||
|
for loaded_model in current_loaded_models:
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is not None and model.is_dynamic():
|
||||||
|
pin_state = model.model.dynamic_pins[model.load_device]
|
||||||
|
|
||||||
|
if pin_state["active"]:
|
||||||
|
*_, buckets = pin_state["weights"]
|
||||||
|
for size, bucket in list(buckets.items()):
|
||||||
|
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
|
||||||
|
if not bucket:
|
||||||
|
del buckets[size]
|
||||||
|
|
||||||
|
pin_state["active"] = False
|
||||||
|
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||||
|
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
@@ -1350,25 +1466,29 @@ def sync_stream(device, stream):
|
|||||||
current_stream(device).wait_stream(stream)
|
current_stream(device).wait_stream(stream)
|
||||||
|
|
||||||
|
|
||||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||||
wf_context = nullcontext()
|
wf_context = nullcontext()
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
wf_context = stream
|
wf_context = stream
|
||||||
if hasattr(wf_context, "as_context"):
|
if hasattr(wf_context, "as_context"):
|
||||||
wf_context = wf_context.as_context(stream)
|
wf_context = wf_context.as_context(stream)
|
||||||
|
|
||||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
|
||||||
|
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
dest_view = dest_views.pop(0)
|
dest_view = dest_views.pop(0)
|
||||||
|
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
|
||||||
continue
|
continue
|
||||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
mark_mmap_dirty(storage)
|
||||||
storage._comfy_tensor_mmap_touched = True
|
if dest_view is not None:
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
if dest2_view is not None:
|
||||||
|
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
@@ -1409,14 +1529,18 @@ TOTAL_PINNED_MEMORY = 0
|
|||||||
MAX_PINNED_MEMORY = -1
|
MAX_PINNED_MEMORY = -1
|
||||||
if not args.disable_pinned_memory:
|
if not args.disable_pinned_memory:
|
||||||
if is_nvidia() or is_amd():
|
if is_nvidia() or is_amd():
|
||||||
|
ram = get_total_memory(torch.device("cpu"))
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
|
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
|
||||||
else:
|
else:
|
||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
|
MAX_PINNED_MEMORY = ram * 0.90
|
||||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
|
def pinned_hostbuf_size(size):
|
||||||
|
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||||
|
|
||||||
def discard_cuda_async_error():
|
def discard_cuda_async_error():
|
||||||
try:
|
try:
|
||||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
@@ -1448,8 +1572,8 @@ def pin_memory(tensor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
size = tensor.nbytes
|
size = tensor.nbytes
|
||||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||||
return False
|
ensure_pin_registerable(size)
|
||||||
|
|
||||||
ptr = tensor.data_ptr()
|
ptr = tensor.data_ptr()
|
||||||
if ptr == 0:
|
if ptr == 0:
|
||||||
@@ -1486,7 +1610,8 @@ def unpin_memory(tensor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
size = PINNED_MEMORY.pop(ptr)
|
||||||
|
TOTAL_PINNED_MEMORY -= size
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logging.warning("Unpin error.")
|
logging.warning("Unpin error.")
|
||||||
@@ -1636,6 +1761,13 @@ def is_device_xpu(device):
|
|||||||
def is_device_cuda(device):
|
def is_device_cuda(device):
|
||||||
return is_device_type(device, 'cuda')
|
return is_device_type(device, 'cuda')
|
||||||
|
|
||||||
|
def set_torch_device(device):
|
||||||
|
"""Set the current device for the given torch device. Supports CUDA and XPU."""
|
||||||
|
if is_device_cuda(device):
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
torch.xpu.set_device(device)
|
||||||
|
|
||||||
def is_directml_enabled():
|
def is_directml_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
@@ -1855,18 +1987,15 @@ def synchronize():
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
return
|
return
|
||||||
# MPS must empty its cache regardless of unified memory detection
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
return
|
return
|
||||||
# On unified memory, empty_cache() returns cached allocations to the OS,
|
# [Sparky] On unified memory, empty_cache() returns cached allocations to
|
||||||
# which can cause page faults when PyTorch re-allocates them. Skip it
|
# the OS, causing page faults on re-allocation. Skip unless forced.
|
||||||
# unless forced — keeping the PyTorch memory pool warm is faster.
|
|
||||||
if UNIFIED_MEMORY and not force:
|
if UNIFIED_MEMORY and not force:
|
||||||
# Only synchronize, don't release cached memory back to OS
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
@@ -1883,7 +2012,34 @@ def soft_empty_cache(force=False):
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
for device in get_all_torch_devices():
|
||||||
|
free_memory(1e30, device)
|
||||||
|
|
||||||
|
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||||
|
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||||
|
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||||
|
additional_models = []
|
||||||
|
if unload_additional_models:
|
||||||
|
additional_models = model.get_nested_additional_models()
|
||||||
|
keep_loaded = []
|
||||||
|
for loaded_model in initial_keep_loaded:
|
||||||
|
if loaded_model.model is not None:
|
||||||
|
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||||
|
continue
|
||||||
|
# check additional models if they are a match
|
||||||
|
skip = False
|
||||||
|
for add_model in additional_models:
|
||||||
|
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
|
keep_loaded.append(loaded_model)
|
||||||
|
if not all_devices:
|
||||||
|
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||||
|
else:
|
||||||
|
for device in get_all_torch_devices():
|
||||||
|
free_memory(1e30, device, keep_loaded)
|
||||||
|
|
||||||
def debug_memory_summary():
|
def debug_memory_summary():
|
||||||
if is_amd() or is_nvidia():
|
if is_amd() or is_nvidia():
|
||||||
|
|||||||
+12
-5
@@ -85,8 +85,9 @@ _TYPES = {
|
|||||||
def load_safetensors(ckpt):
|
def load_safetensors(ckpt):
|
||||||
import comfy_aimdo.model_mmap
|
import comfy_aimdo.model_mmap
|
||||||
|
|
||||||
f = open(ckpt, "rb", buffering=0)
|
file_lock = threading.Lock()
|
||||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||||
|
f = model_mmap.get_file_handle()
|
||||||
file_size = os.path.getsize(ckpt)
|
file_size = os.path.getsize(ckpt)
|
||||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||||
|
|
||||||
@@ -111,9 +112,8 @@ def load_safetensors(ckpt):
|
|||||||
storage = tensor.untyped_storage()
|
storage = tensor.untyped_storage()
|
||||||
setattr(storage,
|
setattr(storage,
|
||||||
"_comfy_tensor_file_slice",
|
"_comfy_tensor_file_slice",
|
||||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
|
||||||
sd[name] = tensor
|
sd[name] = tensor
|
||||||
|
|
||||||
return sd, header.get("__metadata__", {}),
|
return sd, header.get("__metadata__", {}),
|
||||||
@@ -1020,10 +1020,11 @@ def bislerp(samples, width, height):
|
|||||||
|
|
||||||
def lanczos(samples, width, height):
|
def lanczos(samples, width, height):
|
||||||
#the below API is strict and expects grayscale to be squeezed
|
#the below API is strict and expects grayscale to be squeezed
|
||||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
if samples.ndim == 4:
|
||||||
|
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
|
||||||
result = torch.stack(images)
|
result = torch.stack(images)
|
||||||
return result.to(samples.device, samples.dtype)
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
@@ -1452,3 +1453,9 @@ def deepcopy_list_dict(obj, memo=None):
|
|||||||
memo[obj_id] = res
|
memo[obj_id] = res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def bit_reverse_range(index, bits):
|
||||||
|
result = 0
|
||||||
|
for _ in range(bits):
|
||||||
|
result = (result << 1) | (index & 1)
|
||||||
|
index >>= 1
|
||||||
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user