diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 2d7219fef87c..9c5c58fa1fab 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -373,14 +373,15 @@ namespace { class RegionBuilderHelper { public: - RegionBuilderHelper(MLIRContext *context, Block &block) - : context(context), block(block) {} + RegionBuilderHelper(OpBuilder &builder, Block &block) + : builder(builder), block(block) {} // Build the unary functions defined by OpDSL. Value buildUnaryFn(UnaryFn unaryFn, Value arg) { if (!isFloatingPoint(arg)) llvm_unreachable("unsupported non numeric type"); - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); switch (unaryFn) { case UnaryFn::exp: return builder.create<math::ExpOp>(arg.getLoc(), arg); @@ -407,7 +408,8 @@ public: arg1.getType().getIntOrFloatBitWidth() == 1; if (!allComplex && !allFloatingPoint && !allInteger) llvm_unreachable("unsupported non numeric type"); - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); switch (binaryFn) { case BinaryFn::add: if (allComplex) @@ -481,29 +483,32 @@ public: } void yieldOutputs(ValueRange values) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); builder.create<YieldOp>(loc, values); } Value constant(const std::string &value) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr)); } Value index(int64_t dim) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); return builder.create<IndexOp>(builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { - return IntegerType::get(context, width); + return IntegerType::get(builder.getContext(), width); } - Type getFloat32Type() { return Float32Type::get(context); } - Type getFloat64Type() { return Float64Type::get(context); } + Type getFloat32Type() { return Float32Type::get(builder.getContext()); } + Type getFloat64Type() { return Float64Type::get(builder.getContext()); } private: // Generates operations to cast the given operand to a specified type. @@ -511,7 +516,8 @@ private: // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand, bool isUnsignedCast) { - OpBuilder builder = getBuilder(); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); auto loc = operand.getLoc(); return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } @@ -526,13 +532,7 @@ private: return llvm::isa<IntegerType>(value.getType()); } - OpBuilder getBuilder() { - OpBuilder builder(context); - builder.setInsertionPointToEnd(&block); - return builder; - } - - MLIRContext *context; + OpBuilder &builder; Block █ }; |