Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing exaone3.5 #1480

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
83 changes: 83 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
__version__ = "2024.12.12"

__all__ = [
"load_correct_config",
"load_correct_model",
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
Expand Down Expand Up @@ -491,8 +493,89 @@ def UnslothModuleList(*args, **kwargs):
return
pass

def load_correct_config(config):

if config.model_type == 'exaone':
if Version(transformers_version) <= Version('4.47.1'):
raise RuntimeError("To use Exaone you have to compile transformers from scratch using:\
pip install git+https://github.com/huggingface/transformers.git")

from transformers.models.llama.modeling_llama import LlamaConfig

new_config_args = {
'vocab_size': config.vocab_size,
'hidden_size': config.hidden_size,
'intermediate_size': config.intermediate_size,
'num_hidden_layers': config.num_hidden_layers,
'num_attention_heads': config.num_attention_heads,
'num_key_value_heads': config.num_key_value_heads,
'hidden_act': config.activation_function,
'max_position_embeddings': config.max_position_embeddings,
'initializer_range': config.initializer_range,
'rms_norm_eps': config.layer_norm_epsilon,
'use_cache': config.use_cache,
'pad_token_id': config.pad_token_id,
'bos_token_id': config.bos_token_id,
'eos_token_id': config.eos_token_id,
'tie_word_embeddings': config.tie_word_embeddings,
'rope_theta': config.rope_theta,
'rope_scaling': config.rope_scaling,
'attention_bias': False,
'attention_dropout': config.attention_dropout,
'mlp_bias': False,
'head_dim': config.head_dim,
'architectures': ['LlamaForCausalLM'],
'model_type': 'llama',
'torch_dtype': config.torch_dtype
}
config = LlamaConfig.from_dict(new_config_args)
return config
# =============================================

def load_correct_model(model, **model_kwargs):
if model.config.model_type == 'exaone':
new_config = load_correct_config(model.config)

# We need to provide quantization_config to the config as well
# https://github.com/huggingface/transformers/issues/35427
new_config.quantization_config = model_kwargs.pop("quantization_config", None)

from transformers.models.llama.modeling_llama import LlamaForCausalLM

# map the old state_dict keys to new ones
mapping = {
re.compile(r"^transformer\.wte\.weight$"): "model.embed_tokens.weight",
re.compile(r"^transformer\.ln_f\.weight$"): "model.norm.weight",
re.compile(r"^lm_head\.weight$"): "lm_head.weight",
re.compile(r"^transformer\.h\.(\d+)\.ln_1\.weight$") : "model.layers.{}.input_layernorm.weight",
re.compile(r"^transformer\.h\.(\d+)\.ln_2\.weight$") : "model.layers.{}.post_attention_layernorm.weight",
re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_0.weight$") : "model.layers.{}.mlp.gate_proj.weight",
re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_0.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)$") : "model.layers.{}.mlp.gate_proj.weight.{}",
re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_1.weight$") : "model.layers.{}.mlp.up_proj.weight",
re.compile(r"^transformer\.h\.(\d+).mlp.c_fc_1.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)$") : "model.layers.{}.mlp.up_proj.weight.{}",
re.compile(r"^transformer\.h\.(\d+).mlp.c_proj.weight$") : "model.layers.{}.mlp.down_proj.weight",
re.compile(r"^transformer\.h\.(\d+).mlp.c_proj.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)$") : "model.layers.{}.mlp.down_proj.weight.{}",
re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.(k_proj|v_proj|q_proj)\.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)") : "model.layers.{}.self_attn.{}.weight.{}",
re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.(k_proj|v_proj|q_proj)\.weight$") : "model.layers.{}.self_attn.{}.weight",
re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.out_proj\.weight$") : "model.layers.{}.self_attn.o_proj.weight",
re.compile(r"^transformer\.h\.(\d+)\.attn\.attention\.out_proj\.weight\.(absmax|quant_map|nested_absmax|nested_quant_map|quant_state\.\w+)") : "model.layers.{}.self_attn.o_proj.weight.{}"
}

old_state_dict = model.state_dict()
new_state_dict = {}

for key in old_state_dict:
for pattern in mapping:
match = pattern.match(key)
if match:
new_key = mapping[pattern].format(*match.groups())
new_state_dict[new_key] = old_state_dict[key]

assert len(old_state_dict) == len(new_state_dict), RuntimeError(f"The mapping of {model.__class__} into {new_model.__class__} should have the same length")
model = LlamaForCausalLM.from_pretrained(None, config=new_config, state_dict=new_state_dict, **model_kwargs)
return model

# =============================================
def prepare_model_for_kbit_training(
model : Any,
use_gradient_checkpointing : Optional = True,
Expand Down
30 changes: 20 additions & 10 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,8 +1610,11 @@ def from_pretrained(

assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)

model_config = AutoConfig.from_pretrained(model_name, token=token)
# Load the same config class for models that have the same architectures
model_config = load_correct_config(model_config)

# RoPE Scaling
model_config = AutoConfig.from_pretrained(model_name, token = token)
model_max_seq_length = model_config.max_position_embeddings

# Check if RoPE Scaling is even allowed
Expand Down Expand Up @@ -1672,18 +1675,25 @@ def from_pretrained(

# Cannot be None, since HF now checks for the config
if load_in_4bit: kwargs["quantization_config"] = bnb_config


model_kwargs = {
"device_map": device_map,
"torch_dtype": dtype,
"token": token,
"trust_remote_code": trust_remote_code,
"attn_implementation": "eager",
**kwargs,
}
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
# quantization_config = bnb_config,
token = token,
max_position_embeddings = max_position_embeddings,
trust_remote_code = trust_remote_code,
attn_implementation = "eager",
**kwargs,
**model_kwargs
)
# Load the same model class for models that have the same architectures
model = load_correct_model(
model,
**model_kwargs
)

# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def from_pretrained(

model_type = model_config.model_type

if model_type == "llama":
if model_type == "llama" or model_type == "exaone":
scaling_type = None
if getattr(model_config, "rope_scaling", None) is not None:
scaling_type1 = model_config.rope_scaling.get("type", None)
Expand Down
12 changes: 12 additions & 0 deletions unsloth/models/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,18 @@
"unsloth/Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
),
# "unsloth/EXAONE-3.5-2.4B-Instruct-bnb-4bit" : (
# "unsloth/EXAONE-3.5-2.4B-Instruct",
# "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct"
# ),
# "unsloth/EXAONE-3.5-7.8B-Instruct-bnb-4bit" : (
# "unsloth/EXAONE-3.5-7.8B-Instruct",
# "LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct"
# ),
# "unsloth/EXAONE-3.5-32B-Instruct-bnb-4bit" : (
# "unsloth/EXAONE-3.5-32B-Instruct",
# "LGAI-EXAONE/EXAONE-3.5-32B-Instruct"
# )
}

INT_TO_FLOAT_MAPPER = {}
Expand Down