Skip to content

Commit

Permalink
Some improvements to the lowvram unloading.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 23, 2024
1 parent 6e8cdcd commit 839ed33
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))

def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
params = []
Expand All @@ -383,6 +380,13 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading

def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = self._load_list()

load_completely = []
loading.sort(reverse=True)
Expand Down Expand Up @@ -514,47 +518,50 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []

for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
unload_list.append((module_mem, n, m))

unload_list = self._load_list()
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
params = unload[3]

lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
move_weight = True
for param in params:
key = "{}.{}".format(n, param)
bk = self.backup.get(key, None)
if bk is not None:
if not lowvram_possible:
move_weight = False
break

if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)

m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1

m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
m.to(device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1

m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))

self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
Expand Down

0 comments on commit 839ed33

Please sign in to comment.