Skip to content

Commit

Permalink
[linalg] Fix bug for conversion of complex dtype (#3269)
Browse files Browse the repository at this point in the history
The conversion of complex type wasn't supported or checked; the support
and required tests were added.

Fixes:
iree-org/iree#17226 (comment)
  • Loading branch information
pashu123 authored May 1, 2024
1 parent 0a2d21b commit 8c48135
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
21 changes: 21 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
}

if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
auto dtypeElemType = dtypeComplex.getElementType();

// Extract the real and imaginary parts of the scalar.
// Cast them to the target element type, and create a new complex
// value with the target complex type.
Value realVal = b.create<complex::ReOp>(loc, scalar);
Value imgVal = b.create<complex::ImOp>(loc, scalar);

realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);

return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
}
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
}

llvm_unreachable("convertScalarToDtype should handle all the types");
}

Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@
"ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -2314,6 +2315,7 @@
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
Expand Down
28 changes: 28 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
# ==============================================================================


# torch.complex32 is not supported by the refbackend.
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.complex64, True),
([-1], torch.complex128, True),
]
)
def forward(self, a, b):
return torch.mul(a, b)


@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.complex64),
tu.randint(4, high=10).type(torch.complex128),
)


# ==============================================================================


class ElementwiseMishModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 8c48135

Please sign in to comment.