Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
YangWang92 committed Sep 20, 2024
1 parent f7ef768 commit 57f7073
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 119 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"yapf.args":["--style={based_on_s'tyle: google, column_limit: 120, indent_width: 4}"]
}
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
yapf --recursive . --style={based_on_s'tyle: google, column_limit: 120, indent_width: 4}' -i
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def build_cuda_extensions():
"--expt-extended-lambda",
"--use_fast_math",
"-lineinfo",
]
+ arch_flags,
] + arch_flags,
"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
}
extensions = CUDAExtension(
Expand Down
36 changes: 16 additions & 20 deletions tools/convert_finetune_weights_to_hf_packed_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def pack_index(

# upcast the indice to uint64 to avoid overflow on signed bit
if res_indice is not None:
merged_indice = (res_indice.view(index_dtype).to(torch.uint64).view(torch.int64) << index_bits) | indice.view(
index_dtype
).to(torch.uint64).view(torch.int64)
merged_indice = (res_indice.view(index_dtype).to(torch.uint64).view(torch.int64) <<
index_bits) | indice.view(index_dtype).to(torch.uint64).view(torch.int64)
else:
merged_indice = indice.view(index_dtype).to(torch.uint64).view(torch.int64)

Expand Down Expand Up @@ -136,7 +135,8 @@ def dtype_convert(data, from_dtype, to_dtype, as_type):


def convert_idx_dtype(model, from_dtype, to_dtype, as_type):
print(f"converting model indices from {from_dtype} " f"to {to_dtype} as {as_type}")
print(f"converting model indices from {from_dtype} "
f"to {to_dtype} as {as_type}")

quant_config = {}
for mod_name, sub_mod in model.named_modules():
Expand All @@ -148,29 +148,25 @@ def convert_idx_dtype(model, from_dtype, to_dtype, as_type):
# f'dtype: {sub_mod.indices.dtype}')

if sub_mod.indices.dtype == torch.int64:
sub_mod.indices.data = dtype_convert(
sub_mod.indices.data, sub_mod.indices.data.dtype, to_dtype, as_type
)
sub_mod.indices.data = dtype_convert(sub_mod.indices.data, sub_mod.indices.data.dtype, to_dtype,
as_type)
else:
sub_mod.indices.data = dtype_convert(sub_mod.indices.data, from_dtype, to_dtype, as_type)

if hasattr(sub_mod, "res_indices") and sub_mod.res_indices is not None:
if sub_mod.res_indices.dtype == torch.int64:
sub_mod.res_indices.data = dtype_convert(
sub_mod.res_indices.data, sub_mod.res_indices.data.dtype, to_dtype, as_type
)
sub_mod.res_indices.data = dtype_convert(sub_mod.res_indices.data, sub_mod.res_indices.data.dtype,
to_dtype, as_type)
else:
sub_mod.res_indices.data = dtype_convert(sub_mod.res_indices.data, from_dtype, to_dtype, as_type)

if hasattr(sub_mod, "outlier_indices") and sub_mod.outlier_indices is not None:
if sub_mod.outlier_indices.dtype == torch.int64:
sub_mod.outlier_indices.data = dtype_convert(
sub_mod.outlier_indices.data, sub_mod.outlier_indices.data.dtype, to_dtype, as_type
)
sub_mod.outlier_indices.data = dtype_convert(sub_mod.outlier_indices.data,
sub_mod.outlier_indices.data.dtype, to_dtype, as_type)
else:
sub_mod.outlier_indices.data = dtype_convert(
sub_mod.outlier_indices.data, from_dtype, to_dtype, as_type
)
sub_mod.outlier_indices.data = dtype_convert(sub_mod.outlier_indices.data, from_dtype, to_dtype,
as_type)

if sub_mod.perm.dtype == torch.int64:
sub_mod.perm.data = dtype_convert(sub_mod.perm.data, sub_mod.perm.data.dtype, to_dtype, as_type)
Expand Down Expand Up @@ -220,7 +216,8 @@ def eval_ppl(qmodel, config):
assert True, "opt is not supported"
print(f"ppl_{dataset}: {ppl}")

print(f'end time: {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())},' f'duration: {time.time()-tick} seconds')
print(f'end time: {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())},'
f'duration: {time.time()-tick} seconds')


if __name__ == "__main__":
Expand Down Expand Up @@ -286,9 +283,8 @@ def load_model_shards(state_path):

print(f"model seqlen {qmodel.seqlen}")

if qmodel.seqlen != 4096 and (
"llama2" in config.model_args.model_name.lower() or "llama-2" in config.model_args.model_name.lower()
):
if qmodel.seqlen != 4096 and ("llama2" in config.model_args.model_name.lower() or
"llama-2" in config.model_args.model_name.lower()):
print("WARNING! LLama-2 model should set seqlen=4096")
qmodel.eval()

Expand Down
9 changes: 5 additions & 4 deletions vptq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def define_basic_args():
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--model", type=str, required=True, help="float/float16 model to load, such as [mosaicml/mpt-7b]"
)
parser.add_argument("--model",
type=str,
required=True,
help="float/float16 model to load, such as [mosaicml/mpt-7b]")
parser.add_argument("--tokenizer", type=str, default="", help="default same as [model]")
parser.add_argument("--prompt", type=str, default="once upon a time, there ", help="prompt to start generation")
parser.add_argument("--chat", action="store_true", help="chat with the model")
Expand All @@ -39,7 +40,7 @@ def chat_loop(model, tokenizer):
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to("cuda")
generated_ids = model.generate(model_inputs, pad_token_id=2, max_new_tokens=500, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[-1] :], skip_special_tokens=True)
decoded = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[-1]:], skip_special_tokens=True)
messages.append({"role": "assistant", "content": decoded[0]})
print("assistant:", decoded[0])

Expand Down
13 changes: 7 additions & 6 deletions vptq/ist/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def set_op_by_name(layer, name, new_module):


def make_quant_linear(module, quant_conf, name="", target_layer=None):
for module_name, sub_module in tqdm(
module.named_modules(), total=len(list(module.named_modules())), desc="Replacing linear layers..."
):
for module_name, sub_module in tqdm(module.named_modules(),
total=len(list(module.named_modules())),
desc="Replacing linear layers..."):
if module_name in quant_conf:
layer_conf = quant_conf[module_name]
new_module = target_layer(**layer_conf, enable_proxy_error=False, dtype=sub_module.weight.dtype)
Expand All @@ -44,6 +44,7 @@ def make_quant_linear(module, quant_conf, name="", target_layer=None):


class AutoModelForCausalLM(transformers.AutoModelForCausalLM):

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
init_contexts = [
Expand Down Expand Up @@ -74,9 +75,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
checkpoint = pretrained_model_name_or_path
else: # remote
token_arg = {"token": kwargs.get("token", None)}
checkpoint = huggingface_hub.snapshot_download(
repo_id=pretrained_model_name_or_path, ignore_patterns=["*.bin"], **token_arg
)
checkpoint = huggingface_hub.snapshot_download(repo_id=pretrained_model_name_or_path,
ignore_patterns=["*.bin"],
**token_arg)
weight_bins = glob.glob(str(Path(checkpoint).absolute() / "*.safetensors"))
index_json = glob.glob(str(Path(checkpoint).absolute() / "*.index.json"))
pytorch_model_bin = glob.glob(str(Path(checkpoint).absolute() / "pytorch_model.bin"))
Expand Down
Loading

0 comments on commit 57f7073

Please sign in to comment.