You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I try to add the VAE decoder in SD2.1 to my training pipeline, I encountered a OOM error. After careful inspection, I found that the decoder really take a vast amount of memory. If input is in the shape of [1,4,96,96], the memory consumption is already 15G. If I increase the batch size, this value is even bigger.
Reproduction
`import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from diffusers import AutoencoderKL
import torch
Describe the bug
When I try to add the VAE decoder in SD2.1 to my training pipeline, I encountered a OOM error. After careful inspection, I found that the decoder really take a vast amount of memory. If input is in the shape of [1,4,96,96], the memory consumption is already 15G. If I increase the batch size, this value is even bigger.
Reproduction
`import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from diffusers import AutoencoderKL
import torch
vae = AutoencoderKL.from_pretrained(
'stabilityai/stable-diffusion-2-1', subfolder="vae"
)
#vae.encoder.requires_grad_(False)
vae.to('cuda', dtype=torch.float16)
batch=4, channel=4, h,w=96,96, this is the shape of latent
x0=torch.randn((4,4,96,96), device='cuda', dtype=torch.float16)
while True:
#latents = vae.encode(x0).latent_dist.sample()
#print(latents.shape)
x1= vae.decode(x0/vae.config.scaling_factor, return_dict=False)[0]`
Logs
No response
System Info
NVIDIA RTX 5880 Ada Generation, 49140 MiB
NVIDIA RTX 5880 Ada Generation, 49140 MiB
NVIDIA RTX 5880 Ada Generation, 49140 MiB
NVIDIA RTX 5880 Ada Generation, 49140 MiB
NVIDIA RTX 5880 Ada Generation, 49140 MiB
NVIDIA RTX 5880 Ada Generation, 49140 MiB
NVIDIA RTX 5880 Ada Generation, 49140 MiB
Who can help?
@yiyixuxu @DN6
The text was updated successfully, but these errors were encountered: