Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][emitc] Support convert arith.extf and arith.truncf to emitc #121184

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jacquesguan
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Jianjian Guan (jacquesguan)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/121184.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+34-1)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+26)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index ccbc1669b7a92a..e2fbac40517e0d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -733,6 +733,37 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
   }
 };
 
+// Floating-point to floating-point conversions.
+template <typename CastOp>
+class FpCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+  FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+  LogicalResult
+  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Vectors in particular are not supported
+    Type operandType = adaptor.getIn().getType();
+    if (!emitc::isSupportedFloatType(operandType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast source type");
+
+    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+    if (!emitc::isSupportedFloatType(dstType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast destination type");
+
+    Value fpCastOperand = adaptor.getIn();
+    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -778,7 +809,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
-    FtoICastOpConversion<arith::FPToUIOp>
+    FtoICastOpConversion<arith::FPToUIOp>,
+    FpCastOpConversion<arith::ExtFOp>,
+    FpCastOpConversion<arith::TruncFOp>
   >(typeConverter, ctx);
   // clang-format on
 }
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 1728c3a2557e07..434f8771d58c1e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
 
   return %div : i32
 }
+
+// -----
+
+func.func @arith_extf(%arg0: f16) -> f64 {
+  // CHECK-LABEL: arith_extf
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
+  // CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
+  %extd0 = arith.extf %arg0 : f16 to f32
+  // CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
+  %extd1 = arith.extf %extd0 : f32 to f64
+
+  return %extd1 : f64
+}
+
+// -----
+
+func.func @arith_truncf(%arg0: f64) -> f16 {
+  // CHECK-LABEL: arith_truncf
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
+  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+  %truncd0 = arith.truncf %arg0 : f64 to f32
+  // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
+  %truncd1 = arith.truncf %truncd0 : f32 to f16
+
+  return %truncd1 : f16
+}
\ No newline at end of file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants