Skip to content

Commit

Permalink
Merge branch 'main' into original-lora-hunyuan-video
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Dec 28, 2024
2 parents 893b9c0 + 55ac1db commit 35d62aa
Show file tree
Hide file tree
Showing 11 changed files with 521 additions and 135 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ jobs:
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
runs-on:
group: aws-g6e-xlarge-plus
container:
Expand Down
62 changes: 62 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
The example below only quantizes the weights to int8.

```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/FLUX.1-dev"
Expand All @@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
)
pipe.to("cuda")

# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
Expand Down Expand Up @@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use

Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.

## Serializing and Deserializing quantized models

To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.

```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
```

To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.

```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel

transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```

Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.

```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...

# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```

## Resources

- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
Expand Down
56 changes: 56 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,50 @@ def unload_lora_weights(self):
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None

if getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()

for param_name in overwritten_params:
if param_name.endswith(".weight"):
module_names.add(param_name.replace(".weight", ""))

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

current_param_weight = overwritten_params[f"{name}.weight"]
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
with torch.device("meta"):
original_module = torch.nn.Linear(
in_features,
out_features,
bias=bias,
dtype=module_weight.dtype,
)

tmp_state_dict = {"weight": current_param_weight}
if module_bias is not None:
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
setattr(parent_module, current_module_name, original_module)

del tmp_state_dict

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(current_param_weight.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

@classmethod
def _maybe_expand_transformer_param_shape_or_error_(
cls,
Expand All @@ -2313,6 +2357,8 @@ def _maybe_expand_transformer_param_shape_or_error_(

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
Expand Down Expand Up @@ -2387,6 +2433,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

# For `unload_lora_weights()`.
# TODO: this could lead to more memory overhead if the number of overwritten params
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
overwritten_params[f"{current_module_name}.weight"] = module_weight
if module_bias is not None:
overwritten_params[f"{current_module_name}.bias"] = module_bias

if len(overwritten_params) > 0:
transformer._overwritten_params = overwritten_params

return has_param_with_shape_update

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer = None

if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
)

hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
Expand Down Expand Up @@ -820,7 +819,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None and is_bnb_quantization_method:
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -564,6 +571,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if XLA_AVAILABLE:
xm.mark_step()

if output_type == "latent":
image = latents
else:
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
USE_PEFT_BACKEND,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
Expand All @@ -46,6 +47,13 @@
from .pipeline_output import SanaPipelineOutput


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

if is_bs4_available():
Expand Down Expand Up @@ -864,6 +872,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if XLA_AVAILABLE:
xm.mark_step()

if output_type == "latent":
image = latents
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,21 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
self.tokenizer_max_length = self.tokenizer.model_max_length
self.default_sample_size = self.transformer.config.sample_size
self.patch_size = (
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,28 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_latent_channels=self.vae.config.latent_channels,
vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
self.tokenizer_max_length = self.tokenizer.model_max_length
self.default_sample_size = self.transformer.config.sample_size
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
self.patch_size = (
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type

if quant_type.startswith("int"):
if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
Expand Down
66 changes: 66 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,72 @@ def test_load_regular_lora(self):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))

def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)

# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4

dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

control_pipe.unload_lora_weights()
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down
Loading

0 comments on commit 35d62aa

Please sign in to comment.