Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow models to run with a user-provided dtype map instead of a single dtype #10301

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
loaded_sub_model = passed_class_obj[name]

else:
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like _ might be a bit unintuitive. Better to expose full dtype maps or in case partial ones are provided we default to torch.float32 for the rest of the components.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be default? Considering how it will work for integrations, instead of say {'transformer': torch.bfloat16, 'text_encoder': torch.float16, 'text_encoder_2': torch.float16, 'text_encoder_3': torch.float16} for SD3 and {'transformer': torch.bfloat16, 'text_encoder': torch.float16, 'text_encoder_2': torch.float16} for Flux. Not a big issue because components can be got from cls._get_signature_types().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no strong opinions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now it's renamed to default to be clearer, we can remove later if its not needed.

loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
Expand All @@ -562,7 +567,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
proxies=kwargs.get("proxies", None),
Expand All @@ -578,7 +583,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
module_name: compute_module_sizes(
module,
dtype=torch_dtype.get(module_name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype,
)[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). If
a component is not specifed and no default is set, `torch.float32` is used.
custom_pipeline (`str`, *optional*):

<Tip warning={true}>
Expand Down Expand Up @@ -921,14 +924,19 @@ def load_module(name, value):
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
hlky marked this conversation as resolved.
Show resolved Hide resolved
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
importable_classes=importable_classes,
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
provider=provider,
sess_options=sess_options,
device_map=current_device_map,
Expand Down
Loading