Skip to content

Commit

Permalink
fix: a small bug fix for the initialization of the residual index ten…
Browse files Browse the repository at this point in the history
…sor. (#147)

* Fixed a small bug in the initialization of the residual index tensor.

* Modified the README to prevent a single line of code from being too
long to display on a single line.
  • Loading branch information
lcy-seso authored Dec 26, 2024
1 parent c951bf5 commit 170770c
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 147 deletions.
31 changes: 20 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

**Efficient, Flexible and Compressing LLM in less than 2bits**


[Get Started](#installation) | [Technical Report](https://arxiv.org/pdf/2409.17066)

</div>
Expand Down Expand Up @@ -39,7 +38,7 @@

## TL;DR

**Vector Post-Training Quantization (VPTQ)** is a novel Post-Training Quantization method that leverages **Vector Quantization** to high accuracy on LLMs at an extremely low bit-width (<2-bit).
**Vector Post-Training Quantization (VPTQ)** is a novel Post-Training Quantization method that leverages **Vector Quantization** to high accuracy on LLMs at an extremely low bit-width (<2-bit).
VPTQ can compress 70B, even the 405B model, to 1-2 bits without retraining and maintain high accuracy.

* Better Accuracy on 1-2 bits, (405B @ <2bit, 70B @ 2bit)
Expand Down Expand Up @@ -180,37 +179,47 @@ python -m vptq --model=VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-wo
```
![Llama3 1-70b-chat](https://github.com/user-attachments/assets/af051234-d1df-4e25-95e7-17a5ce98f3ea)


### Huggingface Transformers API Example:

**Now, huggingface transformers main branch supports VPTQ**:

```python
#! pip install git+https://github.com/huggingface/transformers.git -U
#! pip install vptq -U

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "VPTQ-community/Meta-Llama-3.3-70B-Instruct-v16-k65536-65536-woft"
# Load VPTQ-quantized model directly from HuggingFace Hub
model = AutoModelForCausalLM.from_pretrained("VPTQ-community/Meta-Llama-3.3-70B-Instruct-v16-k65536-65536-woft", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("VPTQ-community/Meta-Llama-3.3-70B-Instruct-v16-k65536-65536-woft")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Simple inference
prompt = "Explain: Do not go gentle into that good night."
output = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device))
output = model.generate(
**tokenizer(prompt, return_tensors="pt").to(model.device)
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```


### Python API Example from VPTQ package:
Using the Python API from VPTQ package:

```python
import vptq
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-woft")
m = vptq.AutoModelForCausalLM.from_pretrained("VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-woft", device_map='auto')

inputs = tokenizer("Explain: Do Not Go Gentle into That Good Night", return_tensors="pt").to("cuda")
out = m.generate(**inputs, max_new_tokens=100, pad_token_id=2)
model_name = "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-woft"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
m = vptq.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

prompt = "Explain: Do Not Go Gentle into That Good Night"
out = m.generate(
**tokenizer(prompt, return_tensors="pt").to("cuda"),
max_new_tokens=100,
pad_token_id=2
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
```

Expand Down
113 changes: 55 additions & 58 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ __global__ void WqA16WithOutliers_PackIndice(
const scalar_t* bias, int out_features, int in_features,
int outliers_infeatures, const int index_stride_0, const int index_stride_1,
const int centroids_stride_0, const int group_nums) {
static_assert((GROUPSIZE & 1) == 0, "GROUPSIZE must be even ");
static_assert((GROUPSIZE & 1) == 0, "GROUPSIZE must be even.");

int bidx = blockIdx.x; // out_features//base_groupsize
int bidy = blockIdx.y; // batch
int bidz = blockIdx.z; // segment in_features
Expand All @@ -46,6 +47,7 @@ __global__ void WqA16WithOutliers_PackIndice(
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / WARP_SIZE + 1];
scalar_t tmp_output[GROUPSIZE];
const scalar_t zero_value = ZERO_VALUE(scalar_t());

#pragma unroll
for (int i = 0; i < GROUPSIZE; i++) {
tmp_output[i] = zero_value;
Expand Down Expand Up @@ -147,12 +149,7 @@ __global__ void WqA16WithOutliers_PackIndice(
} else {
hres_ptr = (VecType*)base;
}
// scalar_t* res = (scalar_t*)hres;
// #pragma unroll
// for (int gi=0;gi<GROUPSIZE;gi++){
// tmp_output[gi] = __hfma(res[gi], input_v, tmp_output[gi]);
// tmp_output[gi] += bias;
// }

VecType* h2_tmp_output = (VecType*)tmp_output;
#pragma unroll
for (int gi = 0; gi < GROUPSIZE / 2; gi++) {
Expand Down Expand Up @@ -315,18 +312,25 @@ __global__ void DequantizeWithOutliers_PackIndice(
}
}

torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
const int* outf_x_inf, //[out_f, in_f]
const torch::Tensor& q_indice, //[num_cen, o_c_size, in_inf]
const torch::Tensor& centroids, //[num_c, c_size, vec_len]
const c10::optional<torch::Tensor>&
q_indice_residual, //[num_cen, o_c_size, in_inf]
const c10::optional<torch::Tensor>&
residual_centroids, //[num_c, c_size, vec_len]
const c10::optional<torch::Tensor>&
outliers_indices, //[num_cen, c_size, ol_in_f]
const c10::optional<torch::Tensor>&
outliers_centroids, //[num_c, c_size, out_vec_len]
// @brief launch_deqantize_outliers_cuda_packkernel
// @param outf_x_inf [out_f, in_f]
// @param q_indice [num_cen, o_c_size, in_inf]
// @param centroids [num_c, c_size, vec_len]
// @param q_indice_residual [num_cen, o_c_size, in_inf]
// @param residual_centroids [num_c, c_size, vec_len]
// @param outliers_indices [num_cen, c_size, ol_in_f]
// @param outliers_centroids [num_c, c_size, out_vec_len]
// @param perm
// @param weight_scale
// @param weight_bias
// @return torch::Tensor
torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const int* outf_x_inf, const torch::Tensor& q_indice,
const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
const c10::optional<torch::Tensor>& outliers_indices,
const c10::optional<torch::Tensor>& outliers_centroids,
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias) {
OptionalCUDAGuard cudaguard(q_indice.device().index());
Expand All @@ -335,8 +339,9 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
residual_centroids.has_value() ? residual_centroids.value().size(-1) : 0;
TORCH_CHECK(((res_groupsize == base_groupsize) || (res_groupsize == 0)),
"res_groupsize==base_groupsize is false, must be true");
int index_bits =
log2(centroids.size(1)); // how many bits to index quantization vector

// how many bits to index quantization vector
int index_bits = log2(centroids.size(1));
int res_index_bits = residual_centroids.has_value()
? log2(residual_centroids.value().size(1))
: 0;
Expand All @@ -346,8 +351,10 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
cuda::kBlockSize));
dim3 threads(cuda::kBlockSize);
torch::Tensor output;
constexpr bool out_ouf_inf = true; // why =false is 10 times slow?
if (out_ouf_inf) { // out_ouf_inf

// FIXME: why =false is 10 times slow?
constexpr bool out_ouf_inf = true;
if (out_ouf_inf) { // out_ouf_inf
output = at::empty({out_size[0], out_size[1]}, centroids.options());
} else {
output = at::empty({out_size[1], out_size[0]}, centroids.options());
Expand All @@ -364,7 +371,9 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
outliers_indices.has_value()
? outliers_indices.value().data_ptr<int16_t>()
: nullptr;

auto stream = at::cuda::getCurrentCUDAStream().stream();

#define callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, \
ResidualBits) \
{ \
Expand Down Expand Up @@ -510,18 +519,27 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
}
}

torch::Tensor lauch_gemv_outliers_cuda_packkernel(
// @brief launch_gemv_outliers_cuda_packkernel
// @param out_features
// @param input
// @param q_indice [num_cen, o_c_size, in_inf]
// @param centroids [num_c, c_size, vec_len]
// @param q_indice_residual [num_cen, o_c_size, in_inf]
// @param residual_centroids [num_c, c_size, vec_len]
// @param outliers_indices [num_cen, c_size, ol_in_f]
// @param outliers_centroids [num_c, c_size, out_vec_len]
// @param perm
// @param weight_scale
// @param weight_bias
// @param bias
// @return torch::Tensor
torch::Tensor launch_gemv_outliers_cuda_packkernel(
const int out_features, const torch::Tensor& input,
const torch::Tensor& q_indice, //[num_cen, o_c_size, in_inf]
const torch::Tensor& centroids, //[num_c, c_size, vec_len]
const c10::optional<torch::Tensor>&
q_indice_residual, //[num_cen, o_c_size, in_inf]
const c10::optional<torch::Tensor>&
residual_centroids, //[num_c, c_size, vec_len]
const c10::optional<torch::Tensor>&
outliers_indices, //[num_cen, c_size, ol_in_f]
const c10::optional<torch::Tensor>&
outliers_centroids, //[num_c, c_size, out_vec_len]
const torch::Tensor& q_indice, const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
const c10::optional<torch::Tensor>& outliers_indices,
const c10::optional<torch::Tensor>& outliers_centroids,
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias) {
Expand All @@ -533,17 +551,16 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
: 0;

const int in_features = input.size(-1);
// const int out_features = output.size(-1);

auto output_shape = input.sizes().vec();
output_shape[input.dim() - 1] = out_features;
torch::Tensor output;
// blocks = (out_features, batch)

dim3 blocks(cuda::ceil_div(out_features, base_groupsize),
input.numel() / in_features);
dim3 threads(cuda::kBlockSize);
auto stream = at::cuda::getCurrentCUDAStream().stream();
// using scalar_t = c10::Half;
// c10::BFloat16

int shared_memory_size = 2 * in_features * 2;
const int outliers_indices_size_n1 =
outliers_indices.has_value() ? outliers_indices.value().size(-1) : 0;
Expand Down Expand Up @@ -681,26 +698,6 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
}

if (in_features <= cuda::kBlockSize) {
// output = at::empty(output_shape, centroids.options());
// switch (base_groupsize){
// case 16:
// gpuErrchk(cudaFuncSetAttribute(WqA16WithOutliers<scalar_t, 16, 4,
// false>,
// cudaFuncAttributeMaxDynamicSharedMemorySize,
// shared_memory_size));
// DispatchWqA16Kernel(output, 16, false);
// break;
// case 12:
// gpuErrchk(cudaFuncSetAttribute(WqA16WithOutliers<scalar_t, 12, 4,
// false>,
// cudaFuncAttributeMaxDynamicSharedMemorySize,
// shared_memory_size));
// DispatchWqA16Kernel(output, 12, false);
// break;
// default:
// TORCH_CHECK(false, "un-supported
// base_groupsize:"+std::to_string(base_groupsize));
// }
TORCH_CHECK(false, "un-supported yet");
} else {
constexpr int do_reduce = 4;
Expand Down
Loading

0 comments on commit 170770c

Please sign in to comment.