From a3b9b3c1c3646b589362d29eec000dbdce07483e Mon Sep 17 00:00:00 2001 From: William <36047903+wl2018@users.noreply.github.com> Date: Sun, 24 Nov 2024 18:47:01 +0800 Subject: [PATCH] Memory check before inference to avoid VAE Decode using exceeded VRAM. Check if free memory is not less than expected before doing actual decoding, and if it fails, switch to tiled VAE decoding directly. It seems PyTorch may continue occupying memory until the model is destroyed after OOM occurs. This commit tries to avoid OOM from happening in the first place for VAE Decode. This is for VAE Decode ran with exceeded VRAM from #5737. --- comfy/sd.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index e2af7078121..144cf68b8f2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -348,11 +348,19 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) def decode(self, samples_in): + predicted_oom = False + samples = None + out = None pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) + logging.debug(f"Free memory: {free_memory} bytes, predicted memory useage of one batch: {memory_used} bytes") + if free_memory < memory_used: + logging.warning("Warning: Out of memory is predicted for regular VAE decoding, directly switch to tiled VAE decoding.") + predicted_oom = True + raise model_management.OOM_EXCEPTION batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -363,7 +371,11 @@ def decode(self, samples_in): pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out except model_management.OOM_EXCEPTION as e: - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + samples = None + out = None + pixel_samples = None + if not predicted_oom: + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") dims = samples_in.ndim - 2 if dims == 1: pixel_samples = self.decode_tiled_1d(samples_in)