Skip to content

Commit

Permalink
Memory check before inference to avoid VAE Decode using exceeded VRAM.
Browse files Browse the repository at this point in the history
Check if free memory is not less than expected before doing actual decoding,
and if it fails, try to free for required amount of memory,
and if it still fails, switch to tiled VAE decoding directly.

It seems PyTorch may continue occupying memory until the model is destroyed
after OOM occurs. This commit tries to avoid OOM from happening in the first
place for VAE Decode.

This is for VAE Decode ran with exceeded VRAM from #5737.
  • Loading branch information
wl2018 committed Nov 23, 2024
1 parent 839ed33 commit e85d80f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
if shift_model not in keep_loaded and shift_model.model not in keep_loaded:
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False

Expand Down
33 changes: 26 additions & 7 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,22 +348,41 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)

def decode(self, samples_in):
predicted_oom = False
samples = None
out = None
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
logging.debug(f"Free memory: {free_memory} bytes, predicted memory useage of one batch: {memory_used} bytes")
if free_memory < memory_used:
logging.debug("Possible out of memory is detected, try to free memory.")
model_management.free_memory(memory_used, self.device, [self.patcher])
free_memory = model_management.get_free_memory(self.device)
logging.debug(f"Free memory: {free_memory} bytes")
if free_memory < memory_used:
logging.warning("Warning: Out of memory is predicted for regular VAE decoding, directly switch to tiled VAE decoding.")
predicted_oom = True
raise model_management.OOM_EXCEPTION
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)

for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
try:
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
finally:
samples = None
out = None
pixel_samples = None
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
if not predicted_oom:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1:
pixel_samples = self.decode_tiled_1d(samples_in)
Expand Down

0 comments on commit e85d80f

Please sign in to comment.