summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorOleksandr "Alex" Zinenko <zinenko@google.com>2024-01-09 13:18:57 +0100
committerGitHub <noreply@github.com>2024-01-09 13:18:57 +0100
commit633d9184f5f8ab227ab22fd7a7db366b843a02d2 (patch)
treeff018e38ab464db5f3826a319aca7462d352e697
parent4f7c402d9ff1b2c908b97b78baf84157f08745e8 (diff)
[mlir] introduce transform.collect_matching (#76724)
Introduce a new match combinator into the transform dialect. This operation collects all operations that are yielded by a satisfactory match into its results. This is a simpler version of `foreach_match` that can be inserted directly into existing transform scripts.
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformOps.td35
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp150
-rw-r--r--mlir/test/Dialect/Transform/ops-invalid.mlir68
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir44
4 files changed, 279 insertions, 18 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index fcdb21d21503..fe2c28f45aea 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
let hasVerifier = 1;
}
+def CollectMatchingOp : TransformDialectOp<"collect_matching", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Collects all payload ops that match the given named matcher";
+ let description = [{
+ Collects operations or other payload IR objects nested under `root`
+ (inclusive) that match the given matcher expressed as a named sequence. The
+ matcher sequence must accept exactly one argument that it is not allowed to
+ modify. It must yield as many values as this op has results. Each of the
+ yielded values must be associated with exactly one payload object. If any
+ operation in the matcher sequence produces a silenceable failure, the
+ matcher advances to the next payload operation in the walk order without
+ finishing the sequence.
+
+ The i-th result of this operation is constructed by concatenating the i-th
+ yielded payload IR objects of all successful matcher sequence applications.
+ All results are guaranteed to be mapped to the same number of payload IR
+ objects.
+
+ The operation succeeds unless the matcher sequence produced a definite
+ failure for any invocation.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$root,
+ SymbolRefAttr:$matcher);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+
+ let assemblyFormat = [{
+ $matcher `in` $root attr-dict `:` functional-type($root, $results)
+ }];
+}
+
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
- NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+ NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get handle to the producer of this operation's operand number";
let description = [{
The handle defined by this Transform op corresponds to operation that
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index aa4694c88d3b..b80fc09751d2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
//===----------------------------------------------------------------------===//
-// ForeachMatchOp
+// CollectMatchingOp
//===----------------------------------------------------------------------===//
/// Applies matcher operations from the given `block` assigning `op` as the
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+ return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+ return implementSameInterface<transform::TransformHandleTypeInterface,
+ transform::TransformParamTypeInterface,
+ transform::TransformValueHandleTypeInterface>(
+ t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+ getOperation(), getMatcher());
+ if (matcher.isExternal()) {
+ return emitDefiniteFailure()
+ << "unresolved external symbol " << getMatcher();
+ }
+
+ SmallVector<SmallVector<MappedValue>, 2> rawResults;
+ rawResults.resize(getOperation()->getNumResults());
+ std::optional<DiagnosedSilenceableFailure> maybeFailure;
+ for (Operation *root : state.getPayloadOps(getRoot())) {
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ DEBUG_MATCHER({
+ DBGS_MATCHER() << "matching ";
+ op->print(llvm::dbgs(),
+ OpPrintingFlags().assumeVerified().skipRegions());
+ llvm::dbgs() << " @" << op << "\n";
+ });
+
+ // Try matching.
+ SmallVector<SmallVector<MappedValue>> mappings;
+ DiagnosedSilenceableFailure diag =
+ matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+ if (diag.isDefiniteFailure())
+ return WalkResult::interrupt();
+ if (diag.isSilenceableFailure()) {
+ DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage());
+ return WalkResult::advance();
+ }
+
+ // If succeeded, collect results.
+ for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
+ if (mapping.size() != 1) {
+ maybeFailure.emplace(emitSilenceableError()
+ << "result #" << i << ", associated with "
+ << mapping.size()
+ << " payload objects, expected 1");
+ return WalkResult::interrupt();
+ }
+ rawResults[i].push_back(mapping[0]);
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return std::move(*maybeFailure);
+ assert(!maybeFailure && "failure set but the walk was not interrupted");
+
+ for (auto &&[opResult, rawResult] :
+ llvm::zip_equal(getOperation()->getResults(), rawResults)) {
+ results.setMappedValues(opResult, rawResult);
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CollectMatchingOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getRoot(), effects);
+ producesHandle(getResults(), effects);
+ onlyReadsPayload(effects);
+}
+
+LogicalResult transform::CollectMatchingOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
+ symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
+ if (!matcherSymbol ||
+ !isa<TransformOpInterface>(matcherSymbol.getOperation()))
+ return emitError() << "unresolved matcher symbol " << getMatcher();
+
+ ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
+ if (argumentTypes.size() != 1 ||
+ !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
+ return emitError()
+ << "expected the matcher to take one operation handle argument";
+ }
+ if (!matcherSymbol.getArgAttr(
+ 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
+ return emitError() << "expected the matcher argument to be marked readonly";
+ }
+
+ ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
+ if (resultTypes.size() != getOperation()->getNumResults()) {
+ return emitError()
+ << "expected the matcher to yield as many values as op has results ("
+ << getOperation()->getNumResults() << "), got "
+ << resultTypes.size();
+ }
+
+ for (auto &&[i, matcherType, resultType] :
+ llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
+ if (implementSameTransformInterface(matcherType, resultType))
+ continue;
+
+ return emitError()
+ << "mismatching type interfaces for matcher result and op result #"
+ << i;
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ForeachMatchOp
+//===----------------------------------------------------------------------===//
+
DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
return success();
}
-/// Returns `true` if both types implement one of the interfaces provided as
-/// template parameters.
-template <typename... Tys>
-static bool implementSameInterface(Type t1, Type t2) {
- return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
-}
-
-/// Returns `true` if both types implement one of the transform dialect
-/// interfaces.
-static bool implementSameTransformInterface(Type t1, Type t2) {
- return implementSameInterface<transform::TransformHandleTypeInterface,
- transform::TransformParamTypeInterface,
- transform::TransformValueHandleTypeInterface>(
- t1, t2);
-}
-
/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 5123958b02bf..233dbbcb6804 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{unresolved matcher symbol @missing_symbol}}
+ transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{expected the matcher to take one operation handle argument}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher() {
+ transform.yield
+ }
+}
+
+// -----
+
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{expected the matcher argument to be marked readonly}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op) {
+ transform.yield
+ }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.yield %arg0 : !transform.any_op
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 3bbf875ef309..4ecd731ce417 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } {
transform.yield
}
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
+ transform.yield %0 : !transform.any_op
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{unresolved external symbol @matcher}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-remark @below {{matched}}
+ %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below {{matched}}
+ transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+}