diff options
author | Oleksandr "Alex" Zinenko <zinenko@google.com> | 2024-01-09 13:19:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 13:19:41 +0100 |
commit | 4cb2ef4fe372d32d1773f4dd358d6dff91518b5f (patch) | |
tree | 74224376300205c7613f6ac65321db01b848554b | |
parent | 633d9184f5f8ab227ab22fd7a7db366b843a02d2 (diff) |
[mlir] add a chapter on matchers to the transform dialect tutorial (#76725)
These operations has been available for a while, but were not described
in the tutorial. Add a new chapter on using and defining match
operations.
-rw-r--r-- | mlir/docs/Tutorials/transform/Ch4.md | 581 | ||||
-rw-r--r-- | mlir/docs/Tutorials/transform/_index.md | 1 | ||||
-rw-r--r-- | mlir/examples/transform/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/CMakeLists.txt | 21 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/include/CMakeLists.txt | 14 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/include/MyExtension.h | 30 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/include/MyExtension.td | 46 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/lib/CMakeLists.txt | 20 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/lib/MyExtension.cpp | 207 | ||||
-rw-r--r-- | mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp | 55 | ||||
-rw-r--r-- | mlir/test/CMakeLists.txt | 2 | ||||
-rw-r--r-- | mlir/test/Examples/transform/Ch4/features.mlir | 123 | ||||
-rw-r--r-- | mlir/test/Examples/transform/Ch4/multiple.mlir | 131 | ||||
-rw-r--r-- | mlir/test/Examples/transform/Ch4/sequence.mlir | 139 | ||||
-rw-r--r-- | mlir/test/lit.cfg.py | 5 |
16 files changed, 1375 insertions, 3 deletions
diff --git a/mlir/docs/Tutorials/transform/Ch4.md b/mlir/docs/Tutorials/transform/Ch4.md new file mode 100644 index 000000000000..77c36eab343d --- /dev/null +++ b/mlir/docs/Tutorials/transform/Ch4.md @@ -0,0 +1,581 @@ +# Chapter 4: Matching Payload with Transform Operations + +**Check the continuously-tested version of MLIR files under +[mlir/test/Examples/transform/Ch4](https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/Ch4).** + +Up until now, we were applying transform dialect scripts under the assumption +that specific payload operations are identified by the caller when the transform +dialect interpreter is invoked. This may be seen as contrary to the idea of +driving transformations from a dialect since the transformation targets must be +identified through mechanisms external to the transform dialect interpreter, for +example, when invoking the interpreter programmatically in C++ or through pass +arguments as seen in previous chapters. It also adds practical overhead due to +increased interaction with the interpreter in C++, and cognitive overhead of +manipulating two interfaces at once. To remedy this, Transform dialect proposes +a subset of operations for _matching_ payload operations that need to be +transformed. + +_Match_ operations are simply transform operations with some additional +guarantees. In particular, they are not expected to modify the payload IR and +are expected to fail if their operands (typically payload operation handles) are +not associated with payload IR objects having desired properties, such as +operation names or kinds of arguments. Using simple combinator operations, it +becomes possible to set up a higher-level match and rewrite infrastructure +directly within the transform dialect. + + +## Simple match + +Let us reconsider the “fully connected layer” example from [Chapter +1](Ch1.md#chaining-transformations-with-handles), reproduced below for +convenience. + + +```mlir +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul + ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +``` + + +In Chapter 1, we were calling the test transform interpreter pass with +additional arguments, `bind-first-extra-to-ops=linalg.matmul +bind-second-extra-to-ops=linalg.elemwise_binary`, to provide initial +associations for operation handles. Instead, we can use match operations to +discover relevant operations in the payload IR. Match operations can be combined +with “regular” transform operations using, e.g., the +`transform.collect_matching` combinator operation that leverages the concept of +named sequences to organize matchers. + + +```mlir +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.readonly}) { + // Collect operations that match the criteria specified in named sequence. + // If the named sequence fails with a silenceable failure, silences it (the + // message is forwarded to the debug stream). If the named sequence + // succeeds, appends its results to the results of this operation. + %elemwise = transform.collect_matching @match_elemwise in %root + : (!transform.any_op) -> !transform.any_op + %matmul = transform.collect_matching @match_matmul in %root + : (!transform.any_op) -> !transform.any_op + transform.include @print_elemwise failures(propagate) (%elemwise) + : (!transform.any_op) -> () + transform.include @print_matmul failures(propagate) (%matmul) + : (!transform.any_op) -> () + + transform.yield + } + + // This is a matcher sequence. It is given an operation to match and the + // match is considered successful unless any nested operation produces a + // failure. The values yielded by this operation will be forwarded to the + // rewriter sequence on success. + transform.named_sequence @match_elemwise( + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %entry ["linalg.elemwise_binary"] + : !transform.any_op + transform.yield %entry : !transform.any_op + } + transform.named_sequence @match_matmul( + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op + transform.yield %entry : !transform.any_op + } + + // This is a rewriter sequence. + transform.named_sequence @print_elemwise( + %elemwise_binary: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand + %elemwise_binary, "elementwise binary" : !transform.any_op + transform.yield + } + transform.named_sequence @print_matmul( + %matmul: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op + transform.yield + } +} + +``` + + +This script can be executed using the non-test interpreter pass running on the +root operation of the translation unit without additional flags: `mlir-opt +--transform-interpreter`. It will emit corresponding remarks at +`linalg.elemwise_binary` and `linalg.matmul` operations. In debug builds, the +infrastructure provides a convenient method to understand the matching process +by passing `-debug-only=transform-matcher` to `mlir-opt` or a derived tool. It +will print the silenceable failure messages produced by the match operations +into the debug stream, for example: + + +``` +<...> +[transform-matcher] matching %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> @0x5622eee08410 +[transform-matcher] matcher match_elemwise failed: wrong operation name +<...> +``` + + +This is now sufficient to run the rest of the transform script from Chapter 1, +substituting `%arg1` with `%matmul` and `%arg2` with `%elemwise`. + + +## Matching Chains of Operations + +The matcher above remains naive as it matches _all_ operations of the certain +kind under the payload root. These operations may or may not be related, and +may, for example, belong to different functions. Even if they are in a single +function, if there are multiple groups of such operations, we wouldn’t be able +to differentiate them with this approach. In reality, we want to match a +specific group of operations where a `matmul` operation produces a result that +is used by an elementwise operation, which in turn feeds another elementwise +operation in a similar way. + +This can be achieved using the following matcher sequence. + + +```mlir +// This is also a matcher sequence. It is similarly given an operation to +// match and nested operations must succeed in order for a match to be deemed +// successful. It starts matching from the last operation in the use-def chain +// and goes back because each operand (use) has exactly one definition. +transform.named_sequence @match_matmul_elemwise( + %last: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op) { + // The last operation must be an elementwise binary. + transform.match.operation_name %last ["linalg.elemwise_binary"] + : !transform.any_op + // Its first operand must be defined by another operation, to which we + // will get a handle here. We are guaranteed that the first operand exists + // because we know the operation is binary, but even in absence of such a + // guarantee, this operation would have produced a silenceable failure when + // `%last` does not have enough operands. + %middle = transform.get_producer_of_operand %last[0] + : (!transform.any_op) -> !transform.any_op + // The defining operation must itself be an elementwise binary. + transform.match.operation_name %middle ["linalg.elemwise_binary"] + : !transform.any_op + // And the first operand of that operation must be defined by yet another + // operation. + %matmul = transform.get_producer_of_operand %middle[0] + : (!transform.any_op) -> !transform.any_op + // And that operation is a matmul. + transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op + // We will yield the handles to the matmul and the two elementwise + // operations separately. + transform.yield %matmul, %middle, %last + : !transform.any_op, !transform.any_op, !transform.any_op +} +``` + +This matcher is applicable in presence of other `elemwise` and `matmul` +operations and will return the triple of _related_ operations rather than +operations in the order in which they are found. It can be exercised similarly +to the previous incarnation, as follows. + +```mlir +// Alternative entry point. +transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.readonly}) { + // Collect groups of operations that match the criteria specified in the + // named sequence. + %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op + + transform.include @print_elemwise failures(propagate) (%elemwise) + : (!transform.any_op) -> () + transform.include @print_matmul failures(propagate) (%matmul) + : (!transform.any_op) -> () + + transform.yield +} +``` + + +## Defining Match Operations + +The matcher of a chain of operations is correct in presence of other operations, +but is still insufficiently robust for many cases of interest. In particular, +using `transform.get_producer_of_operand %last[0]` requires that the _first_ +operand of elementwise operations is produced by another operation. The same +transformation strategy may however apply regardless of the operand position: +many binary operations are associative. Let us use this opportunity to introduce +a new match operation. Specifically, we would like this operation to succeed if +_any_ of the operands satisfies certain conditions that can be expressed as +other match operations. We also want it to return some of the state and the +position of the matched operand in the operand list. + +Match operations are defined similarly to other transform operations, with the +only difference of additionally implementing the `MatchOpInterface`. Note that +this interface has _no additional methods_ (though it may add some eventually) +and is only used as a verification contract that the operation is intended for +matching and will not attempt to transform the payload. The minimal definition +of our operation is as follows. + + +```tablegen +// Define the new operation. By convention, prefix its name with `match` +// followed by the name of the dialect extension. +def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying", + [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + DeclareOpInterfaceMethods<TransformOpInterface>, + // Indicate that the operation implements MatchOpInterface in addition to + // the TransformOpInterface. This interface is only used as a tag at this + // point and has no methods that are mandatory to implement. + MatchOpInterface, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { + let summary = "Succeed if any of the operands matches all nested criteria"; + let arguments = (ins TransformHandleTypeInterface:$op); + let results = (outs TransformParamTypeInterface:$position, + Variadic<Transform_AnyHandleOrParamType>:$results); + + // Match operations can be arbitrarily complex, e.g., containing regions. + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = [{ + $op `:` functional-type($op, results) attr-dict-with-keyword $body + }]; +} +``` + + +It takes as argument the handle associated with the payload operations whose +operands it will match, has an associated single-block region containing the +match criteria, and returns the position of the matched operand as well as any +other transform value yielded from the body on the successful match. + +The matching logic is implemented in the `apply` method of the +`TransformOpInterface` and is easily composable with other transform operations. +All facilities for managing the interpreter state and recursively entering the +blocks are available in the same way as they are for “regular” transform +operations. Match operations are expected to return a silenceable failure to +indicate failure to match, and to immediately propagate definite failures. If +they have nested operations, they are expected to handle and, in most cases, +silence the silenceable failures produced when applying those operations. For +our operation, the matching is essentially a loop iterating over all operands of +the (single) payload operation and applying nested transform ops until they all +succeed for one of the operands. + + +```cpp +// Matcher ops implement `apply` similarly to other transform ops. They are not +// expected to modify payload, but use the tri-state result to signal failure or +// success to match, as well as potential irrecoverable errors. +mlir::DiagnosedSilenceableFailure +mlir::transform::HasOperandSatisfyingOp::apply( + mlir::transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &results, + mlir::transform::TransformState &state) { + // For simplicity, only handle a single payload op. Actual implementations + // can use `SingleOpMatcher` trait to simplify implementation and document + // this expectation. + auto payloadOps = state.getPayloadOps(getOp()); + if (!llvm::hasSingleElement(payloadOps)) + return emitSilenceableError() << "expected single payload"; + + // Iterate over all operands of the payload op to see if they can be matched + // using the body of this op. + Operation *payload = *payloadOps.begin(); + for (OpOperand &operand : payload->getOpOperands()) { + // Create a scope for transform values defined in the body. This corresponds + // to the syntactic scope of the region attached to this op. Any values + // associated with payloads from now on will be automatically dissociated + // when this object is destroyed, i.e. at the end of the iteration. + // Associate the block argument handle with the operand. + auto matchScope = state.make_region_scope(getBody()); + if (failed(state.mapBlockArgument(getBody().getArgument(0), + {operand.get()}))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + + // Iterate over all nested matchers with the current mapping and see if they + // succeed. + bool matchSucceeded = true; + for (Operation &matcher : getBody().front().without_terminator()) { + // Matcher ops are applied similarly to any other transform op. + DiagnosedSilenceableFailure diag = + state.applyTransform(cast<TransformOpInterface>(matcher)); + + // Definite failures are immediately propagated as they are irrecoverable. + if (diag.isDefiniteFailure()) + return diag; + + // On success, keep checking the remaining conditions. + if (diag.succeeded()) + continue; + + // Report failure-to-match for debugging purposes and stop matching this + // operand. + assert(diag.isSilenceableFailure()); + DEBUG_MATCHER(DBGS_MATCHER() + << "failed to match operand #" << operand.getOperandNumber() + << ": " << diag.getMessage()); + (void)diag.silence(); + matchSucceeded = false; + break; + } + // If failed to match this operand, try other operands. + if (!matchSucceeded) + continue; + + // If we reached this point, the matching succeeded for the current operand. + // Remap the values associated with terminator operands to be associated + // with op results, and also map the parameter result to the operand's + // position. Note that it is safe to do here despite the end of the scope + // as `results` are integrated into `state` by the interpreter after `apply` + // returns rather than immediately. + SmallVector<SmallVector<MappedValue>> yieldedMappings; + transform::detail::prepareValueMappings( + yieldedMappings, getBody().front().getTerminator()->getOperands(), + state); + results.setParams(getPosition().cast<OpResult>(), + {rewriter.getI32IntegerAttr(operand.getOperandNumber())}); + for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings)) + results.setMappedValues(result, mapping); + return DiagnosedSilenceableFailure::success(); + } + + // If we reached this point, none of the operands succeeded the match. + return emitSilenceableError() + << "none of the operands satisfied the conditions"; +} + +``` + + +By convention, operations implementing `MatchOpInterface` must not modify +payload IR and must therefore specify that they only read operand handles and +payload as their effects. + + +```cpp +void transform::CollectMatchingOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getRoot(), effects); + producesHandle(getResults(), effects); + onlyReadsPayload(effects); +} +``` + + +This operation can now be included in a transform dialect extension, loaded and +used in our matcher. Specifically, we will use it to indicate that either of the +operands of the “max” elementwise operation in our example can be produced by +the previous elementwise operation. The previous operation will still require +the matmul to produce the first operand for simplicity. The updated matcher +sequence looks as follows. + + +```mlir +transform.named_sequence @match_matmul_elemwise( + %last: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op, + !transform.param<i32>) { + // The last operation must be an elementwise binary. + transform.match.operation_name %last ["linalg.elemwise_binary"] + : !transform.any_op + + // One of its operands must be defined by another operation, to which we + // will get a handle here. This is achieved thanks to a newly defined + // operation that tries to match operands one by one using the match + // operations nested in its region. + %pos, %middle = transform.match.my.has_operand_satisfying %last + : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) { + ^bb0(%operand: !transform.any_value): + // The operand must be defined by an operation. + %def = transform.get_defining_op %operand + : (!transform.any_value) -> !transform.any_op + // The defining operation must itself be an elementwise binary. + transform.match.operation_name %def ["linalg.elemwise_binary"] + : !transform.any_op + transform.yield %def : !transform.any_op + } + + // And the first operand of that operation must be defined by yet another + // operation. + %matmul = transform.get_producer_of_operand %middle[0] + : (!transform.any_op) -> !transform.any_op + // And that operation is a matmul. + transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op + // We will yield the handles to the matmul and the two elementwise + // operations separately. + transform.yield %matmul, %middle, %last, %pos + : !transform.any_op, !transform.any_op, !transform.any_op, + !transform.param<i32> +} +``` + + +This achieves the desired effect and matches both `max(add(matmul(...), bias), +0)` and `max(0, add(matmul(...), bias))` in the same values. The `%pos` value is +a transform dialect _parameter_, which is used to store lists of entities known +to be constant throughout the transform application. Most often, parameters are +numeric values, but they can generally be any MLIR attributes. + +In order to demonstrate that groups of operations are matched independently of +each other, let us use the `transform.foreach_match` operation that allows one +to implement a simple high-level pattern rewriting approach within the transform +dialect (for advanced or lower-level pattern rewriting, consider PDL(L) or C++ +rewriting APIs). It maps a matcher named sequence to an action named sequence, +and the latter gets invoked whenever the former succeeds. + + +```mlir +// Traverses the payload IR associated with the operand handle, invoking +// @match_matmul_elemwise on each of the operations. If the named sequence +// succeeds, i.e., if none of the nested match (transform) operations +// produced a silenceable failure, invokes @print_matmul_elemwise and +// forwards the values yielded as arguments of the new invocation. If the +// named sequence fails with a silenceable failure, silences it (the message +// is forwarded to the debug stream). Definite failures are propagated +// immediately and unconditionally, as usual. +transform.foreach_match in %root + @match_matmul_elemwise -> @print_matmul_elemwise + : (!transform.any_op) -> !transform.any_op +``` + + +The `@print_matmul_elemwise` named sequence, available in `multiple.mlir`, will +use the parameter with the position of the operand to differentiate the two +groups. + + +## Matchers for Inferred Features + +The matcher sequences described above, although useful to drive transformations +from within the transform dialect interpreter, are rather basic since they +mostly rely on operation names and use-def chains. Alternative implementations +using APIs or various declarative rewrite rules are barely less expressive and +sometimes more concise. The real power of transform dialect matcher ops lies in +the possibility to define matchers of _inferred properties_ of payloads, i.e., +properties that are not directly accessible as an attribute of an operation or +any straightforward relation between IR components. + +The utility of such matchers can be easily demonstrated by slightly modifying +our original example. If matrix multiplication is expressed as a special case of +tensor contraction using `linalg.generic` instead of `linalg.matmul`, the +operation name-based matcher no longer applies. Yet such a representation is +very common and can appear both in the original input and during the course of +transformation, e.g., where a higher-dimensional contraction is decomposed into +loops around a matrix multiplication. + +In order to be a (potentially transposed) matrix multiplication, the +`linalg.generic` operation must have the following features: + + + +* Total rank of 3. +* Two inputs accessed as projected permutation of iteration dimensions. +* One output accessed as projected permutation of iteration dimensions. +* Iteration dimensions can be subdivided into LHS parallel, RHS parallel and reduction dimensions. +* The body block consists of a multiplication and an addition. + +Most of these features can be derived from the properties of the operation, +e.g., the total rank corresponds to the number of entries in the `iterators` +attribute, but almost none of them are immediately accessible in the IR or in +any declarative form, which is usually limited to checking the presence or the +exact match of an attribute or a type. The transform dialect allows these +features to be implemented in the `apply` method of a matcher op and reused +across multiple matching cases. For structured linear algebra payload +operations, many such match operations are readily available in the `structured` +extension. They are sufficient to implement a matrix multiplication matcher +using the features listed above almost verbatim. + + +```mlir +transform.named_sequence @match_generic_matmul( + %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op { + // Match a structured linear algebra operation. + transform.match.structured %candidate : !transform.any_op { + ^bb0(%c: !transform.any_op): + // With a rank equal to 3. + %rank = transform.match.structured.rank %c + : (!transform.any_op) -> !transform.param<i64> + %c3 = transform.param.constant 3 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64> + + // With 2 inputs. + %n_ins = transform.match.structured.num_inputs %c + : (!transform.any_op) -> !transform.param<i64> + %c2 = transform.param.constant 2 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64> + + // With 1 output (note that structured ops in destination passing style + // has as many inits as outputs). + %n_inits = transform.match.structured.num_inits %c + : (!transform.any_op) -> !transform.param<i64> + %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64> + + // All inputs and inits are accessed with a projected permutation. + transform.match.structured.input %c[all] {projected_permutation} + : !transform.any_op + transform.match.structured.init %c[0] {projected_permutation} + : !transform.any_op + + // The body is a mulf/addf contraction with appropriate dimensions. + transform.match.structured.body %c + { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op + %batch, %lhs, %rhs, %reduction = + transform.match.structured.classify_contraction_dims %c + : (!transform.any_op) + -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, + !transform.param<i64>) + + + // There is one of lhs, rhs and reduction dimensions and zero batch + // dimensions. + %n_batch = transform.num_associations %batch + : (!transform.param<i64>) -> !transform.param<i64> + %n_lhs = transform.num_associations %lhs + : (!transform.param<i64>) -> !transform.param<i64> + %n_rhs = transform.num_associations %rhs + : (!transform.param<i64>) -> !transform.param<i64> + %n_reduction = transform.num_associations %reduction + : (!transform.param<i64>) -> !transform.param<i64> + %c0 = transform.param.constant 0 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64> + transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64> + transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64> + transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64> + } + transform.yield %candidate : !transform.any_op +} +``` + + +While this example leverages the contraction-specific matchers that have a +rather non-trivial C++ implementation, the transform dialect is sufficiently +flexible to implement this reasoning directly if desired. One could, for +example, obtain the access map of each input as a parameter and extract the +accessed dimensions as other parameters that can be compared with each other to +ensure the subscripts are `m,k` for LHS, `k,n` for RHS and `m,n` for the +init/result given the `m,n,k` notation for loops. + diff --git a/mlir/docs/Tutorials/transform/_index.md b/mlir/docs/Tutorials/transform/_index.md index b508a5d1d535..8a5af5625450 100644 --- a/mlir/docs/Tutorials/transform/_index.md +++ b/mlir/docs/Tutorials/transform/_index.md @@ -26,6 +26,7 @@ The tutorial is divided into the following chapters. - [Chapter #1](Ch1.md): Combining Existing Transformations - [Chapter #2](Ch2.md): Adding a Simple New Transformation Operation - [Chapter #3](Ch3.md): More than Simple Transform Operations +- [Chapter #4](Ch4.md): Matching Payload with Transform Operations - [Chapter H](ChH.md): Reproducing Halide Schedule The code corresponding to this tutorial is located under diff --git a/mlir/examples/transform/CMakeLists.txt b/mlir/examples/transform/CMakeLists.txt index 3f3740ad2a8d..b688aa7461d6 100644 --- a/mlir/examples/transform/CMakeLists.txt +++ b/mlir/examples/transform/CMakeLists.txt @@ -2,3 +2,4 @@ add_custom_target(TransformExample) add_subdirectory(Ch2) add_subdirectory(Ch3) +add_subdirectory(Ch4) diff --git a/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp b/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp index 1e4367ad4690..3c348c663aba 100644 --- a/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp +++ b/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This is the top-level file for the Transform dialect tutorial chapter 2. +// This is the top-level file for the Transform dialect tutorial chapter 3. // //===----------------------------------------------------------------------===// diff --git a/mlir/examples/transform/Ch4/CMakeLists.txt b/mlir/examples/transform/Ch4/CMakeLists.txt new file mode 100644 index 000000000000..c070a04a35a8 --- /dev/null +++ b/mlir/examples/transform/Ch4/CMakeLists.txt @@ -0,0 +1,21 @@ +# For a better top-level template to copy, see examples/standalone. + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) + +add_dependencies(TransformExample transform-opt-ch4) +add_llvm_example(transform-opt-ch4 + transform-opt/transform-opt.cpp) + +target_link_libraries(transform-opt-ch4 + PRIVATE + MLIRIR + MLIRMlirOptMain + MLIRSideEffectInterfaces + MLIRTransformDialectTransforms + MyExtensionCh4 +) diff --git a/mlir/examples/transform/Ch4/include/CMakeLists.txt b/mlir/examples/transform/Ch4/include/CMakeLists.txt new file mode 100644 index 000000000000..1f960e590529 --- /dev/null +++ b/mlir/examples/transform/Ch4/include/CMakeLists.txt @@ -0,0 +1,14 @@ +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the +# compilation. +add_public_tablegen_target(MyExtensionCh4IncGen) + +# Don't forget to generate the documentation, this will produce a +# MyExtensionCh4.md under Tutorials/transform +add_mlir_doc(MyExtension MyExtensionCh4 Tutorials/transform/ -gen-op-doc) diff --git a/mlir/examples/transform/Ch4/include/MyExtension.h b/mlir/examples/transform/Ch4/include/MyExtension.h new file mode 100644 index 000000000000..13e5b3c04b02 --- /dev/null +++ b/mlir/examples/transform/Ch4/include/MyExtension.h @@ -0,0 +1,30 @@ +//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 4 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" + +namespace mlir { +class CallOpInterface; +namespace func { +class CallOp; +} // namespace func +} // namespace mlir + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" + +// Registers our Transform dialect extension. +void registerMyExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/examples/transform/Ch4/include/MyExtension.td b/mlir/examples/transform/Ch4/include/MyExtension.td new file mode 100644 index 000000000000..ae58dc37db43 --- /dev/null +++ b/mlir/examples/transform/Ch4/include/MyExtension.td @@ -0,0 +1,46 @@ +//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 4 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Define the new operation. By convention, prefix its name with `match` +// followed by the name of the dialect extension. +def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying", + [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + DeclareOpInterfaceMethods<TransformOpInterface>, + // Indicate that the operation implements MatchOpInterface in addition to + // the TransformOpInterface. This interface is only used as a tag at this + // point and has no methods that are mandatory to implement. + MatchOpInterface, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { + let summary = "Succeed if any of the operands matches all nested criteria"; + let arguments = (ins TransformHandleTypeInterface:$op); + let results = (outs TransformParamTypeInterface:$position, + Variadic<Transform_AnyHandleOrParamType>:$results); + + // Match operations can be arbitrarily complex, e.g., containing regions. + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = [{ + $op `:` functional-type($op, results) attr-dict-with-keyword $body + }]; +} + +#endif // MY_EXTENSION diff --git a/mlir/examples/transform/Ch4/lib/CMakeLists.txt b/mlir/examples/transform/Ch4/lib/CMakeLists.txt new file mode 100644 index 000000000000..33338a679af3 --- /dev/null +++ b/mlir/examples/transform/Ch4/lib/CMakeLists.txt @@ -0,0 +1,20 @@ +# Outside examples, this should be `add_mlir_library`. +add_mlir_example_library( + # Library called MyExtension. + MyExtensionCh4 + + # Built from the following source files. + MyExtension.cpp + + # Make includes visible without top-level path. + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/examples/transform/Ch4/include + + # Make sure ODS declaration and definitions are generated before compiling this. + DEPENDS + MyExtensionCh4IncGen + + # Link in the transform dialect, an all generated dialects. + LINK_LIBS PRIVATE + MLIRTransformDialect +) diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp new file mode 100644 index 000000000000..26e348f2a30e --- /dev/null +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -0,0 +1,207 @@ +//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 4 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE_MATCHER "transform-matcher" +#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") +#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +//===---------------------------------------------------------------------===// +// MyExtension +//===---------------------------------------------------------------------===// + +// Define a new transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class MyExtension + : public ::mlir::transform::TransformDialectExtension<MyExtension> { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void MyExtension::init() { + // Register the additional match operations with the dialect similarly to + // other transform operations. List all operations generated from ODS. This + // call will perform additional checks that the operations implement the + // transform and memory effect interfaces required by the dialect interpreter + // and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} + +//===---------------------------------------------------------------------===// +// HasOperandSatisfyingOp +//===---------------------------------------------------------------------===// + +/// Returns `true` if both types implement one of the interfaces provided as +/// template parameters. +template <typename... Tys> +static bool implementSameInterface(mlir::Type t1, mlir::Type t2) { + return ((llvm::isa<Tys>(t1) && llvm::isa<Tys>(t2)) || ... || false); +} + +/// Returns `true` if both types implement one of the transform dialect +/// interfaces. +static bool implementSameTransformInterface(mlir::Type t1, mlir::Type t2) { + return implementSameInterface< + mlir::transform::TransformHandleTypeInterface, + mlir::transform::TransformParamTypeInterface, + mlir::transform::TransformValueHandleTypeInterface>(t1, t2); +} + +// Matcher ops implement `apply` similarly to other transform ops. They are not +// expected to modify payload, but use the tri-state result to signal failure or +// success to match, as well as potential irrecoverable errors. +mlir::DiagnosedSilenceableFailure +mlir::transform::HasOperandSatisfyingOp::apply( + mlir::transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &results, + mlir::transform::TransformState &state) { + // For simplicity, only handle a single payload op. Actual implementations + // can use `SingleOpMatcher` trait to simplify implementation and document + // this expectation. + auto payloadOps = state.getPayloadOps(getOp()); + if (!llvm::hasSingleElement(payloadOps)) + return emitSilenceableError() << "expected single payload"; + + // Iterate over all operands of the payload op to see if they can be matched + // using the body of this op. + Operation *payload = *payloadOps.begin(); + for (OpOperand &operand : payload->getOpOperands()) { + // Create a scope for transform values defined in the body. This corresponds + // to the syntactic scope of the region attached to this op. Any values + // associated with payloads from now on will be automatically dissociated + // when this object is destroyed, i.e. at the end of the iteration. + // Associate the block argument handle with the operand. + auto matchScope = state.make_region_scope(getBody()); + if (failed(state.mapBlockArgument(getBody().getArgument(0), + {operand.get()}))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + + // Iterate over all nested matchers with the current mapping and see if they + // succeed. + bool matchSucceeded = true; + for (Operation &matcher : getBody().front().without_terminator()) { + // Matcher ops are applied similarly to any other transform op. + DiagnosedSilenceableFailure diag = + state.applyTransform(cast<TransformOpInterface>(matcher)); + + // Definite failures are immediately propagated as they are irrecoverable. + if (diag.isDefiniteFailure()) + return diag; + + // On success, keep checking the remaining conditions. + if (diag.succeeded()) + continue; + + // Report failure-to-match for debugging purposes and stop matching this + // operand. + assert(diag.isSilenceableFailure()); + DEBUG_MATCHER(DBGS_MATCHER() + << "failed to match operand #" << operand.getOperandNumber() + << ": " << diag.getMessage()); + (void)diag.silence(); + matchSucceeded = false; + break; + } + // If failed to match this operand, try other operands. + if (!matchSucceeded) + continue; + + // If we reached this point, the matching succeeded for the current operand. + // Remap the values associated with terminator operands to be associated + // with op results, and also map the parameter result to the operand's + // position. Note that it is safe to do here despite the end of the scope + // as `results` are integrated into `state` by the interpreter after `apply` + // returns rather than immediately. + SmallVector<SmallVector<MappedValue>> yieldedMappings; + transform::detail::prepareValueMappings( + yieldedMappings, getBody().front().getTerminator()->getOperands(), + state); + results.setParams(getPosition().cast<OpResult>(), + {rewriter.getI32IntegerAttr(operand.getOperandNumber())}); + for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings)) + results.setMappedValues(result, mapping); + return DiagnosedSilenceableFailure::success(); + } + + // If we reached this point, none of the operands succeeded the match. + return emitSilenceableError() + << "none of the operands satisfied the conditions"; +} + +// By convention, operations implementing MatchOpInterface must not modify +// payload IR and must therefore specify that they only read operand handles and +// payload as their effects. +void mlir::transform::HasOperandSatisfyingOp::getEffects( + llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) { + onlyReadsPayload(effects); + onlyReadsHandle(getOp(), effects); + producesHandle(getPosition(), effects); + producesHandle(getResults(), effects); +} + +// Verify well-formedness of the operation and emit diagnostics if it is +// ill-formed. +mlir::LogicalResult mlir::transform::HasOperandSatisfyingOp::verify() { + mlir::Block &bodyBlock = getBody().front(); + if (bodyBlock.getNumArguments() != 1 || + !isa<TransformValueHandleTypeInterface>( + bodyBlock.getArgument(0).getType())) { + return emitOpError() + << "expects the body to have one value handle argument"; + } + if (bodyBlock.getTerminator()->getNumOperands() != getNumResults() - 1) { + return emitOpError() << "expects the body to yield " + << (getNumResults() - 1) << " values, got " + << bodyBlock.getTerminator()->getNumOperands(); + } + for (auto &&[i, operand, result] : + llvm::enumerate(bodyBlock.getTerminator()->getOperands().getTypes(), + getResults().getTypes())) { + if (implementSameTransformInterface(operand, result)) + continue; + return emitOpError() << "expects terminator operand #" << i + << " and result #" << (i + 1) + << " to implement the same transform interface"; + } + + for (Operation &op : bodyBlock.without_terminator()) { + if (!isa<TransformOpInterface>(op) || !isa<MatchOpInterface>(op)) { + InFlightDiagnostic diag = emitOpError() + << "expects body to contain match ops"; + diag.attachNote(op.getLoc()) << "non-match operation"; + return diag; + } + } + + return success(); +} + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions<MyExtension>(); +} diff --git a/mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp b/mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp new file mode 100644 index 000000000000..10190664b51c --- /dev/null +++ b/mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp @@ -0,0 +1,55 @@ +//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the Transform dialect tutorial chapter 4. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" + +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include <cstdlib> + +namespace test { +void registerTestTransformDialectExtension(mlir::DialectRegistry &); +} // namespace test + +int main(int argc, char **argv) { + // Register all "core" dialects and our transform dialect extension. + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); + registerMyExtension(registry); + + // Register a handful of cleanup passes that we can run to make the output IR + // look nicer. + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerSymbolDCEPass(); + mlir::transform::registerInterpreterPass(); + + // Register the test passes. +#ifdef MLIR_INCLUDE_TESTS + test::registerTestTransformDialectExtension(registry); +#else + llvm::errs() << "warning: MLIR built without test extension, interpreter " + "testing will not be available\n"; +#endif // MLIR_INCLUDE_TESTS + + // Delegate to the MLIR utility for parsing and pass management. + return mlir::MlirOptMain(argc, argv, "transform-opt-ch4", registry) + .succeeded() + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 7ec4c8f0963a..8ce030feeded 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -166,6 +166,8 @@ if(LLVM_BUILD_EXAMPLES) list(APPEND MLIR_TEST_DEPENDS transform-opt-ch2 transform-opt-ch3 + transform-opt-ch4 + mlir-minimal-opt ) if(MLIR_ENABLE_EXECUTION_ENGINE) list(APPEND MLIR_TEST_DEPENDS diff --git a/mlir/test/Examples/transform/Ch4/features.mlir b/mlir/test/Examples/transform/Ch4/features.mlir new file mode 100644 index 000000000000..9a2af474aa4f --- /dev/null +++ b/mlir/test/Examples/transform/Ch4/features.mlir @@ -0,0 +1,123 @@ +// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics + +// Matmul as a named operation. +func.func @named( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // expected-remark @below {{matmul}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %matmul : tensor<512x512xf32> +} + +// Matmul as a generic operation. +func.func @generic( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // expected-remark @below {{matmul}} + %matmul = linalg.generic { + iterator_types = ["parallel", "parallel", "reduction"], + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>] + } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.addf %0, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor<512x512xf32> + return %matmul : tensor<512x512xf32> +} + +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.consumed}) { + + // Traverses the payload IR associated with the operand handle, invoking + // @match_matmul_elemwise on each of the operations. If the named sequence + // succeeds, i.e., if none of the nested match (transform) operations + // produced a silenceable failure, invokes @print_matmul_elemwise and + // forwards the values yielded as arguments of the new invocation. If the + // named sequence fails with a silenceable failure, silences it (the message + // is forwarded to the debug stream). Definite failures are propagated + // immediately and unconditionally, as usual. + transform.foreach_match in %root + @match_generic_matmul -> @print_generic_matmul + : (!transform.any_op) -> !transform.any_op + + transform.yield + } + + // This is an action sequence. + transform.named_sequence @print_generic_matmul( + %matmul: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_generic_matmul( + %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op { + // Match a structured linear algebra operation. + transform.match.structured %candidate : !transform.any_op { + ^bb0(%c: !transform.any_op): + // With a rank equal to 3. + %rank = transform.match.structured.rank %c + : (!transform.any_op) -> !transform.param<i64> + %c3 = transform.param.constant 3 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64> + + // With 2 inputs. + %n_ins = transform.match.structured.num_inputs %c + : (!transform.any_op) -> !transform.param<i64> + %c2 = transform.param.constant 2 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64> + + // With 1 output (note that structured ops in destination passing style + // has as many inits as outputs). + %n_inits = transform.match.structured.num_inits %c + : (!transform.any_op) -> !transform.param<i64> + %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64> + + // All inputs and inits are accessed with a projected permutation. + transform.match.structured.input %c[all] {projected_permutation} + : !transform.any_op + transform.match.structured.init %c[0] {projected_permutation} + : !transform.any_op + + // The body is a mulf/addf contraction with appropriate dimensions. + transform.match.structured.body %c + { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op + %batch, %lhs, %rhs, %reduction = + transform.match.structured.classify_contraction_dims %c + : (!transform.any_op) + -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, + !transform.param<i64>) + + // There is one of lhs, rhs and reduction dimensions and zero batch + // dimensions. + %n_batch = transform.num_associations %batch + : (!transform.param<i64>) -> !transform.param<i64> + %n_lhs = transform.num_associations %lhs + : (!transform.param<i64>) -> !transform.param<i64> + %n_rhs = transform.num_associations %rhs + : (!transform.param<i64>) -> !transform.param<i64> + %n_reduction = transform.num_associations %reduction + : (!transform.param<i64>) -> !transform.param<i64> + %c0 = transform.param.constant 0 : i64 -> !transform.param<i64> + transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64> + transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64> + transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64> + transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64> + } + transform.yield %candidate : !transform.any_op + } +} diff --git a/mlir/test/Examples/transform/Ch4/multiple.mlir b/mlir/test/Examples/transform/Ch4/multiple.mlir new file mode 100644 index 000000000000..22ef7c99f86a --- /dev/null +++ b/mlir/test/Examples/transform/Ch4/multiple.mlir @@ -0,0 +1,131 @@ +// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics + +// Matmul+ReLU. +func.func @fc_relu_operands_00( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-remark @below {{matmul # 0}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + // expected-remark @below {{add # 0}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{max # 0}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// Matmul+ReLU with swapped operands. +func.func @fc_relu_operands_01( + %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-remark @below {{matmul # 1}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + // expected-remark @below {{add # 1}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{max # 1}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } + ins(%c0f, %biased : f32, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.consumed}) { + + // Traverses the payload IR associated with the operand handle, invoking + // @match_matmul_elemwise on each of the operations. If the named sequence + // succeeds, i.e., if none of the nested match (transform) operations + // produced a silenceable failure, invokes @print_matmul_elemwise and + // forwards the values yielded as arguments of the new invocation. If the + // named sequence fails with a silenceable failure, silences it (the message + // is forwarded to the debug stream). Definite failures are propagated + // immediately and unconditionally, as usual. + transform.foreach_match in %root + @match_matmul_elemwise -> @print_matmul_elemwise + : (!transform.any_op) -> !transform.any_op + + transform.yield + } + + // This is an action sequence. + transform.named_sequence @print_matmul_elemwise( + %matmul: !transform.any_op {transform.readonly}, + %add: !transform.any_op {transform.readonly}, + %max: !transform.any_op {transform.readonly}, + %pos: !transform.param<i32> {transform.readonly}) { + transform.test_print_param %pos, "matmul #" at %matmul + : !transform.param<i32>, !transform.any_op + transform.test_print_param %pos, "add #" at %add + : !transform.param<i32>, !transform.any_op + transform.test_print_param %pos, "max #" at %max + : !transform.param<i32>, !transform.any_op + transform.yield + } + + // This is also a matcher sequence. It is similarly given an operation to + // match and nested operations must succeed in order for a match to be deemed + // successful. It starts matching from the last operation in the use-def chain + // and goes back because each operand (use) has exactly one definition. + transform.named_sequence @match_matmul_elemwise( + %last: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op, + !transform.param<i32>) { + // The last operation must be an elementwise binary. + transform.match.operation_name %last ["linalg.elemwise_binary"] + : !transform.any_op + + // One of its operands must be defined by another operation, to which we + // will get a handle here. This is achieved thanks to a newly defined + // operation that tries to match operands one by one using the match + // operations nested in its region. + %pos, %middle = transform.match.my.has_operand_satisfying %last + : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) { + ^bb0(%operand: !transform.any_value): + // The operand must be defined by an operation. + %def = transform.get_defining_op %operand + : (!transform.any_value) -> !transform.any_op + // The defining operation must itself be an elementwise binary. + transform.match.operation_name %def ["linalg.elemwise_binary"] + : !transform.any_op + transform.yield %def : !transform.any_op + } + + // And the first operand of that operation must be defined by yet another + // operation. + %matmul = transform.get_producer_of_operand %middle[0] + : (!transform.any_op) -> !transform.any_op + // And that operation is a matmul. + transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op + // We will yield the handles to the matmul and the two elementwise + // operations separately. + transform.yield %matmul, %middle, %last, %pos + : !transform.any_op, !transform.any_op, !transform.any_op, + !transform.param<i32> + } +} diff --git a/mlir/test/Examples/transform/Ch4/sequence.mlir b/mlir/test/Examples/transform/Ch4/sequence.mlir new file mode 100644 index 000000000000..28c3e9649bd9 --- /dev/null +++ b/mlir/test/Examples/transform/Ch4/sequence.mlir @@ -0,0 +1,139 @@ +// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics +// +// RUN: transform-opt-ch4 %s \ +// RUN: --transform-interpreter='entry-point=__transform_main_v2' \ +// RUN: --verify-diagnostics + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-remark @below {{matmul}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + // expected-remark @below {{elementwise binary}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{elementwise binary}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// The module containing named sequences must have an attribute allowing them +// to enable verification. +module @transforms attributes { transform.with_named_sequence } { + // Entry point. This takes as the only argument the root operation (typically + // pass root) given to the transform interpreter. + transform.named_sequence @__transform_main( + %root: !transform.any_op {transform.readonly}) { + // Collect operations that match the criteria specified in the named + // sequence. If the named sequence fails with a silenceable failure, + // silences it (the message is forwarded to the debug stream). If the named + // sequence succeeds, appends its results to the results of this operation. + %elemwise = transform.collect_matching @match_elemwise in %root + : (!transform.any_op) -> !transform.any_op + %matmul = transform.collect_matching @match_matmul in %root + : (!transform.any_op) -> !transform.any_op + + transform.include @print_elemwise failures(propagate) (%elemwise) + : (!transform.any_op) -> () + transform.include @print_matmul failures(propagate) (%matmul) + : (!transform.any_op) -> () + + transform.yield + } + + // Alternative entry point. + transform.named_sequence @__transform_main_v2( + %root: !transform.any_op {transform.readonly}) { + // Collect groups of operations that match the criteria specified in the + // named sequence. + %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op + + transform.include @print_elemwise failures(propagate) (%elemwise) + : (!transform.any_op) -> () + transform.include @print_matmul failures(propagate) (%matmul) + : (!transform.any_op) -> () + + transform.yield + } + + // This is a matcher sequence. It is given an operation to match and the + // match is considered successful unless any nested operation produces a + // failure. The values yielded by this operation will be forwarded to the + // rewriter sequence on success. + transform.named_sequence @match_elemwise( + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %entry ["linalg.elemwise_binary"] + : !transform.any_op + transform.yield %entry : !transform.any_op + } + transform.named_sequence @match_matmul( + %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op + transform.yield %entry : !transform.any_op + } + + // This is an action sequence. + transform.named_sequence @print_elemwise( + %elemwise_binary: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand + %elemwise_binary, "elementwise binary" : !transform.any_op + transform.yield + } + transform.named_sequence @print_matmul( + %matmul: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op + transform.yield + } + + // This is also a matcher sequence. It is similarly given an operation to + // match and nested operations must succeed in order for a match to be deemed + // successful. It starts matching from the last operation in the use-def chain + // and goes back because each operand (use) has exactly one definition. + transform.named_sequence @match_matmul_elemwise( + %last: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op) { + // The last operation must be an elementwise binary. + transform.match.operation_name %last ["linalg.elemwise_binary"] + : !transform.any_op + // Its first operand must be defined by another operation, to which we + // will get a handle here. We are guaranteed that the first operand exists + // because we know the operation is binary, but even in absence of such a + // guarantee, this operation would have produced a silenceable failure when + // `%last` does not have enough operands. + %middle = transform.get_producer_of_operand %last[0] + : (!transform.any_op) -> !transform.any_op + // The defining operation must itself be an elementwise binary. + transform.match.operation_name %middle ["linalg.elemwise_binary"] + : !transform.any_op + // And the first operand of that operation must be defined by yet another + // operation. + %matmul = transform.get_producer_of_operand %middle[0] + : (!transform.any_op) -> !transform.any_op + // And that operation is a matmul. + transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op + // We will yield the handles to the matmul and the two elementwise + // operations separately. + transform.yield %matmul, %middle, %last + : !transform.any_op, !transform.any_op, !transform.any_op + } +} diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 5b92491175e5..0a1ea1d16da4 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -154,8 +154,9 @@ tools.extend( ToolSubst("toyc-ch5", unresolved="ignore"), ToolSubst("toyc-ch6", unresolved="ignore"), ToolSubst("toyc-ch7", unresolved="ignore"), - ToolSubst('transform-opt-ch2', unresolved='ignore'), - ToolSubst('transform-opt-ch3', unresolved='ignore'), + ToolSubst("transform-opt-ch2", unresolved="ignore"), + ToolSubst("transform-opt-ch3", unresolved="ignore"), + ToolSubst("transform-opt-ch4", unresolved="ignore"), ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"), ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"), ] |