summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaksim Levental <maksim.levental@gmail.com>2023-12-21 11:20:29 -0600
committerGitHub <noreply@github.com>2023-12-21 11:20:29 -0600
commit537b2aa264c5a9879a80289c8d123b39e520eb15 (patch)
treefa1fcdb1a320263c6255fdeea9ba58b788e37702
parent11140cc238b8c4124e6f9efacb1601f81da096a0 (diff)
[mlir][python] meta region_op (#75673)
-rw-r--r--mlir/python/CMakeLists.txt9
-rw-r--r--mlir/python/mlir/dialects/arith.py8
-rw-r--r--mlir/python/mlir/dialects/builtin.py23
-rw-r--r--mlir/python/mlir/dialects/func.py3
-rw-r--r--mlir/python/mlir/dialects/pdl.py10
-rw-r--r--mlir/python/mlir/dialects/scf.py2
-rw-r--r--mlir/python/mlir/dialects/tensor.py7
-rw-r--r--mlir/python/mlir/dialects/transform/__init__.py13
-rw-r--r--mlir/python/mlir/dialects/transform/extras/__init__.py15
-rw-r--r--mlir/python/mlir/extras/meta.py83
-rw-r--r--mlir/test/python/dialects/arith_dialect.py6
-rw-r--r--mlir/test/python/dialects/tensor.py35
-rw-r--r--mlir/test/python/dialects/transform_extras.py73
-rw-r--r--mlir/test/python/integration/dialects/transform.py155
14 files changed, 429 insertions, 13 deletions
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55c5973e40e5..3c9cf304d88a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
- extras/types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/_mlir/passmanager.pyi
)
+declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ ADD_TO_PARENT MLIRPythonSources.Core.Python
+ SOURCES
+ extras/types.py
+ extras/meta.py
+)
+
declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonSources
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 83aca0d58bf2..663a53660a64 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -11,6 +11,8 @@ try:
from ._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
+ SubClassValueT as _SubClassValueT,
)
from typing import Any, List, Union
@@ -75,3 +77,9 @@ class ConstantOp(ConstantOp):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
+
+
+def constant(
+ result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+) -> _SubClassValueT:
+ return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
index b71cc2466d46..1c69d6d7c3a0 100644
--- a/mlir/python/mlir/dialects/builtin.py
+++ b/mlir/python/mlir/dialects/builtin.py
@@ -2,8 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Dict, Optional
+
from ._builtin_ops_gen import *
from ._builtin_ops_gen import _Dialect
+from ..extras.meta import region_op
try:
from ..ir import *
@@ -23,3 +26,23 @@ class ModuleOp(ModuleOp):
@property
def body(self):
return self.regions[0].blocks[0]
+
+
+@region_op
+def module(
+ *,
+ sym_name=None,
+ sym_visibility=None,
+ attrs: Optional[Dict[str, Attribute]] = None,
+ loc=None,
+ ip=None,
+):
+ mod = ModuleOp.__base__(
+ sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip
+ )
+ if attrs is None:
+ attrs = {}
+ for attr_name, attr in attrs.items():
+ mod.operation.attributes[attr_name] = attr
+
+ return mod
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 6599f67b7078..24fdcbcd85b2 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -243,6 +243,9 @@ class FuncOp(FuncOp):
return decorator
+func = FuncOp.from_py_func
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class CallOp(CallOp):
"""Specialization for the call op class."""
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index 90d7d706238e..db07dc50aabd 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -5,6 +5,7 @@
from ._pdl_ops_gen import *
from ._pdl_ops_gen import _Dialect
from .._mlir_libs._mlirDialectsPDL import *
+from .._mlir_libs._mlirDialectsPDL import OperationType
try:
@@ -13,7 +14,7 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Union, Optional, Sequence, Mapping
+from typing import Union, Optional, Sequence, Mapping, NewType
from ._ods_common import (
get_op_result_or_value as _get_value,
get_op_results_or_values as _get_values,
@@ -220,3 +221,10 @@ class TypesOp(TypesOp):
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
+
+
+OperationTypeT = NewType("OperationType", OperationType)
+
+
+def op_t() -> OperationTypeT:
+ return OperationTypeT(OperationType.get())
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 20bbed9bc93d..dad7377987e5 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,7 +120,7 @@ def for_(
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
- p = constant(IntegerAttr.get(IndexType.get(), p))
+ p = constant(IndexType.get(), p)
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 67248748eaf3..79dd9476ad0f 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -4,6 +4,7 @@
from ._tensor_ops_gen import *
from ._tensor_ops_gen import _Dialect
+from ..extras.meta import region_op
try:
from ..ir import *
@@ -40,3 +41,9 @@ class EmptyOp(EmptyOp):
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
+
+
+generate = region_op(
+ lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
+ terminator=lambda args: YieldOp(args[0]),
+)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 175634c7d458..5b158ec6b65f 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -18,7 +18,7 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Optional, Sequence, Union
+from typing import Optional, Sequence, Union, NewType
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -175,7 +175,7 @@ class NamedSequenceOp(NamedSequenceOp):
result_types: Sequence[Type],
sym_visibility=None,
arg_attrs=None,
- res_attrs=None
+ res_attrs=None,
):
function_type = FunctionType.get(input_types, result_types)
super().__init__(
@@ -183,7 +183,7 @@ class NamedSequenceOp(NamedSequenceOp):
function_type=TypeAttr.get(function_type),
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
- res_attrs=res_attrs
+ res_attrs=res_attrs,
)
self.regions[0].blocks.append(*input_types)
@@ -212,3 +212,10 @@ class YieldOp(YieldOp):
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+
+
+AnyOpTypeT = NewType("AnyOpType", AnyOpType)
+
+
+def any_op_t() -> AnyOpTypeT:
+ return AnyOpTypeT(AnyOpType.get())
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index c715dac1ef7e..e4d47e9064f2 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -4,8 +4,16 @@
from typing import Callable, Optional, Sequence, Union
+from ....extras.meta import region_op
from .... import ir
-from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
+from .. import (
+ AnyOpType,
+ OperationType,
+ NamedSequenceOp,
+ YieldOp,
+ SequenceOp,
+ ApplyPatternsOp,
+)
from .. import structured
@@ -147,3 +155,8 @@ def insert_transform_script(
if dump_script:
print(named_sequence_op)
+
+
+sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
+named_sequence = region_op(NamedSequenceOp, terminator=YieldOp)
+apply_patterns = region_op(ApplyPatternsOp)
diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
new file mode 100644
index 000000000000..3f2defadf794
--- /dev/null
+++ b/mlir/python/mlir/extras/meta.py
@@ -0,0 +1,83 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import inspect
+from functools import wraps
+
+from ..dialects._ods_common import get_op_result_or_op_results
+from ..ir import Type, InsertionPoint
+
+
+def op_region_builder(op, op_region, terminator=None):
+ def builder_wrapper(body_builder):
+ # Add a block with block args having types determined by type hints on the wrapped function.
+ if len(op_region.blocks) == 0:
+ sig = inspect.signature(body_builder)
+ types = [p.annotation for p in sig.parameters.values()]
+ if not (
+ len(types) == len(sig.parameters)
+ and all(isinstance(t, Type) for t in types)
+ ):
+ raise ValueError(
+ f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
+ )
+
+ op_region.blocks.append(*types)
+
+ with InsertionPoint(op_region.blocks[0]):
+ results = body_builder(*list(op_region.blocks[0].arguments))
+
+ with InsertionPoint(list(op_region.blocks)[-1]):
+ if terminator is not None:
+ res = []
+ if isinstance(results, (tuple, list)):
+ res.extend(results)
+ elif results is not None:
+ res.append(results)
+ terminator(res)
+
+ return get_op_result_or_op_results(op)
+
+ return builder_wrapper
+
+
+def region_op(op_constructor, terminator=None):
+ """Decorator to define an MLIR Op specified as a python function.
+
+ Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+ active for the current thread (i.e. established in a `with` block).
+
+ Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
+
+ When applied as a decorator to a Python function, an entry block will
+ be constructed for the Op with types as specified **as type hints on the args of the function**.
+ The block arguments will be passed positionally to the Python function.
+
+ If a terminator is specified then the return from the decorated function will be passed
+ to the terminator as the last statement in the entry block. Note, the API for the terminator
+ is a (possibly empty) list; terminator accepting single values should be wrapped in a
+ `lambda args: term(args[0])`
+
+ The identifier (name) of the function will become:
+ 1. A single value result if the Op returns a single value;
+ 2. An OpResultList (as a list) if the Op returns multiple values;
+ 3. The Operation if the Op returns no results.
+
+ See examples in tensor.py and transform.extras.
+ """
+
+ def op_decorator(*args, **kwargs):
+ op = op_constructor(*args, **kwargs)
+ op_region = op.regions[0]
+
+ return op_region_builder(op, op_region, terminator)
+
+ @wraps(op_decorator)
+ def maybe_no_args(*args, **kwargs):
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
+ return op_decorator()(args[0])
+ else:
+ return op_decorator(*args, **kwargs)
+
+ return maybe_no_args
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index f80f2c084a0f..8bb80eed2b81 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -75,7 +75,7 @@ def testArithValue():
f64_t = F64Type.get()
with InsertionPoint(module.body):
- a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+ a = arith.constant(f16_t, 42.42)
# CHECK: ArithValue(%cst = arith.constant 4.240
print(a)
@@ -83,12 +83,12 @@ def testArithValue():
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
print(b)
- a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+ a = arith.constant(f32_t, 42.42)
b = a - a
# CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
print(b)
- a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+ a = arith.constant(f64_t, 42.42)
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py
index b690c934dc46..ca9066b23911 100644
--- a/mlir/test/python/dialects/tensor.py
+++ b/mlir/test/python/dialects/tensor.py
@@ -4,6 +4,7 @@ from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
import mlir.dialects.tensor as tensor
+from mlir.extras import types as T
def run(f):
@@ -139,3 +140,37 @@ def testFromElementsOp():
t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
# CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
print(t)
+
+
+# CHECK-LABEL: TEST: testGenerateRegionOp
+@run
+def testGenerateRegionOp():
+ S = ShapedType.get_dynamic_size()
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
+ one = arith.constant(T.index(), 1)
+ two = arith.constant(T.index(), 2)
+
+ @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two])
+ def generate_one(i: T.index(), j: T.index(), k: T.index()):
+ ij = arith.addi(i, j)
+ ijk = arith.addi(ij, k)
+ return ijk
+
+ assert (
+ isinstance(generate_one, Value)
+ and generate_one.owner.name == "tensor.generate"
+ )
+
+ # CHECK: %[[GENERATED:.*]] = tensor.generate
+ # CHECK-SAME: %[[VAL_0]],
+ # CHECK-SAME: %[[VAL_1]] {
+ # CHECK: ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
+ # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index
+ # CHECK: tensor.yield %[[VAL_5]] : index
+ # CHECK: } : tensor<?x3x?xindex>
+ print(module)
diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
index e7b43ea63c31..358f8c32f75c 100644
--- a/mlir/test/python/dialects/transform_extras.py
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -2,9 +2,34 @@
from typing import Callable
from mlir import ir
-from mlir.dialects import scf
-from mlir.dialects.transform import structured
-from mlir.dialects.transform.extras import OpHandle, insert_transform_script
+from mlir.dialects import scf, pdl
+from mlir.dialects.transform import (
+ structured,
+ get_parent_op,
+ apply_patterns_canonicalization,
+ apply_cse,
+ any_op_t,
+)
+from mlir.dialects.transform import FailurePropagationMode
+from mlir.dialects.transform.structured import structured_match
+from mlir.dialects.transform.loop import loop_unroll
+from mlir.dialects.transform.extras import (
+ OpHandle,
+ insert_transform_script,
+ sequence,
+ apply_patterns,
+)
+from mlir.extras import types as T
+
+
+def construct_and_print_in_module(f):
+ print("\nTEST:", f.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
def build_transform_script(script: Callable[[OpHandle], None]):
@@ -93,3 +118,45 @@ def test_match_ops_mixed(op: OpHandle):
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
# CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
# CHECK-SAME: -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_sequence_region
+@construct_and_print_in_module
+def test_sequence_region():
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
+ # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
+ # CHECK: }
+ @sequence([], FailurePropagationMode.Propagate, [])
+ def basic(target: any_op_t()):
+ m = structured_match(any_op_t(), target, ops=["arith.addi"])
+ loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
+ loop_unroll(loop, 4)
+
+
+# CHECK-LABEL: TEST: test_apply_patterns
+@construct_and_print_in_module
+def test_apply_patterns():
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
+ # CHECK: apply_patterns to %[[VAL_2]] {
+ # CHECK: transform.apply_patterns.canonicalization
+ # CHECK: } : !pdl.operation
+ # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: apply_cse to %[[VAL_3]] : !transform.any_op
+ # CHECK: }
+ @sequence([], FailurePropagationMode.Propagate, [])
+ def basic(variant_op: any_op_t()):
+ matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
+ top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
+
+ @apply_patterns(top_func)
+ def pats():
+ apply_patterns_canonicalization()
+
+ top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
+ apply_cse(top_func)
diff --git a/mlir/test/python/integration/dialects/transform.py b/mlir/test/python/integration/dialects/transform.py
new file mode 100644
index 000000000000..bc88a61314d0
--- /dev/null
+++ b/mlir/test/python/integration/dialects/transform.py
@@ -0,0 +1,155 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.passmanager import PassManager
+from mlir.ir import Context, Location, Module, InsertionPoint, UnitAttr
+from mlir.dialects import scf, pdl, func, arith, linalg
+from mlir.dialects.transform import (
+ get_parent_op,
+ apply_patterns_canonicalization,
+ apply_cse,
+ any_op_t,
+)
+from mlir.dialects.transform.structured import structured_match
+from mlir.dialects.transform.loop import loop_unroll
+from mlir.dialects.transform.extras import named_sequence, apply_patterns
+from mlir.extras import types as T
+from mlir.dialects.builtin import module, ModuleOp
+
+
+def construct_and_print_in_module(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ module = f(module)
+ if module is not None:
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: test_named_sequence
+@construct_and_print_in_module
+def test_named_sequence(module_):
+ # CHECK-LABEL: func.func @loop_unroll_op() {
+ # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
+ # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.func()
+ def loop_unroll_op():
+ for i in scf.for_(0, 42, 5):
+ v = arith.addi(i, i)
+ scf.yield_([])
+
+ # CHECK-LABEL: module attributes {transform.with_named_sequence} {
+ # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
+ # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
+ # CHECK: transform.yield
+ # CHECK: }
+ # CHECK: }
+ @module(attrs={"transform.with_named_sequence": UnitAttr.get()})
+ def mod():
+ @named_sequence("__transform_main", [any_op_t()], [])
+ def basic(target: any_op_t()):
+ m = structured_match(any_op_t(), target, ops=["arith.addi"])
+ loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
+ loop_unroll(loop, 4)
+
+ # The identifier (name) of the function becomes the Operation
+ assert isinstance(mod.opview, ModuleOp)
+
+ print(module_)
+
+ pm = PassManager.parse("builtin.module(transform-interpreter)")
+ pm.run(module_.operation)
+
+ # CHECK-LABEL: func.func @loop_unroll_op() {
+ # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 40 : index
+ # CHECK: %[[VAL_7:.*]] = arith.constant 20 : index
+ # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] {
+ # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index
+ # CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index
+ # CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index
+ # CHECK: %[[VAL_12:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index
+ # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index
+ # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
+ # CHECK: %[[VAL_16:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index
+ # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index
+ # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index
+ # CHECK: }
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+ # CHECK: return
+ # CHECK: }
+ print(module_)
+
+
+# CHECK-LABEL: TEST: test_apply_patterns
+@construct_and_print_in_module
+def test_apply_patterns(module_):
+ M, N, K = 3, 5, 3
+
+ # CHECK-LABEL: func.func @matmul(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32
+ # CHECK: %[[VAL_5:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+ # CHECK: return %[[VAL_5]] : tensor<3x3xf32>
+ # CHECK: }
+ @func.func(
+ T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
+ )
+ def matmul(A, B, C):
+ i = arith.constant(T.i32(), 1)
+ v = arith.addi(i, i)
+ return linalg.matmul(A, B, outs=[C])
+
+ # CHECK-LABEL: module attributes {transform.with_named_sequence} {
+ # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
+ # CHECK: transform.apply_patterns to %[[VAL_2]] {
+ # CHECK: transform.apply_patterns.canonicalization
+ # CHECK: } : !pdl.operation
+ # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: transform.apply_cse to %[[VAL_3]] : !transform.any_op
+ # CHECK: transform.yield
+ # CHECK: }
+ # CHECK: }
+ @module(attrs={"transform.with_named_sequence": UnitAttr.get()})
+ def mod():
+ @named_sequence("__transform_main", [any_op_t()], [])
+ def basic(variant_op: any_op_t()):
+ matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
+ top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
+
+ @apply_patterns(top_func)
+ def pats():
+ apply_patterns_canonicalization()
+
+ top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
+ apply_cse(top_func)
+
+ print(module_)
+
+ pm = PassManager.parse("builtin.module(transform-interpreter)")
+ pm.run(module_.operation)
+
+ # CHECK-LABEL: func.func @matmul(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ # CHECK: %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+ # CHECK: return %[[VAL_3]] : tensor<3x3xf32>
+ # CHECK: }
+ print(module_)