summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp')
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 40dce001a3b2..3532785c31b9 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -54,6 +54,31 @@ public:
return success();
}
};
+
+class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
+public:
+ using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type dstType = getTypeConverter()->convertType(selectOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
+
+ if (!adaptor.getCondition().getType().isInteger(1))
+ return rewriter.notifyMatchFailure(
+ selectOp,
+ "can only be converted if condition is a scalar of type i1");
+
+ rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
+ adaptor.getOperands());
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
- ArithOpConversion<arith::SubFOp, emitc::SubOp>
+ ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+ SelectOpConversion
>(typeConverter, ctx);
// clang-format on
}