diff options
author | Matthias Springer <me@m-sp.org> | 2024-05-08 15:39:08 +0200 |
---|---|---|
committer | Matthias Springer <me@m-sp.org> | 2024-05-08 15:39:08 +0200 |
commit | 7b1d6a75b9f048e85b8b752218999eb39bf14d22 (patch) | |
tree | 43e42406464c0997320368ff2191085d145e1e2a | |
parent | bbd6a2d85c44d99e66b471d251a742f7551a0c61 (diff) |
[mlir][IR] Support op interfaces in `HasParent` traitupstream/users/matthias-springer/has_parent_interface
This commit adds support for op interfaces to `HasParent`: an op interface can now be specified as a parent.
To produce useful error messages, a new helper function `getInterfaceName` is generated for every op interface. This is similar to `getOperationName`, which is generated for operations.
This commit addresses a TODO in `TensorOps.td`.
-rw-r--r-- | mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 3 | ||||
-rw-r--r-- | mlir/include/mlir/IR/OpBase.td | 2 | ||||
-rw-r--r-- | mlir/include/mlir/IR/OpDefinition.h | 23 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/invalid.mlir | 9 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 10 |
6 files changed, 41 insertions, 10 deletions
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index a403e89a39f9..2d9f4c29f7aa 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1463,8 +1463,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, - // TODO: Cannot use an interface here atm, verify this manually for now. - // HasParent<"ParallelCombiningOpInterface"> + HasParent<"ParallelCombiningOpInterface"> ]> { let summary = [{ Specify the tensor slice update of a single thread of a parent diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 7866ac24c1cc..b089e72fe892 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -133,7 +133,7 @@ class SingleBlockImplicitTerminator<string op> // Op's regions don't have terminator. def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait; -// Op's parent operation is the provided one. +// Op's parent operation or op interface is the provided one. class HasParent<string op> : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 59f094d66909..550f04d9a373 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1298,7 +1298,9 @@ struct HasParent { return op->emitOpError() << "expects parent op " << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'") - << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'"; + << llvm::ArrayRef( + {getOperationOrInterfaceName<ParentOpTypes>()...}) + << "'"; } template <typename ParentOpType = @@ -1309,6 +1311,25 @@ struct HasParent { return llvm::cast<ParentOpType>(parent); } }; + +private: + /// A class is an op interface if it has a `getInterfaceName` function. + template <typename T, typename = int> + struct IsInterface : std::false_type {}; + template <typename T> + struct IsInterface<T, decltype((void)T::getInterfaceName(), 0)> + : std::true_type {}; + + /// Helper function that returns the name of the given operation or interface + /// as a string literal. + template <typename T> + static constexpr StringLiteral getOperationOrInterfaceName() { + if constexpr (IsInterface<T>::value) { + return T::getInterfaceName(); + } else { + return T::getOperationName(); + } + } }; /// A trait for operations that have an attribute specifying operand segments. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 7a13f7a7d135..f45c2e4efdf5 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3455,10 +3455,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, } LogicalResult ParallelInsertSliceOp::verify() { - if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp())) - return this->emitError("expected ParallelCombiningOpInterface parent, got:") - << *(getOperation()->getParentOp()); - RankedTensorType expectedType; SliceVerificationResult result = verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(), diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 41b6529f64af..4205d9c3dcd3 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -698,3 +698,12 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape( %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } + +// ----- + +func.func @parallel_insert_slice_out_of_context(%a: tensor<5xf32>, %b: tensor<100xf32>) { + // expected-error@+1 {{expects parent op 'ParallelCombiningOpInterface'}} + tensor.parallel_insert_slice %a into %b[0][5][1] + : tensor<5xf32> into tensor<100xf32> + return +} diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 2a7406f42f34..17babee913f0 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -537,7 +537,7 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { // Emit the derived trait for the interface. os << "template <typename " << valueTemplate << ">\n"; - os << "struct " << interface.getName() << "Trait;\n"; + os << "struct " << interfaceName << "Trait;\n"; os << "\n} // namespace detail\n"; @@ -548,6 +548,11 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); + // Insert function that returns the name of the interface as a string. + os << " static constexpr ::llvm::StringLiteral getInterfaceName() {\n" + << " return \"" << interfaceName << "\";\n" + << " }\n\n"; + // Emit a utility wrapper trait class. os << llvm::formatv(" template <typename {1}>\n" " struct Trait : public detail::{0}Trait<{1}> {{};\n", @@ -588,7 +593,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { << " auto* interface = getInterfaceFor(base);\n" << " if (!interface)\n" " return false;\n" - " " << interfaceName << " odsInterfaceInstance(base, interface);\n" + " " + << interfaceName << " odsInterfaceInstance(base, interface);\n" << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt) << "\n }\n"; } |