Skip to content

Commit

Permalink
[Stablehlo] fix lowering gelu(x, tanh) (#3307)
Browse files Browse the repository at this point in the history
* lowering gelu("none") to erf
* lowering gelu("tanh") to tanh
  • Loading branch information
qingyunqu authored May 9, 2024
1 parent 0f0f57c commit 5213557
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
46 changes: 38 additions & 8 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <cmath>
#include <numeric>
#include <type_traits>

Expand Down Expand Up @@ -1064,7 +1065,8 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
}

// Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
// Gelu(x, "none") = x * 0.5 * (1 + erf(x/(sqrt(2))))
// Gelu(x, "tanh") = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
template <>
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
AtenGeluOp op, OpAdaptor adaptor,
Expand All @@ -1076,16 +1078,44 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
return op.emitError("only ranked tensor type is supported.");
}

std::string approximate;
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) {
return op.emitError("approximate must be constant string");
}
if (approximate != "none" && approximate != "tanh") {
return op.emitError("unsupported approximate: ") << approximate;
}

Value one = getConstantLike(rewriter, loc, 1.0, input);
Value two = getConstantLike(rewriter, loc, 2.0, input);
Value three = getConstantLike(rewriter, loc, 3.0, input);
Value half = getConstantLike(rewriter, loc, 0.5, input);
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
return success();
// 2/pi
Value twoDivPi = getConstantLike(rewriter, loc, M_2_PI, input);
Value t = getConstantLike(rewriter, loc, 0.044715, input);

// x * 0.5
auto inputMulHalf = rewriter.create<stablehlo::MulOp>(loc, input, half);
if (approximate == "none") {
auto rsqrtTwo = rewriter.create<stablehlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, erfAdd, inputMulHalf);
return success();
} else {
auto sqrtTwoPi = rewriter.create<stablehlo::SqrtOp>(loc, twoDivPi);
// x^3
auto powThree = rewriter.create<stablehlo::PowOp>(loc, input, three);
// x + 0.044715 * x^3
auto add = rewriter.create<stablehlo::AddOp>(
loc, input, rewriter.create<stablehlo::MulOp>(loc, t, powThree));
auto tanh = rewriter.create<stablehlo::TanhOp>(
loc, rewriter.create<stablehlo::MulOp>(loc, sqrtTwoPi, add));
auto tanhAdd = rewriter.create<stablehlo::AddOp>(loc, tanh, one);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, tanhAdd, inputMulHalf);
return success();
}
}

// AtenLog2Op
Expand Down
13 changes: 6 additions & 7 deletions test/Conversion/TorchToStablehlo/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
// CHECK-LABEL: func.func @torch.aten.gelu(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "none"
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.multiply %[[T0]], %[[T3]]
// CHECK: %[[T5:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.multiply %[[T0]], %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = chlo.erf %[[T6]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[T8:.*]] = stablehlo.add %[[T7]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T8]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
Expand Down

0 comments on commit 5213557

Please sign in to comment.