Skip to content

Commit

Permalink
[torch] Add OnnxToTorch lowering for onnx.HannWindow (#3276)
Browse files Browse the repository at this point in the history
Adds OnnxToTorch lowering for the `onnx.HannWindow` op. Also factors out
common implementation between the window functions.
  • Loading branch information
vinayakdsci authored May 3, 2024
1 parent a46fe2c commit 67d6a66
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 95 deletions.
231 changes: 136 additions & 95 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
return success();
}

namespace {
LogicalResult windowFunctionImpl(OpBinder binder,
ConversionPatternRewriter &rewriter,
Value size, Value a0, Value a1, Value a2,
Torch::ValueTensorType resultType,
int64_t output_datatype, int64_t periodic) {

Location loc = binder.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);

double isPeriodicFp = static_cast<double>(periodic);

Value zero = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(0.0));
Value one = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1.0));
Value two = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2.0));

constexpr double pi = llvm::numbers::pi;
Value tau = b.create<Torch::ConstantFloatOp>(
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));

Value noneVal = b.create<Torch::ConstantNoneOp>();
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
Value float32Type = b.create<Torch::ConstantIntOp>(
rewriter.getI64IntegerAttr(/*float32Type*/ 6));

// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand =
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal);
Value symmetricSizeFloat = b.create<Torch::AtenSubScalarOp>(
periodicSizeFloat.getType(), periodicSizeFloat, one, one);

Value isPeriodic =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(isPeriodicFp));
Value isSymmetricFloat = b.create<Torch::ConstantFloatOp>(
rewriter.getF64FloatAttr(1.0 - isPeriodicFp));

Value periodicComponent = b.create<Torch::AtenMulScalarOp>(
periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic);
Value symmetricComponent = b.create<Torch::AtenMulScalarOp>(
symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat);
Value sizeFloat = b.create<Torch::AtenAddTensorOp>(
symmetricComponent.getType(), symmetricComponent, periodicComponent, one);

// Here, size can be used in the place of periodicSizeFloat, as the
// latter is just a float representation of the former.
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);

Value rangeArr = b.create<Torch::AtenArangeStartStepOp>(
resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal);

Value rangeTimesTau =
b.create<Torch::AtenMulScalarOp>(resultType, rangeArr, tau);
Value rangeAngular =
b.create<Torch::AtenDivTensorOp>(resultType, rangeTimesTau, sizeFloat);
Value twoRangeAngular =
b.create<Torch::AtenMulScalarOp>(resultType, rangeAngular, two);

Value cosRangeAngular = b.create<Torch::AtenCosOp>(resultType, rangeAngular);
Value cosTwoRangeAngular =
b.create<Torch::AtenCosOp>(resultType, twoRangeAngular);

Value a1Component =
b.create<Torch::AtenMulScalarOp>(resultType, cosRangeAngular, a1);
Value a2Component =
b.create<Torch::AtenMulScalarOp>(resultType, cosTwoRangeAngular, a2);

// AtenSubScalarOp actually requires a tensor operand as the LHS, that
// is, operand #1. Therefore, to avoid errors, the onnx implementation
// has been modified. a1 has been changed to negative half, and the
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
// operation is commutative.
Value subA1Component =
b.create<Torch::AtenAddScalarOp>(resultType, a1Component, a0, one);
Value result = b.create<Torch::AtenAddTensorOp>(resultType, subA1Component,
a2Component, one);

std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(output_datatype);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given dtype conversion");
}
Value outputDtype = b.create<Torch::ConstantIntOp>(
rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch.value()));

rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, result, outputDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/noneVal);

return success();
}

} // namespace

// Simple rewrites for the default domain.
// See: https://onnx.ai/onnx/operators/
// For operators that are effectively version invariant, we register with
Expand Down Expand Up @@ -2252,7 +2354,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType)) {
return failure();
}
double isPeriodicFp = static_cast<double>(periodic);
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
Expand All @@ -2262,104 +2363,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
Value zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0));
Value one = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(1.0));
Value two = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(2.0));

constexpr double pi = llvm::numbers::pi;
Value tau = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));

Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6));

// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand = size.getType()
.dyn_cast<Torch::ValueTensorType>()
.getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), f32ResultType, size, float32Type, cstFalse,
cstFalse, noneVal);
Value symmetricSizeFloat = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
one, one);

Value isPeriodic = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp));
Value isSymmetricFloat = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp));

Value periodicComponent = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat,
isPeriodic);
Value symmetricComponent = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat,
isSymmetricFloat);
Value sizeFloat = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), symmetricComponent.getType(), symmetricComponent,
periodicComponent, one);

// Here, size can be used in the place of periodicSizeFloat, as the
// latter is just a float representation of the former.
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);

Value rangeArr = rewriter.create<Torch::AtenArangeStartStepOp>(
binder.getLoc(), resultType, zero, scalarLimit, one, noneVal,
noneVal, noneVal, noneVal);

Value rangeTimesTau = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeArr, tau);
Value rangeAngular = rewriter.create<Torch::AtenDivTensorOp>(
binder.getLoc(), resultType, rangeTimesTau, sizeFloat);
Value twoRangeAngular = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, rangeAngular, two);

Value cosRangeAngular = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, rangeAngular);
Value cosTwoRangeAngular = rewriter.create<Torch::AtenCosOp>(
binder.getLoc(), resultType, twoRangeAngular);

Value a1Component = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosRangeAngular, a1);
Value a2Component = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, cosTwoRangeAngular, a2);

// AtenSubScalarOp actually requires a tensor operand as the LHS, that
// is, operand #1. Therefore, to avoid errors, the onnx implementation
// has been modified. a1 has been changed to negative half, and the
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
// operation is commutative.
Value subA1Component = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, a1Component, a0, one);
Value result = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, subA1Component, a2Component, one);
auto windowFunctionResult =
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
output_datatype, periodic);

std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(output_datatype);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
if (failed(windowFunctionResult))
return failure();

return success();
});

patterns.onOp(
"HannWindow", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value size;
Torch::ValueTensorType resultType;
int64_t periodic, output_datatype;
if (binder.tensorOperand(size) ||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value outputDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch.value()));
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));

auto windowFunctionResult =
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
output_datatype, periodic);

if (failed(windowFunctionResult))
return failure();

rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, result, outputDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/noneVal);
return success();
});
}
Loading

0 comments on commit 67d6a66

Please sign in to comment.