Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pipeline_stable_diffusion_instruct_pix2pix.py to use pytorch_xla and from single safetensor. #10355

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -165,6 +166,7 @@
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -633,6 +635,12 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"

elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
):
model_type = "instruct-pix2pix"

else:
model_type = "v1"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,23 @@

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
Expand Down Expand Up @@ -457,6 +465,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,30 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionSafetyChecker


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s


class TextToVideoZeroPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
):
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
Expand Down Expand Up @@ -440,6 +458,10 @@ def backward_loop(
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

return latents.clone().detach()

# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
Expand Down
Loading