diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 6730f1551607..e1829bc409eb 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -37,7 +37,8 @@ The following SkyReels-V2 models are supported in Diffusers: - [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers) - [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers) - [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers) -- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers) + +This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). > [!TIP] > Click on the SkyReels-V2 models in the right sidebar for more examples of video generation. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index d6cd7d7feceb..1b1c8ee097c5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -545,22 +545,24 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 089f92632d38..4bc0d0aaea83 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -887,25 +887,28 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 2951a9447386..3e2004533258 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -966,25 +966,28 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 6fedfc795a40..234ec531b862 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -974,25 +974,28 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index d61b687eadc3..d1df7f5f34cb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -678,24 +678,26 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1