summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-05-08 15:39:08 +0200
committerMatthias Springer <me@m-sp.org>2024-05-08 15:39:08 +0200
commit7b1d6a75b9f048e85b8b752218999eb39bf14d22 (patch)
tree43e42406464c0997320368ff2191085d145e1e2a
parentbbd6a2d85c44d99e66b471d251a742f7551a0c61 (diff)
[mlir][IR] Support op interfaces in `HasParent` traitupstream/users/matthias-springer/has_parent_interface
This commit adds support for op interfaces to `HasParent`: an op interface can now be specified as a parent. To produce useful error messages, a new helper function `getInterfaceName` is generated for every op interface. This is similar to `getOperationName`, which is generated for operations. This commit addresses a TODO in `TensorOps.td`.
-rw-r--r--mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td3
-rw-r--r--mlir/include/mlir/IR/OpBase.td2
-rw-r--r--mlir/include/mlir/IR/OpDefinition.h23
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp4
-rw-r--r--mlir/test/Dialect/Tensor/invalid.mlir9
-rw-r--r--mlir/tools/mlir-tblgen/OpInterfacesGen.cpp10
6 files changed, 41 insertions, 10 deletions
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f9..2d9f4c29f7aa 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1463,8 +1463,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
- // TODO: Cannot use an interface here atm, verify this manually for now.
- // HasParent<"ParallelCombiningOpInterface">
+ HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7866ac24c1cc..b089e72fe892 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -133,7 +133,7 @@ class SingleBlockImplicitTerminator<string op>
// Op's regions don't have terminator.
def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
-// Op's parent operation is the provided one.
+// Op's parent operation or op interface is the provided one.
class HasParent<string op>
: ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d66909..550f04d9a373 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1298,7 +1298,9 @@ struct HasParent {
return op->emitOpError()
<< "expects parent op "
<< (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
- << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
+ << llvm::ArrayRef(
+ {getOperationOrInterfaceName<ParentOpTypes>()...})
+ << "'";
}
template <typename ParentOpType =
@@ -1309,6 +1311,25 @@ struct HasParent {
return llvm::cast<ParentOpType>(parent);
}
};
+
+private:
+ /// A class is an op interface if it has a `getInterfaceName` function.
+ template <typename T, typename = int>
+ struct IsInterface : std::false_type {};
+ template <typename T>
+ struct IsInterface<T, decltype((void)T::getInterfaceName(), 0)>
+ : std::true_type {};
+
+ /// Helper function that returns the name of the given operation or interface
+ /// as a string literal.
+ template <typename T>
+ static constexpr StringLiteral getOperationOrInterfaceName() {
+ if constexpr (IsInterface<T>::value) {
+ return T::getInterfaceName();
+ } else {
+ return T::getOperationName();
+ }
+ }
};
/// A trait for operations that have an attribute specifying operand segments.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7a13f7a7d135..f45c2e4efdf5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3455,10 +3455,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}
LogicalResult ParallelInsertSliceOp::verify() {
- if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
- return this->emitError("expected ParallelCombiningOpInterface parent, got:")
- << *(getOperation()->getParentOp());
-
RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64af..4205d9c3dcd3 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -698,3 +698,12 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape(
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func.func @parallel_insert_slice_out_of_context(%a: tensor<5xf32>, %b: tensor<100xf32>) {
+ // expected-error@+1 {{expects parent op 'ParallelCombiningOpInterface'}}
+ tensor.parallel_insert_slice %a into %b[0][5][1]
+ : tensor<5xf32> into tensor<100xf32>
+ return
+}
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34..17babee913f0 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -537,7 +537,7 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Emit the derived trait for the interface.
os << "template <typename " << valueTemplate << ">\n";
- os << "struct " << interface.getName() << "Trait;\n";
+ os << "struct " << interfaceName << "Trait;\n";
os << "\n} // namespace detail\n";
@@ -548,6 +548,11 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);
+ // Insert function that returns the name of the interface as a string.
+ os << " static constexpr ::llvm::StringLiteral getInterfaceName() {\n"
+ << " return \"" << interfaceName << "\";\n"
+ << " }\n\n";
+
// Emit a utility wrapper trait class.
os << llvm::formatv(" template <typename {1}>\n"
" struct Trait : public detail::{0}Trait<{1}> {{};\n",
@@ -588,7 +593,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
<< " auto* interface = getInterfaceFor(base);\n"
<< " if (!interface)\n"
" return false;\n"
- " " << interfaceName << " odsInterfaceInstance(base, interface);\n"
+ " "
+ << interfaceName << " odsInterfaceInstance(base, interface);\n"
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
<< "\n }\n";
}