-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] torchao Contributor Guide #391
Comments
Regarding 1, apart from what I have feedbacked in #384, starting to think of another alternative quantizer = Int4WeightOnlyQuantizer(groupsize=32)
quantizer.quantize(model) But then this feels like the old api Personally I don't really like a function returning a function, like the current Another option is to expose from functools import partial
quantize(model, partial(apply_int4wo_quant, groupsize=32)) Also, since the quantization is in-place, I think it's good to use |
For the manual API why have both a string and a |
Is there a tutorial or end-to-end example of how to compose these APIs to implement a non-trivial quantization method (e.g., AWQ, GPTQ, etc.) and specialized deployment layout (e.g., Marlin)? Basically a reference impl of how these tools can be used to facilitate the translation of research ideas to deployment-ready libraries. If not, happy to work on one. |
the quantizer API is actually what I have been thinking about before as "Unified Quantization API": https://github.com/pytorch/ao/blob/main/torchao/quantization/unified.py and these two APIs will cover most of the current quant flows, it's also used by QAT prototype: ao/torchao/quantization/prototype/qat.py Line 22 in d0af941
the partial function idea has been raised in our meetings before as well, but that also doesn't seem very straightforward to use. For now I'm planning to just use also in the ideal future I think we'd expect modeling user just use the autoquant and not worry about all these details |
so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names |
Not yet, so my understanding is that this doc talks about how we build the fundamental "dtype" of quantization, it can serve as a building block for more sophisticated quantization method that can utilize the "dtype" as a data representation. I'm planning to put up an example of static quant (with module swap) that could potentially help demonstrate how these other techniques (e.g. ones that require calibration etc.) can be implemented in similar ways. please feel free to work on a tutorial to show how a real world end to end quantization example looks like utilizing the "dtype" that we build with tensor subclass in this doc we also plan to build out hqq with this design #255, cc @HDCharles, this one also doesn't not require calibration though. |
But they are already importing the |
yeah, we are thinking of just removing these for now, it would be better for people to also see the docstrings for these things, and an extra import doesn't seem to be a big issue |
About subclasses: I hope there would still be way to (when needed) register custom fused kernels which do e.g. q-someop-dq in a fused way, without having a separate kernel launches for q and dq. I know this type of graph matching is possible with torch.compile, but I hope that the explicit introduction of subclasses (and seemingly mainly used for representational/expressiveness/dispatch purpose) will not make this more complicated. Also, hoping that it will work nicely with profiling/tracing to know exactly what kernel is getting invoked and exactly where any q/dq is happening (especially for autoquant regimes). This is kind of similar to what was originally done with quint8 dtype, right? (except now it will allow user-powered extension and dispatch is based on subclass type instead of dtype) |
yeah I think we should still be able to register inductor fusion passes, but one thing here is, q/dq ops are no longer large ops in the torch.compile path, we are planning to keep them as smaller aten ops (sub/mul etc.) so these can participate in normal inductor optimization directly, so the optimization story will be a bit different for inductor/torch.compile I think. However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434
yeah we can definitely provide additional information on what kernel is picked for autoquant, cc @HDCharles
yes, this is similar to quint8, except it's built in python with tensor subclasses extension point, this allows us to stay out of core and have faster iteration speed as well. for dispatch, I feel it could also continue to use dtype as well, after we sort out the dtype story: #442 |
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Based on the example, it seems like it would be the property of DTypeTensor that decides whether to use q-dq or not, right? |
So what I understand from this proposal, as far as wrapping LayoutTensor and DTypeTensor is concerned is that, A. Static quantization (both activation and weights are quantized) It is not clear how the proposed API addresses 1, but I presume you have ideas so I will assume it will work. Tensor subclass as I understand does/can do two things: 1) override representation of the tensor, e.g. linear.weight changed from torch.Tensor to DTypeTensor and 2) also change the dispatch behavior to dictate how an op with DTypeTensor should be executed. On the DTypeLayout: I feel that having each backend or kernel that has its own special layout for execution should be its own tensor subclass, however this can also result in proliferation, e.g. DTypeLayoutCUDA, DTypeLayoutCUDAMySecialPacking, DTypeLayoutMetalDefault etc. I actually liked PT2E workflow in this regard where representation was canonical and execution semantics, arising from weight packing etc, were done as a separate transform. If I were to think of the same here, then I would say for 4-bit there is DTypeTensor and DTypeDefaultLayout and subsequent transforms can replace the tensor subclass with their backend specific tensor subclass. Separate from above: For the comment on using q-dq based dispatch vs. fused op, I think we can allow overriding behavior where users can plugin their own implementation, including custom fused ops, for a specific DTypeTensor subclass that uses a specific DTypeLayout tensor. |
yeah this is correct
yeah working on an example for this right now
I should probably add more docs for this one, right now it's implemented by applying a ao/torchao/quantization/quant_api.py Lines 355 to 356 in a895699
LienarActQuantizedTensor , when dispatching to linear op, we'll apply the quantization function to input_quant_func to the input, and then continue the dispatch: ao/torchao/quantization/subclass.py Line 657 in a895699
ao/torchao/dtypes/affine_quantized_tensor.py Lines 550 to 554 in a895699
also I want to highlight that dynamic quant, static quant is not considered as purely a dtype problem, since this also involves flows (how to convert my model to use these quantized tensors?), I'm also working on giving more details/examples of how to do that as well.
yeah I think so, user should be able to customize what they would like to say by implementing a new LayoutTensor type I think, although I guess the difference here is user has to reason through different dispatch layers to figure out what is the final representation they will see in the end, like the dynamic quant example. |
@jerryzh168 please note that my questions/responses are not motivated by whether it works for executorch or not. My comment on canonical representation was to borrow the same concept from PT2E where quantization and execution of quantized ops are separated. In the current APIs proposed, it is not the case and thats what I was highlighting |
And this I mean for eager model not for export. Basically in exported graph there is a) quant and b) lowering. What is the equivalent of that in eager mode subclass based API and whether it is useful to have that |
Summary: Addressing feedback for `quantize` API from #391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
@kimishpatel, I see, yeah I think separation of quant and lowering makes sense for executorch stack, but for eager it is not really applicable, since in eager people would just expect to quantize a model and get acceleration, require eager mode use case to do an extra lowering step seems to change the UX for eager mode? what do you think? |
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from pytorch#391 to `tutorials` folder so that code can be executed while we develop new APIs/utils and being kept up to date Test Plan: python Reviewers: python tutorials/developer_api_guide.py regression tests: python test/quantization/test_quant_api.py python test/integration/test_integraton.py Subscribers: Tasks: Tags:
Summary: Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from #391 to `tutorials` folder so that code can be executed while we develop new APIs/utils and being kept up to date Test Plan: python Reviewers: python tutorials/developer_api_guide.py regression tests: python test/quantization/test_quant_api.py python test/integration/test_integraton.py Subscribers: Tasks: Tags:
Summary: Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from #391 to `tutorials` folder so that code can be executed while we develop new APIs/utils and being kept up to date Test Plan: python Reviewers: python tutorials/developer_api_guide.py regression tests: python test/quantization/test_quant_api.py python test/integration/test_integraton.py Subscribers: Tasks: Tags:
Summary: 1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py 2. added pytorch#391 to torchao docs Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
Summary: 1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py 2. added pytorch#391 to torchao docs Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
* Update torchao api reference and add contributor guide Summary: 1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py 2. added #391 to torchao docs Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * format * typo * renaming * comma * format * comments
* Update torchao api reference and add contributor guide Summary: 1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py 2. added pytorch#391 to torchao docs Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * format * typo * renaming * comma * format * comments
Summary: In executorch this path was slower compared to using optimized ops. We did not debug that further, but for now just disabled it. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
Added to torchao docs: https://pytorch.org/ao/stable/contributor_guide.html
Status: Draft
Updated: 09/18/2024
Objective
In this doc we’ll talk about how different optimization techniques are structured in torchao and how to contribute to torchao.
torchao Stack Overview
First we want to lay out the torchao stack:
Any quantization algorithm will be using some components from the above stack, for example
int4_weight_only
quantization uses:(1) weight only quantization flow
(2) tinygemm bf16 activation + int4 weight kernel and quant primitive ops
(3) AffineQuantizedTensor tensor subclass with TensorCoreTiledLayout
(4) torch.uint4 dtype (simulated with quant_min/quant_max right now)
Note: we'll also talk about how to compose sparsity with quantization in the
Quantized Tensors
sectionBasic DTypes
dtype
is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people calltorch.empty(.., dtype)
), for more details please check out: https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data, the dtypes we aim to support in torchao are:
Note some of the above are prototype only for now. We'll consider adding then to pytorch core when they become popular and have hardware support.
Current Support
In terms of actual implementation, there are two parts:
1). In PyTorch, we need to add the dtype to
torch.dtype,
e.g.torch.uint2
, example: pytorch/pytorch#117208, but these are just placeholders so that we can usetorch.uint2
.2). Outside of PyTorch (e.g. in torchao), we implement the tensor operations for these dtypes with tensor subclasses, also a standard packing format is needed.
Adding placeholder dtype in PyTorch
As mentioned in https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833, the criteria for adding dtype in PyTorch is that it shows wide adoption. For the above mentioned fundamental dtypes, the ones that are supported in PyTorch are:
We may add torch.int2 to torch.int7 to PyTorch soon due to request from edge team, but for the other types we plan to wait until there is more evidence of wide adoption and hardware support.
Implementing tensor operations for these dtypes with Tensor subclasses
For this, the requirement is we decide on a "standard" packing format, and hopefully one that is amenable to efficient implementation, but for both uintx and floatx we haven't integrate enough kernels to decide on this. So current packing implementations (e.g.
ao/torchao/dtypes/uintx/uintx.py
Line 36 in d2bce6a
Quantization Primitive Ops / Efficient Kernels
Quantization Primitive Ops
Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators:
choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization
quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters
dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters
There could be variations of the above to accommodate specific use cases, for example for static quantization we may have
choose_qparams_affine_with_min_max
that will choose quantization parameters based on min/max values derived from the observation process.Efficient kernels
We'll also have efficient kernels that works with the low precision tensors, for example
_weight_int4pack_mm the tinygemm int4 kernel (bf16 activation + int4 weight)
int_matmul that takes two int8 tensors and outputs an int32 tensor
int_scaled_matmul that does matmul and also applies a scale to the result.
Note: We can also rely on
torch.compile
to generate kernels (through triton), for example the current int8 weight only quantization kernel just relies ontorch.compile
to get speedup. In this case there is no specific "efficient kernel" that's corresponding to the type of quantization.Quantized Tensors (derived dtypes)
On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale.
Existing example in torchao is
AffineQuantizedTensor
, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is:low_precision_val = high_precision_val / scale + zero_point
, where scale/zero_point are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Affine quantization is a very common type of quantization, since it's straightforward that when we try to map from higher precision values to lower precision values, we do an affine transformation (high_preicsion_val / scale + zero_point). Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is look up table based quantization.Layout and Packing
Native tensors have a hardcoded list of selections of layout: https://github.com/pytorch/pytorch/blob/647815049ec28a72dc1bb6a977791927bba058d5/c10/core/Layout.h#L11, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout.
The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. And the extension of layout can be achieved at python level tensor subclasses without modifying C++ pytorch core code.
We use this to support different ways that the same quantized Tensor can be packed for efficient execution, for example, for
_weight_int4pack_mm
we need to pack the weight to an format that is friendly for Tensor Core, we call it TensorCoreTiledLayoutType. We add a layout_tensor for the quantized tensor to store the packed (or unpacked) weight, and we use a layout_type to store different parameters that's relevant for packing.Note that layout is an abstraction not only for custom data representation, it is also used for how the
layout Tensor interacts with different operators, e.g. the same data representation can have different
implementations when running the same operator, e.g. transpose, quantized_linear, even the operator semantics should stay the same.
Quantize + Sparse Tensor can also be supported through the Layout abstraction, for example, int4 weight only quantization + sparse. We also provide some common utils that helps people to add different layouts to a quantized tensor, please check out the developer guide below for code examples.
Quantization Algorithms/Flows
On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up.
For demonstration purposes, let's say after previous step we have
AffineQuantizedTensor
andto_affine_quantized
factory function defined. For simplicity, let's sayto_affine_quantized
takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to anAffineQuantizedTensor
with corresponding dtype.Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in
Tensor Subclass Developer Guide
section.Weight Only Quantization
This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:
apply the above to all linear modules in the model and we'll get a weight only quantized model.
Dynamic Activation Quantization + Weight Quantization
This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao, the common pattern we use is by applying
to_linear_activation_quantized
on top of quantized weight:to_linear_activation_quantized
is used to apply quantization to activation, it takes ainput_quant_func
that will quantize the activation and the original weight, and during runtime when it encounters aF.linear
op, it will apply the storedinput_qunat_func
to activation and redispatch toF.linear
with quantized activation and weight.If the above does not work, user can also do module swaps, or use torch.export.unflatten.unflatten() to get a traced module that you can modify
But using tensor subclass is preferred because it is easier for serialization/deserialization, if we use tensor subclasses to support dynamic quantization, then we can load the quantized weights directly without further preparation for the model. Otherwise, we'd need to do module swap or other modifications to the model first before loading the quantized weights.
Static Quantization
Static quantization means activation is statically quantized instead of dynamically quantized at runtime. In terms of flow, static quantization requires calibration with sample data in order that we can figure out the appropriate quantization parameters.
At the high level there are three steps for static quantization: (1) insert observers (2) calibration (3) quantize the model
Insert Observers
In insert observers step, we need to add observer modules to input (and output) activation and weight of the operator to collect statistics of the Tensor. So there are two things we need to address, how to define observer module? how to add observer module to the model.
How to define observer module
Observers are specific to: (1) type of quantization (e.g. affine quantization, look up table based quantization) (2) type of stats we want to track, e.g. min max observer, moving average observer.
Generally an observer module should define forward and calculate_qparams
For affine quantization, we defined AffineQuantizedMinMaxObserver that records min_val/max_val based on the granularity of affine quantization, and also defines how to calculate_qparams based on the recorded stats.
How to add observer module to the model
Use Tensor Subclasses
If the only operator you are interested in quantizing is linear, you can use linear activation weight observer, we also have a corresponding insert_observer_ API that handles modifying the weight of linear.
Module swap
Alternatively, you could also define and ObservedLinear module (or other module types) and swap the non observed with the observed module
Calibration
Calibration step is typically straightforward, typically we just need to run the model through the calibration dataset. For more complicated calibration (e.g. where we record all inputs and do optimizations based on all inputs), we'll cover some of them in next section.
Quantize
We can reuse the
quantize_
API but provide a differentapply_tensor_subclass
function that converts the observed linear module to a linear module with quantized weight and statically quantized input activation, this can be done in the same manner as the dynamic quantization (withto_linear_activation_quantized
), see example.Alternatively, user can do module swap as well.
Other Quantization Flows
For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, GPTQ like quantization flow that is adopted by Autoround, it uses MultiTensor and module hooks to optimize the module.
If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details.
Training
The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well.
Quantization Aware Training
Low Bit Optimizers
Today we have some prototype low bit optimizers: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). We can extend our
AffineQuantizedTensor
for that to be used in optimizers as well following the example.Quantized Training
Similar to low bit optimizers, we have quantized training prototype in https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training, and we could extend AffineQuantizedTensor to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc.
Case Study: How int4 weight only quantization works in torchao?
To connect everything together, here is a more detailed walk through for how int4 weight only quantization is implemented in torchao.
High Level Summary
During Quantization
First we start with the API call:
quantize_(model, int4_weight_only())
what this does is it converts the weights ofnn.Linear
modules in themodel
to int4 quantized tensor (AffineQuantizedTensor
that is int4 dtype, asymmetric, per group quantized), using the layout for tinygemm kernel: "tensor_core_tiled" layout type.During Model Execution
When we run the quantized model
model(inputs)
, we'll run through the functional linear operator innn.Linear
:where
input
is a bfloat16 floating point Tensor,weight
is an int4AffineQuantizedTensor
, it calls into a__torch_function__
of theAffineQuantizedTensor
subclass, which will end up in an implementation for F.linear when one of the input isAffineQuantizedTensor
, so it calls:The _quantized_linear_op goes through the
_AQT_QLINEAR_DISPATCH_TABLE
and checks each dispatch conditions, if the dispatch condition passes, it will call the implementation withinput
/weight
/bias
. Please check out this doc for the explanation ofdispatch_condition
andimpl
In this case the
dispatch_condition
for the int4 weight only quantization kernel will be this and the implementation we are using will be this, the function takes an bfloat16 input Tensor and an int4AffineQuantizedTensor
, and calltorch.ops.aten._weight_int4pack_mm
with the input Tensor and the packed weight that's stored inweight_tensor.layout_tensor
.During Save/Load
Since
AffineQuantizedTensor
weight is still atorch.Tensor
, save/load works the same way as the original high precision floating point model.Tensor Subclass Developer Guide
We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios.
Prerequisites
Some externally available resources for tensor subclasses:
Why Tensor Subclass?
There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things:
(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core
(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization
(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques
Example Code for a new Quantization Technique or DType
Please feel free to start with https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations.
Basic Structure
A tensor subclass needs to define a few basic methods:
__new__
,__init__
,__tensor_flatten__
,__tensor_unflatten__
and also dispatch functions for torch functions
__torch_function__
and aten ops__torch_dispatch__
Here is an example of basic structure:
Operator Support
There are two types of operator support, torch function and aten ops. For torch functions (e.g. torch.nn.functional.linear), we’ll need to overwrite
__torch_function__
callback in the Tensor subclass, for aten ops (e.g. torch.ops.aten.mm), we’ll need to overwrite__torch_dispatch__
callback function.For a new dtype, we’d like people to define the following decorator:
And we can implement the operator dispatch with the following:
What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are:
__torch_function__
:torch.nn.functional.linear
__torch_dispatch__
:torch.ops.aten.addmm.default
,torch.ops.aten.mm.default
,torch.ops.aten.detach.default
,torch.ops.aten.t.default
You can also find the ops that can be overwritten in
__torch_function__
or__torch_dispatch__
with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (seeOptimized Operators
section for more details):Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes.
We are still working on a table that talks about for each feature what are the operators that need to be supported.
Adding Efficient Kernels
Custom triton kernels
Custom triton kernels can be implemented and registered in https://github.com/pytorch/ao/tree/main/torchao/kernel
Implementation Example:
ao/torchao/kernel/intmm_triton.py
Lines 270 to 302 in 0bdde92
Register as a custom op:
ao/torchao/kernel/intmm_triton.py
Lines 337 to 364 in 0bdde92
Custom hand written kernels
Custom kernels (implementations) for cpu/cuda/mps can be implemented through https://github.com/pytorch/ao/tree/main/torchao/csrc e.g. int4 cuda, and accessible through torch.ops.my_custom_op
Dispatches
For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in torch_function or torch_dispatch and dispatch to target operators, for example:
ao/torchao/dtypes/aqt.py
Lines 348 to 355 in cbc74ee
Specifically for
AffineQuantizedTensor
, we also allow people to extend the quantized linear to use a new efficient kernel or implement by defining two functions:dispatch_condition
(defines the condition to dispatch to the kernel) andimpl
(actual implementation that takes activation, (quantized) weight, bias Tensor and runs the efficient kernel), both takinginput_tensor
,weight_tensor
,bias
as argument, and can be registered into dispatch of quantized linear inAffineQuantizedTensor
withregister_aqt_quantized_linear_dispatch
. here is an example showing how it works:ao/test/dtypes/test_affine_quantized.py
Lines 92 to 113 in e283743
Packing/Layout
Sometimes the quantized weights has to be packed in order to yield optimal performance. For this we want to extend the “layout” concept in Tensor and introduce an indirection for tensor data storage, see #278 for more details.
Here is an example (see notebook for full code):
Flow
After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.
For model level API, people can reuse
torchao.quantization.quantize_
that allows people to apply a tensor subclass conversion to weight of linear, and allows filtering function: https://github.com/pytorch/ao/blob/aeee551b15eebeaabf98ffab9a00addc675a12a9/torchao/quantization/quant_api.py (TODO: replace this with torchao doc website link when that's ready)See
Quantization Algorithms/Flows
section for examples of weight only/dynamic quant/static quant and other types of model level APIs based on the factory function.Using torch.compile for Performance
Note: for 2.4 and below, we need to use the following:
In order to be compatible with torch.compile. To aim for performance optimization, we should run through torch.compile with fullgraph mode first, and remove any unnecessary graph breaks. You can add TORCH_LOGS=”output_code” when you run the script in order to see the inductor generated code. e.g.
TORCH_LOGS=”output_code” python example.py
Serialization
This test shows how we expect save/load to work for a model quantized with tensor subclass based API:
You can checkout the serialization doc for more details.
Note: we are also integrated with huggingface and supports serialization/deserialization through the huggingface
save_pretrained
/push_to_hub
/from_pretrained
APIs, available after huggingface/transformers#33456 is landed.Other Feature Support
The above just talks about basic feature support, we also provide examples on how to add supports for training, tensor parallel, FSDP by extending the MyDTypeTensor, we'll put more examples in developer_api_guide folder covering the following use cases.
General Guide on Extending torchao
For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder https://github.com/pytorch/ao/tree/main/torchao/prototype, but you could also take a look at
AffineQuantizedTensor
if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case.To contribute to existing code base:
AffineQuantizedTensor
, e.g. making it trainable, add tensor parallelism support etc.: https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.pyAffineQuantizedTensor
(maybe a new layout as well): https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py for now but we plan to split this file a bit more, Add sparse marlin AQT layout #621 as an exampleTensor Subclass Functionality/Composability Testing
We are also working on test suites to test out the functionalities of tensor subclass and the composability with different systems like (torch.compile, DTensor etc.):
Kernel Microbenchmarks
Before we test performance on models, we can also do some microbenchmarks on single linear operator (or other compute intensive/memory intensive) operators with different input dimensions to get a sense of speedup. For a specific kernel that you'd like to benchmark, you can create a benchmark file like https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_aq.py and run benchmark with different shapes that's important for target model. A quick way to get the relevant shape for linear op and other ops is by running the example with this:
Change the model with the model you are interested in optimizing, and run the following:
Example output:
The output of
all linear shapes
can be copy pasted to microbenchmarking script code underbenchmarks/benchmark_your_kernel.py
for benchmarking.For benchmark helper functions, right now we have
ao/torchao/utils.py
Line 55 in 0bdde92
ao/torchao/utils.py
Line 139 in 0bdde92
Model Benchmarks and Eval
After you have the quantization flow implemented, you can run benchmark and eval on llama (llama2/llama3) or sam models that are already modified to be friendly to
torch.compile
, and compare with existing techniques in torchao.Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models.
Please checkout the
--help
option for each of the script to understand the supported options, e.g. you can use--profile=profile_path
to get the chrome trace of the run to understand detailed chrome trace: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-tracing-functionalityPlease let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder.
The text was updated successfully, but these errors were encountered: