Skip to content

Commit

Permalink
fix config format for transformers (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy authored Nov 18, 2024
1 parent 3ff720b commit f49341e
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ requires = [
build-backend = "setuptools.build_meta"

[tool.ruff]
# Allow lines to be as long as 80.
# Allow lines to be as long as 120.
line-length = 120
exclude = [
# External file, leaving license intact
Expand Down
4 changes: 3 additions & 1 deletion vptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
# --------------------------------------------------------------------------

__version__ = "0.0.3"
from vptq.layers import AutoModelForCausalLM as AutoModelForCausalLM
from vptq.layers import AutoModelForCausalLM, VQuantLinear

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
4 changes: 3 additions & 1 deletion vptq/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from vptq.layers.model_base import AutoModelForCausalLM as AutoModelForCausalLM
from vptq.layers.model_base import AutoModelForCausalLM, VQuantLinear

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
4 changes: 2 additions & 2 deletions vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = cls.from_config(auto_conf, *model_args, **cls_kwargs)

target_layer = VQuantLinear
quant_config = auto_conf.quant_config
quantization_config = auto_conf.quantization_config

# replace linear layers with quantized linear layers
with transformers.utils.generic.ContextManagers([accelerate.init_empty_weights()]):
make_quant_linear(model, quant_config, target_layer=target_layer)
make_quant_linear(model, quantization_config, target_layer=target_layer)

no_split_module_classes = [i[1].__class__.__name__ for i in model.named_modules() if i[0].endswith(".0")]

Expand Down

0 comments on commit f49341e

Please sign in to comment.