diff --git a/comfy/sd.py b/comfy/sd.py index e2af7078121..144cf68b8f2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -348,11 +348,19 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) def decode(self, samples_in): + predicted_oom = False + samples = None + out = None pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) + logging.debug(f"Free memory: {free_memory} bytes, predicted memory useage of one batch: {memory_used} bytes") + if free_memory < memory_used: + logging.warning("Warning: Out of memory is predicted for regular VAE decoding, directly switch to tiled VAE decoding.") + predicted_oom = True + raise model_management.OOM_EXCEPTION batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -363,7 +371,11 @@ def decode(self, samples_in): pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out except model_management.OOM_EXCEPTION as e: - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + samples = None + out = None + pixel_samples = None + if not predicted_oom: + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") dims = samples_in.ndim - 2 if dims == 1: pixel_samples = self.decode_tiled_1d(samples_in)