summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp36
1 files changed, 18 insertions, 18 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d7219fef87c..9c5c58fa1fab 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -373,14 +373,15 @@ namespace {
class RegionBuilderHelper {
public:
- RegionBuilderHelper(MLIRContext *context, Block &block)
- : context(context), block(block) {}
+ RegionBuilderHelper(OpBuilder &builder, Block &block)
+ : builder(builder), block(block) {}
// Build the unary functions defined by OpDSL.
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
if (!isFloatingPoint(arg))
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
case UnaryFn::exp:
return builder.create<math::ExpOp>(arg.getLoc(), arg);
@@ -407,7 +408,8 @@ public:
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (binaryFn) {
case BinaryFn::add:
if (allComplex)
@@ -481,29 +483,32 @@ public:
}
void yieldOutputs(ValueRange values) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
builder.create<YieldOp>(loc, values);
}
Value constant(const std::string &value) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
}
Value index(int64_t dim) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
- return IntegerType::get(context, width);
+ return IntegerType::get(builder.getContext(), width);
}
- Type getFloat32Type() { return Float32Type::get(context); }
- Type getFloat64Type() { return Float64Type::get(context); }
+ Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
+ Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
private:
// Generates operations to cast the given operand to a specified type.
@@ -511,7 +516,8 @@ private:
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
auto loc = operand.getLoc();
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
@@ -526,13 +532,7 @@ private:
return llvm::isa<IntegerType>(value.getType());
}
- OpBuilder getBuilder() {
- OpBuilder builder(context);
- builder.setInsertionPointToEnd(&block);
- return builder;
- }
-
- MLIRContext *context;
+ OpBuilder &builder;
Block &block;
};