Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load model from state_dict + config when quantized #35427

Open
1 of 4 tasks
KareemMusleh opened this issue Dec 27, 2024 · 5 comments
Open
1 of 4 tasks

Can't load model from state_dict + config when quantized #35427

KareemMusleh opened this issue Dec 27, 2024 · 5 comments
Labels

Comments

@KareemMusleh
Copy link

System Info

I used a colab notebook

  • transformers version: 4.47.1
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.27.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Tensorflow version (GPU?): 2.17.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Using distributed or parallel set-up in script?: NO
  • Using GPU in script?: YES
  • GPU type: Tesla T4

Who can help?

@SunMarc
@MekkCyber
@SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Open Google Collab
  2. Then:
!pip install bitsandbytes
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

quantization_config = BitsAndBytesConfig(
  load_in_4bit              = True,
  bnb_4bit_use_double_quant = True,
  bnb_4bit_quant_type       = "nf4",
  bnb_4bit_compute_dtype    = torch.float32,
)
model_name = 'LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct'
exaone = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, quantization_config=quantization_config)
new_exaone = AutoModelForCausalLM.from_pretrained(None, config=exaone.config, state_dict=exaone.state_dict(), quantization_config=quantization_config)

I receive the following error

low_cpu_mem_usage was None, now default to True since model is quantized.


AttributeError Traceback (most recent call last)

in <cell line: 1>()
----> 1 new_exaone = AutoModelForCausalLM.from_pretrained(None, config=exaone.config, state_dict=exaone.state_dict())

3 frames

/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py in load_state_dict(checkpoint_file, is_quantized, map_location, weights_only)
500 Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
501 """
--> 502 if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
503 # Check format of the archive
504 with safe_open(checkpoint_file, framework="pt") as f:

AttributeError: 'NoneType' object has no attribute 'endswith'

Expected behavior

for the model to load with the quantized weights.

Also, I am trying to implement exaone into unsloth. And I want to ask your thoughts on the best way to load a model, that has it's own class defined in the huggingface repo, using another class. more specifically the exaone model has the same architecture as llama but it just has some config and layers renamed, which results in a different state_dict. My current approach, which seems to be working when the model isn't quantized, is to load the model in ExaoneForCausalLM using AutoModelForCausalLM and then create a LlamaConfig from it, we also create a rename the keys of the state_dict (we don't copy). And we provide both of them to the from_pretrained function. What are your thoughts on how to best solve this problem? Your help is greatly appreciated!

@MekkCyber
Copy link
Contributor

Hey @KareemMusleh, I managed to reproduce the issue, to solve it you can install transformers from source :
pip install git+https://github.com/huggingface/transformers.git.
For your question I'm not familiar with how things are implemented in unsloth, are you asking if we can load an exaone model as if it's a llama model ?

@KareemMusleh
Copy link
Author

Thanks for the quick response!
Yes, exaone has the same architecture as llama but they implemented it in modeling_exaone with different names for the layers and different names for the config options. I am asking how would someone do this in transformers, unsloth uses transformers.

@KareemMusleh
Copy link
Author

Also can you please point me to the commit that solves the issue, I couldn't find it. Just in case someone asks to make this backwards compatible

@KareemMusleh
Copy link
Author

just checked it, the solution works!

@KareemMusleh
Copy link
Author

KareemMusleh commented Dec 27, 2024

@MekkCyber, I checked it again today and it doesn't work. Again I am using colab:

!pip install -U bitsandbytes
!pip install git+https://github.com/huggingface/transformers.git
from transformers import __version__ as transformers_version
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaConfig
import torch

def load_correct_model(model, **model_kwargs):
    if model.config.model_type == 'exaone':
        import re
        new_model = AutoModelForCausalLM
        # pretrained_model_name_or_path = model.config._name_or_path

        # get the correct config
        new_config_args =  {
            'vocab_size': model.config.vocab_size,
            'hidden_size': model.config.hidden_size,
            'intermediate_size': model.config.intermediate_size,
            'num_hidden_layers': model.config.num_hidden_layers,
            'num_attention_heads': model.config.num_attention_heads,
            'num_key_value_heads': model.config.num_key_value_heads,
            'hidden_act': model.config.activation_function,
            'max_position_embeddings': model.config.max_position_embeddings,
            'initializer_range': model.config.initializer_range,
            'rms_norm_eps': model.config.layer_norm_epsilon,
            'use_cache': model.config.use_cache,
            'pad_token_id': model.config.pad_token_id,
            'bos_token_id': model.config.bos_token_id,
            'eos_token_id': model.config.eos_token_id,
            'tie_word_embeddings': model.config.tie_word_embeddings,
            'rope_theta': model.config.rope_theta,
            'rope_scaling': model.config.rope_scaling,
            'attention_bias': False,
            'attention_dropout': model.config.attention_dropout,
            'mlp_bias': False,
            'head_dim': model.config.head_dim,
            'architectures': ['LlamaForCausalLM'],
            'model_type': 'llama',
            'torch_dtype': model.config.torch_dtype
        }
        new_config = LlamaConfig.from_dict(new_config_args)

        mapping = {
            re.compile(r"^transformer\.wte\.weight$"): "model.embed_tokens.weight",
            re.compile(r"^transformer\.ln_f\.weight$"): "model.norm.weight",
            re.compile(r"^lm_head\.weight$"): "lm_head.weight",
            re.compile(r"^transformer\.h\.(\d+)\.ln_1\.weight$") : "model.layers.{}.input_layernorm.weight",
            re.compile(r"^transformer\.h\.(\d+)\.ln_2\.weight$") : "model.layers.{}.post_attention_layernorm.weight",
            re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_0.weight$") : "model.layers.{}.mlp.gate_proj.weight",
            re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_0.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)$") : "model.layers.{}.mlp.gate_proj.weight.{}",
            re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_1.weight$") : "model.layers.{}.mlp.up_proj.weight",
            re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_1.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)$") : "model.layers.{}.mlp.up_proj.weight.{}",
            re.compile(r"^transformer\.h\.(\d+).mlp.c_proj.weight$") : "model.layers.{}.mlp.down_proj.weight",
            re.compile(r"^transformer\.h\.(\d+).mlp.c_proj.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)$") : "model.layers.{}.mlp.down_proj.weight.{}",
            re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.(k_proj|v_proj|q_proj)\.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)") : "model.layers.{}.self_attn.{}.weight.{}",
            re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.(k_proj|v_proj|q_proj)\.weight") : "model.layers.{}.self_attn.{}.weight",
            re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.out_proj\.weight") : "model.layers.{}.self_attn.o_proj.weight",
            re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.out_proj\.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)") : "model.layers.{}.self_attn.o_proj.weight.{}"
        }

        old_state_dict = model.state_dict()
        new_state_dict = {}

        for key in old_state_dict:
            for pattern in mapping:
                match = pattern.match(key)
                if match:
                    new_key = mapping[pattern].format(*match.groups())
                    new_state_dict[new_key] = old_state_dict[key]
        assert len(old_state_dict) == len(new_state_dict), RuntimeError(f"The mapping of {model.__class__} into {new_model.__class__} should have the same length")
        model = new_model.from_pretrained(None, config=new_config, state_dict=new_state_dict, **model_kwargs)
    return model


quantization_config = BitsAndBytesConfig(
  load_in_4bit              = True,
  bnb_4bit_use_double_quant = True,
  bnb_4bit_quant_type       = "nf4",
  bnb_4bit_compute_dtype    = torch.float32,
)
model_kwargs = {
    "quantization_config": quantization_config,
    "device_map": "sequential",
    "trust_remote_code": True,
}
model_name = 'LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct'
exaone = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
load_correct_model(exaone, **model_kwargs)

providing quantization_config raises ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8. And not providing quantization raises ValueError: Trying to set a tensor of shape torch.Size([170]) in "weight" (which has shape torch.Size([640, 2560])), this looks incorrect.

I've checked the state_dict and compared it against a llama model, it seems to me to be correct. I think the problem might be with some param that the config is expected to have.

Edit 1: Changing new_model to LlamaForCausalLM throws the same error
Edit 2: Adding quantization_config to the config seems to solve the problem but there are still issues with inference.

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Choose your prompt
prompt = "Explain how wonderful you are"  # English example

messages = [
    {"role": "system", 
     "content": "You are EXAONE model from LG AI Research, a helpful assistant."},
    {"role": "user", "content": prompt}
]
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
)

output = model.generate(
    input_ids.to("cuda"),
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=128,
    do_sample=False,
)

I get the following error: RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants