Skip to content

Commit

Permalink
[TOSA] Add some more mixed dtype handling (#3909)
Browse files Browse the repository at this point in the history
* Add int input handling for activation functions like erf, sigmoid, and
tanh
* Fix mixed dtype handling for scalar comparison ops
* Add mixed dtype handling for pow tensor op (with only floating point
result type support for now)
* Add Torch to TOSA lowering for torch.aten.tan


Change-Id: I3a8aa1e6febbc0e39ebdb5734f87ae171b03cd73

Signed-off-by: Justin Ngo <[email protected]>
  • Loading branch information
justin-ngo-arm authored Dec 9, 2024
1 parent a99e378 commit 5077090
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 40 deletions.
107 changes: 83 additions & 24 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhsAsTensor, lhsElemTy, {})))
rhsAsTensor, rhs.getType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
Expand All @@ -414,11 +414,26 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
auto rhsTensorTy = dyn_cast<TensorType>(rhsTensor.getType());
auto rhsElemTy = rhsTensorTy.getElementType();

// There is no Lesser operator in TOSA.
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>());

// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
lhs = tosa::promoteType(rewriter, lhs, resultTy);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
}

// Support different types comparisons
auto isLhsElemFloat = isa<mlir::FloatType>(lhsElemTy);
auto isRhsElemFloat = isa<mlir::FloatType>(rhsElemTy);

// Support different types comparisons
if (lhsElemTy != rhsElemTy) {
if (lhsElemTy != rhsElemTy && !isBitwiseOp) {
if (isLhsElemFloat && !isRhsElemFloat) {
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
} else if (!isLhsElemFloat && isRhsElemFloat) {
Expand All @@ -441,20 +456,6 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
}
}
}
// There is no Lesser operator in TOSA.
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>());

// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
lhs = tosa::promoteType(rewriter, lhs, resultTy);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
}

auto resultOp = rewriter.create<TosaOpT>(op.getLoc(), resultTy,
(swapLhsRhs ? rhsTensor : lhs),
Expand Down Expand Up @@ -770,17 +771,24 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
auto selfTy = dyn_cast<TensorType>(self.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");

if (!isa<mlir::FloatType>(selfTy.getElementType()))
auto resultTy = dyn_cast<TensorType>(
this->getTypeConverter()->convertType(op.getType()));

if (!isa<mlir::FloatType>(resultTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
op, "Only floating-point datatype result types are supported");

rewriter.replaceOpWithNewOp<TosaOpT>(
op, this->getTypeConverter()->convertType(op.getType()), self);
// Non floating point inputs are not supported for activation functions
// (erf, sigmoid, tanh) in TOSA so we cast the input to result type
if (!isa<mlir::FloatType>(selfTy.getElementType()))
self = tosa::promoteType(rewriter, self, resultTy);

rewriter.replaceOpWithNewOp<TosaOpT>(op, resultTy, self);

return success();
}
Expand Down Expand Up @@ -1283,6 +1291,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
auto outType =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));

if (!isa<mlir::FloatType>(outType.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

Value selfTensor;
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
Value selfScalar = op.getSelf();
Expand All @@ -1299,9 +1311,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");

// Non floating point inputs are not supported for tosa.pow so we cast the
// input to result type
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
selfTensor = tosa::promoteType(rewriter, selfTensor, outType);
}

Value expTensor;
Expand All @@ -1319,6 +1332,11 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
if (!expTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");

// Non floating point exponents are not supported for tosa.pow so we cast
// the exponent to result type
if (!isa<mlir::FloatType>(expTy.getElementType()))
expTensor = tosa::promoteType(rewriter, expTensor, outType);
}

auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
Expand Down Expand Up @@ -8198,6 +8216,46 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
return success();
}

// Legalization for aten.tan
template <>
LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
AtenTanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// tan = sin / cos
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));

if (!isa<mlir::FloatType>(resultType.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// Non floating point inputs are not supported in TOSA so we cast the input
// to result type
if (!isa<mlir::FloatType>(selfType.getElementType()))
self = tosa::promoteType(rewriter, self, resultType);

auto sinOp = rewriter.create<tosa::SinOp>(op->getLoc(), resultType, self);

auto cosOp = rewriter.create<tosa::CosOp>(op->getLoc(), resultType, self);

auto reciprocalOp =
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), resultType, cosOp);

auto result = rewriter.create<tosa::MulOp>(
op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(),
/*shift=*/0);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -8540,6 +8598,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenTanOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
15 changes: 7 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,13 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseErfIntModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseUnaryIntModule_basic",
"PowIntFloatModule_basic",
"Deg2radModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"L1LossMeanReductionModule_basic",
Expand Down Expand Up @@ -3658,22 +3665,16 @@
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
Expand Down Expand Up @@ -3780,7 +3781,6 @@
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
Expand Down Expand Up @@ -4369,7 +4369,6 @@
"ElementwiseSqrtIntModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseTernaryModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
Expand Down
Loading

0 comments on commit 5077090

Please sign in to comment.