From d09911f0620cbd30987931852746534d1b080f6d Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 7 Nov 2022 17:43:57 +0100 Subject: [PATCH 1/2] Fix cpu offloading --- src/diffusers/pipeline_utils.py | 2 +- .../stable_diffusion/test_stable_diffusion.py | 5 +++-- .../test_stable_diffusion_img2img.py | 15 ++++----------- .../test_stable_diffusion_inpaint.py | 17 +++++++++-------- 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 628e6320129b..9359c275f974 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -231,7 +231,7 @@ def device(self) -> torch.device: module = getattr(self, name) if isinstance(module, torch.nn.Module): if module.device == torch.device("meta"): - return torch.device("cpu") + return torch.device("cuda" if torch.cuda.is_available() else "cpu") return module.device return torch.device("cpu") diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index a83299eaf9f3..679b5c214bb2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -839,7 +839,6 @@ def test_stable_diffusion_low_cpu_mem_usage(self): assert 2 * low_cpu_mem_usage_time < normal_load_time - @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -848,10 +847,12 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): prompt = "Andromeda galaxy in a bottle" pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) pipeline.enable_attention_slicing(1) pipeline.enable_sequential_cpu_offload() - _ = pipeline(prompt, num_inference_steps=5) + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipeline(prompt, generator=generator, num_inference_steps=5) mem_bytes = torch.cuda.max_memory_allocated() # make sure that less than 1.5 GB is allocated diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 2d29e1b80644..6d5c6feab5bc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -603,25 +603,18 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() init_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/img2img/sketch-mountains-input.jpg" ) - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/fantasy_landscape_k_lms.png" - ) init_image = init_image.resize((768, 512)) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 model_id = "CompVis/stable-diffusion-v1-4" lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - model_id, - scheduler=lms, - safety_checker=None, - device_map="auto", + model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -642,5 +635,5 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): ) mem_bytes = torch.cuda.max_memory_allocated() - # make sure that less than 1.5 GB is allocated - assert mem_bytes < 1.5 * 10**9 + # make sure that less than 2.2 GB is allocated + assert mem_bytes < 2.2 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e8dcb43163da..5fcdd71dd6e4 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -384,6 +384,7 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self): def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() init_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" @@ -393,16 +394,16 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" ) - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png" - ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 model_id = "runwayml/stable-diffusion-inpainting" pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( - model_id, safety_checker=None, scheduler=pndm, device_map="auto" + model_id, + safety_checker=None, + scheduler=pndm, + device_map="auto", + revision="fp16", + torch_dtype=torch.float16, ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -422,5 +423,5 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): ) mem_bytes = torch.cuda.max_memory_allocated() - # make sure that less than 1.5 GB is allocated - assert mem_bytes < 1.5 * 10**9 + # make sure that less than 2.2 GB is allocated + assert mem_bytes < 2.2 * 10**9 From 497cb5b7af676ef0d43df0d09a46f5231bd68685 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 8 Nov 2022 23:03:35 +0100 Subject: [PATCH 2/2] get offloaded devices locally for SD pipelines --- src/diffusers/pipeline_utils.py | 2 - .../pipeline_stable_diffusion.py | 40 +++++++++++----- .../pipeline_stable_diffusion_img2img.py | 37 +++++++++++---- .../pipeline_stable_diffusion_inpaint.py | 47 +++++++++++++------ .../stable_diffusion/test_stable_diffusion.py | 5 +- 5 files changed, 91 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 9359c275f974..240f46533a27 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -230,8 +230,6 @@ def device(self) -> torch.device: for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): - if module.device == torch.device("meta"): - return torch.device("cuda" if torch.cuda.is_available() else "cpu") return module.device return torch.device("cpu") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9c7edabf69f1..8afb29359e5d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -181,6 +181,24 @@ def enable_sequential_cpu_offload(self): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + @torch.no_grad() def __call__( self, @@ -272,6 +290,8 @@ def __call__( f" {type(callback_steps)}." ) + device = self._execution_device + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -288,7 +308,7 @@ def __call__( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -328,7 +348,7 @@ def __call__( truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -348,24 +368,22 @@ def __call__( latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: - if self.device.type == "mps": + if device.type == "mps": # randn does not work reproducibly on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device) else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(device) # set timesteps self.scheduler.set_timesteps(num_inference_steps) # Some schedulers like PNDM have timesteps as arrays # It's more optimized to move all timesteps to correct device beforehand - timesteps_tensor = self.scheduler.timesteps.to(self.device) + timesteps_tensor = self.scheduler.timesteps.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -413,9 +431,7 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).float().numpy() if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f6f38ab1d376..d69e9db77343 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -169,6 +169,25 @@ def enable_sequential_cpu_offload(self): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. @@ -278,6 +297,8 @@ def __call__( f" {type(callback_steps)}." ) + device = self._execution_device + # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -300,7 +321,7 @@ def __call__( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) @@ -334,7 +355,7 @@ def __call__( truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] # duplicate unconditional embeddings for each generation per prompt seq_len = uncond_embeddings.shape[1] @@ -348,7 +369,7 @@ def __call__( # encode the init image into latents and scale the latents latents_dtype = text_embeddings.dtype - init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_image = init_image.to(device=device, dtype=latents_dtype) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents @@ -379,10 +400,10 @@ def __call__( init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=device) # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=latents_dtype) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -405,7 +426,7 @@ def __call__( # Some schedulers like PNDM have timesteps as arrays # It's more optimized to move all timesteps to correct device beforehand - timesteps = self.scheduler.timesteps[t_start:].to(self.device) + timesteps = self.scheduler.timesteps[t_start:].to(device) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance @@ -434,9 +455,7 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a7af1c9d3351..6a532db5cd1a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -169,6 +169,25 @@ def enable_sequential_cpu_offload(self): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. @@ -289,6 +308,8 @@ def __call__( f" {type(callback_steps)}." ) + device = self._execution_device + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -305,7 +326,7 @@ def __call__( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -345,7 +366,7 @@ def __call__( truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -365,17 +386,15 @@ def __call__( latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: - if self.device.type == "mps": + if device.type == "mps": # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device) else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(device) # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image(image, mask_image) @@ -384,9 +403,9 @@ def __call__( # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) - mask = mask.to(device=self.device, dtype=text_embeddings.dtype) + mask = mask.to(device=device, dtype=text_embeddings.dtype) - masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) + masked_image = masked_image.to(device=device, dtype=text_embeddings.dtype) # encode the mask image into latents space so we can concatenate it to the latents masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) @@ -402,7 +421,7 @@ def __call__( ) # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype) + masked_image_latents = masked_image_latents.to(device=device, dtype=text_embeddings.dtype) num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] @@ -421,7 +440,7 @@ def __call__( # Some schedulers like PNDM have timesteps as arrays # It's more optimized to move all timesteps to correct device beforehand - timesteps_tensor = self.scheduler.timesteps.to(self.device) + timesteps_tensor = self.scheduler.timesteps.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -473,9 +492,7 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).float().numpy() if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) ) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 679b5c214bb2..252b02806ae0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -842,6 +842,7 @@ def test_stable_diffusion_low_cpu_mem_usage(self): def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() pipeline_id = "CompVis/stable-diffusion-v1-4" prompt = "Andromeda galaxy in a bottle" @@ -855,5 +856,5 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): _ = pipeline(prompt, generator=generator, num_inference_steps=5) mem_bytes = torch.cuda.max_memory_allocated() - # make sure that less than 1.5 GB is allocated - assert mem_bytes < 1.5 * 10**9 + # make sure that less than 2.8 GB is allocated + assert mem_bytes < 2.8 * 10**9