From bc6be6c11e48114889a368e8c3597df8aac64ae3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 16:40:04 -0500 Subject: [PATCH] Some fixes to the lowvram system. --- comfy/model_patcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 22de7eea9c2..fc2329543a9 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -374,9 +374,14 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False loading = [] for n, m in self.model.named_modules(): params = [] + skip = False for name, param in m.named_parameters(recurse=False): params.append(name) - if hasattr(m, "comfy_cast_weights") or len(params) > 0: + for name, param in m.named_parameters(recurse=True): + if name not in params: + skip = True # skip random weights in non leaf modules + break + if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): loading.append((comfy.model_management.module_size(m), n, m, params)) load_completely = [] @@ -420,8 +425,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False if m.comfy_cast_weights: wipe_lowvram_weight(m) - mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) + if full_load or mem_counter + module_mem < lowvram_model_memory: + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) load_completely.sort(reverse=True) for x in load_completely: