Skip to content

Commit

Permalink
[fx] Fix importing and tests for quantized conv (#3809)
Browse files Browse the repository at this point in the history
The fx tracer does not support tracing "real" quantized tensors
currently. A "real" quantized tensor here means a tensor that is created
using a method like `torch.quantize_per_tensor()` and carries the
quantization parameters (scale, zero_point, scheme) in the object.
However, it seems like the DQ-Q type fake quantizatation is now commonly
used as a high level representation of quantized operators and is only
lowered to native quantized ops (if available) in the respective
hardware backend. Quantization of floating point modules in PyTorch is
recently also performed as a graph transformation after
exporting/tracing the original module.

```python
# Examples of "real"/native quantization
tens = torch.randint(-127, 127, (1,), dtype=torch.int8)
torch._make_per_tensor_quantized_tensor(tens, 1, 0)
# tensor([90.], size=(1,), dtype=torch.qint8,
#       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

tens = torch.rand((1,))
torch.quantize_per_tensor(tens, 1, 0, torch.qint8)
# tensor([1.], size=(1,), dtype=torch.qint8,
#       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

# Example of DQ/Q quantization
import torch.ao.quantization.fx._decomposed
tens = torch.rand((1,))
torch.ops.quantized_decomposed.quantize_per_tensor.default(tens, 1, 0, -128, 127, torch.int8)
# tensor([1], dtype=torch.int8)
```

This means that a typical import flow for a quantized network
into/through torch-mlir would look like this:
`torch.export() -> quantization transformations on fx graph ->
fx_importer` Where the tensors in the graph are normal float/int tensors
and the quantization parameters are carried by the DQ/Q ops. These kinds
of graphs can be traced without issues.

Currently, our quantized convolution tests use the "real" quantized
tensors. This means that with the retirement of the `jit_ir_importer`,
these tests cannot be imported any longer. In summary, I see no reason
to stick to the "real" quantization in these tests, as both PyTorch 2.0
is using DQ/Q quantization and our linalg backend is also using it.

This patch updates our quantized convolution tests to use the DQ-Q
quantization with the ops from `torch.ops.quantized_decomposed`.

Note: For future reference, there seems to be an ongoing consolidation
of the ops for the DQ/Q scheme on the PyTorch side
(pytorch/ao#986 (comment)).
  • Loading branch information
ubfx authored Oct 22, 2024
1 parent 140cad5 commit 42ba541
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 48 deletions.
8 changes: 0 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,7 @@
"CeilFloatModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
Expand Down
94 changes: 55 additions & 39 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,23 +1183,28 @@ def ConvTbcModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))


# For DQ-Q fake quantization ops
import torch.ao.quantization.fx._decomposed


class Conv2dQInt8ModuleBase(torch.nn.Module):
def __init__(self, groups=1):
self.groups = groups
super().__init__()

def _forward(self, inputVec, weight, bias):
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
inputVec = torch.dequantize(inputVec)

weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3)
weight = torch.dequantize(weight)

bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
bias = torch.dequantize(bias)
def _forward(self, input, weight, bias):
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
input, 0.01, 7, -128, 127, torch.int8
)
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
weight, 0.01, 3, -128, 127, torch.int8
)
bias = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
bias, 1, 0, -1000, 1000, torch.int32
)

return torch.ops.aten.conv2d(
inputVec,
conv = torch.ops.aten.conv2d(
input,
weight,
bias=bias,
stride=[1, 1],
Expand All @@ -1208,6 +1213,11 @@ def _forward(self, inputVec, weight, bias):
groups=self.groups,
)

# Use int32 to avoid overflows
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
)


class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
@export
Expand All @@ -1216,7 +1226,7 @@ class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
([-1], torch.int32, True),
]
)
def forward(self, inputVec, weight, bias):
Expand All @@ -1230,7 +1240,7 @@ class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
None,
([2, 3, 12, 12], torch.int8, True),
([3, 1, 5, 3], torch.int8, True),
([3], torch.float, True),
([3], torch.int32, True),
]
)
def forward(self, inputVec, weight, bias):
Expand All @@ -1244,7 +1254,7 @@ class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
None,
([2, 3, 12, 12], torch.int8, True),
([6, 1, 5, 3], torch.int8, True),
([6], torch.float, True),
([6], torch.int32, True),
]
)
def forward(self, inputVec, weight, bias):
Expand All @@ -1255,23 +1265,23 @@ def forward(self, inputVec, weight, bias):
def Conv2dQInt8Module_basic(module, tu: TestUtils):
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
module.forward(inputVec, weight, bias)


@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn(groups=2))
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(6)
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
module.forward(inputVec, weight, bias)


@register_test_case(module_factory=lambda: Conv2dQInt8ModuleStatic(groups=3))
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
module.forward(inputVec, weight, bias)


Expand All @@ -1281,7 +1291,7 @@ def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
bias = torch.rand(6)
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
module.forward(inputVec, weight, bias)


Expand All @@ -1300,24 +1310,29 @@ def __init__(self):
]
)
def forward(self, input, weight, bias):
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
qinput = torch.dequantize(qinput)
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
qweight = torch.dequantize(qweight)
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
qbias = torch.dequantize(qbias)
qz = torch.ops.aten.convolution(
qinput,
qweight,
bias=qbias,
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
input, 0.01, -25, -128, 127, torch.int8
)
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
weight, 0.01, 50, -128, 127, torch.int8
)

res = torch.ops.aten.convolution(
input,
weight,
bias=bias,
stride=[2, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)
return qz

# Use int32 to avoid overflows
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
res, 1, 0, -(2**31), 2**31 - 1, torch.int32
)


@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
Expand All @@ -1342,18 +1357,14 @@ def __init__(self, groups=1):
super().__init__()

def _forward(self, inputVec, weight, scales, zeropoints, bias):
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
inputVec = torch.dequantize(inputVec)

weight = torch._make_per_channel_quantized_tensor(
weight, scales, zeropoints, axis=0
inputVec = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
inputVec, 0.01, 7, -128, 127, torch.int8
)
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
weight, scales, zeropoints, 0, -128, 127, torch.int8
)
weight = torch.dequantize(weight)

bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
bias = torch.dequantize(bias)

return torch.ops.aten.conv2d(
conv = torch.ops.aten.conv2d(
inputVec,
weight,
bias=bias,
Expand All @@ -1363,6 +1374,11 @@ def _forward(self, inputVec, weight, scales, zeropoints, bias):
groups=self.groups,
)

# Use int32 to avoid overflows
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
)


class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
@export
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _module_lowering(
option_string = "{extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
Expand Down

0 comments on commit 42ba541

Please sign in to comment.