diff options
author | Oleksandr "Alex" Zinenko <zinenko@google.com> | 2024-01-09 13:18:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 13:18:57 +0100 |
commit | 633d9184f5f8ab227ab22fd7a7db366b843a02d2 (patch) | |
tree | ff018e38ab464db5f3826a319aca7462d352e697 | |
parent | 4f7c402d9ff1b2c908b97b78baf84157f08745e8 (diff) |
[mlir] introduce transform.collect_matching (#76724)
Introduce a new match combinator into the transform dialect. This
operation collects all operations that are yielded by a satisfactory
match into its results. This is a simpler version of `foreach_match`
that can be inserted directly into existing transform scripts.
-rw-r--r-- | mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 35 | ||||
-rw-r--r-- | mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 150 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops-invalid.mlir | 68 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/test-interpreter.mlir | 44 |
4 files changed, 279 insertions, 18 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index fcdb21d21503..fe2c28f45aea 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations", let hasVerifier = 1; } +def CollectMatchingOp : TransformDialectOp<"collect_matching", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + DeclareOpInterfaceMethods<SymbolUserOpInterface>, + DeclareOpInterfaceMethods<TransformOpInterface>]> { + let summary = "Collects all payload ops that match the given named matcher"; + let description = [{ + Collects operations or other payload IR objects nested under `root` + (inclusive) that match the given matcher expressed as a named sequence. The + matcher sequence must accept exactly one argument that it is not allowed to + modify. It must yield as many values as this op has results. Each of the + yielded values must be associated with exactly one payload object. If any + operation in the matcher sequence produces a silenceable failure, the + matcher advances to the next payload operation in the walk order without + finishing the sequence. + + The i-th result of this operation is constructed by concatenating the i-th + yielded payload IR objects of all successful matcher sequence applications. + All results are guaranteed to be mapped to the same number of payload IR + objects. + + The operation succeeds unless the matcher sequence produced a definite + failure for any invocation. + }]; + + let arguments = (ins TransformHandleTypeInterface:$root, + SymbolRefAttr:$matcher); + let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results); + + let assemblyFormat = [{ + $matcher `in` $root attr-dict `:` functional-type($root, $results) + }]; +} + def ForeachMatchOp : TransformDialectOp<"foreach_match", [ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<SymbolUserOpInterface>, @@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op", def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", [DeclareOpInterfaceMethods<TransformOpInterface>, - NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> { let summary = "Get handle to the producer of this operation's operand number"; let description = [{ The handle defined by this Transform op corresponds to operation that diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index aa4694c88d3b..b80fc09751d2 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" @@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } //===----------------------------------------------------------------------===// -// ForeachMatchOp +// CollectMatchingOp //===----------------------------------------------------------------------===// /// Applies matcher operations from the given `block` assigning `op` as the @@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state, return DiagnosedSilenceableFailure::success(); } +/// Returns `true` if both types implement one of the interfaces provided as +/// template parameters. +template <typename... Tys> +static bool implementSameInterface(Type t1, Type t2) { + return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); +} + +/// Returns `true` if both types implement one of the transform dialect +/// interfaces. +static bool implementSameTransformInterface(Type t1, Type t2) { + return implementSameInterface<transform::TransformHandleTypeInterface, + transform::TransformParamTypeInterface, + transform::TransformValueHandleTypeInterface>( + t1, t2); +} + +//===----------------------------------------------------------------------===// +// CollectMatchingOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( + getOperation(), getMatcher()); + if (matcher.isExternal()) { + return emitDefiniteFailure() + << "unresolved external symbol " << getMatcher(); + } + + SmallVector<SmallVector<MappedValue>, 2> rawResults; + rawResults.resize(getOperation()->getNumResults()); + std::optional<DiagnosedSilenceableFailure> maybeFailure; + for (Operation *root : state.getPayloadOps(getRoot())) { + WalkResult walkResult = root->walk([&](Operation *op) { + DEBUG_MATCHER({ + DBGS_MATCHER() << "matching "; + op->print(llvm::dbgs(), + OpPrintingFlags().assumeVerified().skipRegions()); + llvm::dbgs() << " @" << op << "\n"; + }); + + // Try matching. + SmallVector<SmallVector<MappedValue>> mappings; + DiagnosedSilenceableFailure diag = + matchBlock(matcher.getFunctionBody().front(), op, state, mappings); + if (diag.isDefiniteFailure()) + return WalkResult::interrupt(); + if (diag.isSilenceableFailure()) { + DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() + << " failed: " << diag.getMessage()); + return WalkResult::advance(); + } + + // If succeeded, collect results. + for (auto &&[i, mapping] : llvm::enumerate(mappings)) { + if (mapping.size() != 1) { + maybeFailure.emplace(emitSilenceableError() + << "result #" << i << ", associated with " + << mapping.size() + << " payload objects, expected 1"); + return WalkResult::interrupt(); + } + rawResults[i].push_back(mapping[0]); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return std::move(*maybeFailure); + assert(!maybeFailure && "failure set but the walk was not interrupted"); + + for (auto &&[opResult, rawResult] : + llvm::zip_equal(getOperation()->getResults(), rawResults)) { + results.setMappedValues(opResult, rawResult); + } + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::CollectMatchingOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getRoot(), effects); + producesHandle(getResults(), effects); + onlyReadsPayload(effects); +} + +LogicalResult transform::CollectMatchingOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( + symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); + if (!matcherSymbol || + !isa<TransformOpInterface>(matcherSymbol.getOperation())) + return emitError() << "unresolved matcher symbol " << getMatcher(); + + ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); + if (argumentTypes.size() != 1 || + !isa<TransformHandleTypeInterface>(argumentTypes[0])) { + return emitError() + << "expected the matcher to take one operation handle argument"; + } + if (!matcherSymbol.getArgAttr( + 0, transform::TransformDialect::kArgReadOnlyAttrName)) { + return emitError() << "expected the matcher argument to be marked readonly"; + } + + ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); + if (resultTypes.size() != getOperation()->getNumResults()) { + return emitError() + << "expected the matcher to yield as many values as op has results (" + << getOperation()->getNumResults() << "), got " + << resultTypes.size(); + } + + for (auto &&[i, matcherType, resultType] : + llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { + if (implementSameTransformInterface(matcherType, resultType)) + continue; + + return emitError() + << "mismatching type interfaces for matcher result and op result #" + << i; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ForeachMatchOp +//===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() { return success(); } -/// Returns `true` if both types implement one of the interfaces provided as -/// template parameters. -template <typename... Tys> -static bool implementSameInterface(Type t1, Type t2) { - return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); -} - -/// Returns `true` if both types implement one of the transform dialect -/// interfaces. -static bool implementSameTransformInterface(Type t1, Type t2) { - return implementSameInterface<transform::TransformHandleTypeInterface, - transform::TransformParamTypeInterface, - transform::TransformValueHandleTypeInterface>( - t1, t2); -} - /// Checks that the attributes of the function-like operation have correct /// consumption effect annotations. If `alsoVerifyInternal`, checks for /// annotations being present even if they can be inferred from the body. diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 5123958b02bf..233dbbcb6804 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -704,3 +704,71 @@ transform.sequence failures(propagate) { // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}} transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32> } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{unresolved matcher symbol @missing_symbol}} + transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{expected the matcher to take one operation handle argument}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher() { + transform.yield + } +} + +// ----- + + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{expected the matcher argument to be marked readonly}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op) { + transform.yield + } +} + + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) { + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{mismatching type interfaces for matcher result and op result #0}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.yield %arg0 : !transform.any_op + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 3bbf875ef309..4ecd731ce417 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{result #0, associated with 2 payload objects, expected 1}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.merge_handles %arg0, %arg0 : !transform.any_op + transform.yield %0 : !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{unresolved external symbol @matcher}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-remark @below {{matched}} + %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-remark @below {{matched}} + transform.test_print_remark_at_operand %0, "matched" : !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op + transform.yield %arg0 : !transform.any_op + } +} |