diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 40dce001a3b2..3532785c31b9 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -54,6 +54,31 @@ public: return success(); } }; + +class SelectOpConversion : public OpConversionPattern<arith::SelectOp> { +public: + using OpConversionPattern<arith::SelectOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type dstType = getTypeConverter()->convertType(selectOp.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(selectOp, "type conversion failed"); + + if (!adaptor.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + selectOp, + "can only be converted if condition is a scalar of type i1"); + + rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType, + adaptor.getOperands()); + + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, ArithOpConversion<arith::AddFOp, emitc::AddOp>, ArithOpConversion<arith::DivFOp, emitc::DivOp>, ArithOpConversion<arith::MulFOp, emitc::MulOp>, - ArithOpConversion<arith::SubFOp, emitc::SubOp> + ArithOpConversion<arith::SubFOp, emitc::SubOp>, + SelectOpConversion >(typeConverter, ctx); // clang-format on } |