diff options
author | Maksim Levental <maksim.levental@gmail.com> | 2023-12-21 11:20:29 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-21 11:20:29 -0600 |
commit | 537b2aa264c5a9879a80289c8d123b39e520eb15 (patch) | |
tree | fa1fcdb1a320263c6255fdeea9ba58b788e37702 | |
parent | 11140cc238b8c4124e6f9efacb1601f81da096a0 (diff) |
[mlir][python] meta region_op (#75673)
-rw-r--r-- | mlir/python/CMakeLists.txt | 9 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/arith.py | 8 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/builtin.py | 23 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/func.py | 3 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/pdl.py | 10 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/scf.py | 2 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/tensor.py | 7 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/transform/__init__.py | 13 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/transform/extras/__init__.py | 15 | ||||
-rw-r--r-- | mlir/python/mlir/extras/meta.py | 83 | ||||
-rw-r--r-- | mlir/test/python/dialects/arith_dialect.py | 6 | ||||
-rw-r--r-- | mlir/test/python/dialects/tensor.py | 35 | ||||
-rw-r--r-- | mlir/test/python/dialects/transform_extras.py | 73 | ||||
-rw-r--r-- | mlir/test/python/integration/dialects/transform.py | 155 |
14 files changed, 429 insertions, 13 deletions
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 55c5973e40e5..3c9cf304d88a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/__init__.py ir.py passmanager.py - extras/types.py dialects/_ods_common.py # The main _mlir module has submodules: include stubs from each. @@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/_mlir/passmanager.pyi ) +declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources.Core.Python + SOURCES + extras/types.py + extras/meta.py +) + declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ADD_TO_PARENT MLIRPythonSources diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 83aca0d58bf2..663a53660a64 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -11,6 +11,8 @@ try: from ._ods_common import ( get_default_loc_context as _get_default_loc_context, _cext as _ods_cext, + get_op_result_or_op_results as _get_op_result_or_op_results, + SubClassValueT as _SubClassValueT, ) from typing import Any, List, Union @@ -75,3 +77,9 @@ class ConstantOp(ConstantOp): return FloatAttr(self.value).value else: raise ValueError("only integer and float constants have literal values") + + +def constant( + result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None +) -> _SubClassValueT: + return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py index b71cc2466d46..1c69d6d7c3a0 100644 --- a/mlir/python/mlir/dialects/builtin.py +++ b/mlir/python/mlir/dialects/builtin.py @@ -2,8 +2,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Dict, Optional + from ._builtin_ops_gen import * from ._builtin_ops_gen import _Dialect +from ..extras.meta import region_op try: from ..ir import * @@ -23,3 +26,23 @@ class ModuleOp(ModuleOp): @property def body(self): return self.regions[0].blocks[0] + + +@region_op +def module( + *, + sym_name=None, + sym_visibility=None, + attrs: Optional[Dict[str, Attribute]] = None, + loc=None, + ip=None, +): + mod = ModuleOp.__base__( + sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip + ) + if attrs is None: + attrs = {} + for attr_name, attr in attrs.items(): + mod.operation.attributes[attr_name] = attr + + return mod diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 6599f67b7078..24fdcbcd85b2 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -243,6 +243,9 @@ class FuncOp(FuncOp): return decorator +func = FuncOp.from_py_func + + @_ods_cext.register_operation(_Dialect, replace=True) class CallOp(CallOp): """Specialization for the call op class.""" diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index 90d7d706238e..db07dc50aabd 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -5,6 +5,7 @@ from ._pdl_ops_gen import * from ._pdl_ops_gen import _Dialect from .._mlir_libs._mlirDialectsPDL import * +from .._mlir_libs._mlirDialectsPDL import OperationType try: @@ -13,7 +14,7 @@ try: except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Union, Optional, Sequence, Mapping +from typing import Union, Optional, Sequence, Mapping, NewType from ._ods_common import ( get_op_result_or_value as _get_value, get_op_results_or_values as _get_values, @@ -220,3 +221,10 @@ class TypesOp(TypesOp): constantTypes = [] result = pdl.RangeType.get(pdl.TypeType.get()) super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) + + +OperationTypeT = NewType("OperationType", OperationType) + + +def op_t() -> OperationTypeT: + return OperationTypeT(OperationType.get()) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 20bbed9bc93d..dad7377987e5 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -120,7 +120,7 @@ def for_( params = [start, stop, step] for i, p in enumerate(params): if isinstance(p, int): - p = constant(IntegerAttr.get(IndexType.get(), p)) + p = constant(IndexType.get(), p) elif isinstance(p, float): raise ValueError(f"{p=} must be int.") params[i] = p diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 67248748eaf3..79dd9476ad0f 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -4,6 +4,7 @@ from ._tensor_ops_gen import * from ._tensor_ops_gen import _Dialect +from ..extras.meta import region_op try: from ..ir import * @@ -40,3 +41,9 @@ class EmptyOp(EmptyOp): dynamic_sizes.append(s) result_type = RankedTensorType.get(static_sizes, element_type) super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) + + +generate = region_op( + lambda result, dynamic_extents: GenerateOp(result, dynamic_extents), + terminator=lambda args: YieldOp(args[0]), +) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 175634c7d458..5b158ec6b65f 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -18,7 +18,7 @@ try: except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Union, NewType @_ods_cext.register_operation(_Dialect, replace=True) @@ -175,7 +175,7 @@ class NamedSequenceOp(NamedSequenceOp): result_types: Sequence[Type], sym_visibility=None, arg_attrs=None, - res_attrs=None + res_attrs=None, ): function_type = FunctionType.get(input_types, result_types) super().__init__( @@ -183,7 +183,7 @@ class NamedSequenceOp(NamedSequenceOp): function_type=TypeAttr.get(function_type), sym_visibility=sym_visibility, arg_attrs=arg_attrs, - res_attrs=res_attrs + res_attrs=res_attrs, ) self.regions[0].blocks.append(*input_types) @@ -212,3 +212,10 @@ class YieldOp(YieldOp): if operands is None: operands = [] super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) + + +AnyOpTypeT = NewType("AnyOpType", AnyOpType) + + +def any_op_t() -> AnyOpTypeT: + return AnyOpTypeT(AnyOpType.get()) diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index c715dac1ef7e..e4d47e9064f2 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -4,8 +4,16 @@ from typing import Callable, Optional, Sequence, Union +from ....extras.meta import region_op from .... import ir -from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp +from .. import ( + AnyOpType, + OperationType, + NamedSequenceOp, + YieldOp, + SequenceOp, + ApplyPatternsOp, +) from .. import structured @@ -147,3 +155,8 @@ def insert_transform_script( if dump_script: print(named_sequence_op) + + +sequence = region_op(SequenceOp.__base__, terminator=YieldOp) +named_sequence = region_op(NamedSequenceOp, terminator=YieldOp) +apply_patterns = region_op(ApplyPatternsOp) diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py new file mode 100644 index 000000000000..3f2defadf794 --- /dev/null +++ b/mlir/python/mlir/extras/meta.py @@ -0,0 +1,83 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import inspect +from functools import wraps + +from ..dialects._ods_common import get_op_result_or_op_results +from ..ir import Type, InsertionPoint + + +def op_region_builder(op, op_region, terminator=None): + def builder_wrapper(body_builder): + # Add a block with block args having types determined by type hints on the wrapped function. + if len(op_region.blocks) == 0: + sig = inspect.signature(body_builder) + types = [p.annotation for p in sig.parameters.values()] + if not ( + len(types) == len(sig.parameters) + and all(isinstance(t, Type) for t in types) + ): + raise ValueError( + f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}" + ) + + op_region.blocks.append(*types) + + with InsertionPoint(op_region.blocks[0]): + results = body_builder(*list(op_region.blocks[0].arguments)) + + with InsertionPoint(list(op_region.blocks)[-1]): + if terminator is not None: + res = [] + if isinstance(results, (tuple, list)): + res.extend(results) + elif results is not None: + res.append(results) + terminator(res) + + return get_op_result_or_op_results(op) + + return builder_wrapper + + +def region_op(op_constructor, terminator=None): + """Decorator to define an MLIR Op specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor. + + When applied as a decorator to a Python function, an entry block will + be constructed for the Op with types as specified **as type hints on the args of the function**. + The block arguments will be passed positionally to the Python function. + + If a terminator is specified then the return from the decorated function will be passed + to the terminator as the last statement in the entry block. Note, the API for the terminator + is a (possibly empty) list; terminator accepting single values should be wrapped in a + `lambda args: term(args[0])` + + The identifier (name) of the function will become: + 1. A single value result if the Op returns a single value; + 2. An OpResultList (as a list) if the Op returns multiple values; + 3. The Operation if the Op returns no results. + + See examples in tensor.py and transform.extras. + """ + + def op_decorator(*args, **kwargs): + op = op_constructor(*args, **kwargs) + op_region = op.regions[0] + + return op_region_builder(op, op_region, terminator) + + @wraps(op_decorator) + def maybe_no_args(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return op_decorator()(args[0]) + else: + return op_decorator(*args, **kwargs) + + return maybe_no_args diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index f80f2c084a0f..8bb80eed2b81 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -75,7 +75,7 @@ def testArithValue(): f64_t = F64Type.get() with InsertionPoint(module.body): - a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) + a = arith.constant(f16_t, 42.42) # CHECK: ArithValue(%cst = arith.constant 4.240 print(a) @@ -83,12 +83,12 @@ def testArithValue(): # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) print(b) - a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) + a = arith.constant(f32_t, 42.42) b = a - a # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) print(b) - a = arith.constant(value=FloatAttr.get(f64_t, 42.42)) + a = arith.constant(f64_t, 42.42) b = a * a # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) print(b) diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py index b690c934dc46..ca9066b23911 100644 --- a/mlir/test/python/dialects/tensor.py +++ b/mlir/test/python/dialects/tensor.py @@ -4,6 +4,7 @@ from mlir.ir import * import mlir.dialects.arith as arith import mlir.dialects.func as func import mlir.dialects.tensor as tensor +from mlir.extras import types as T def run(f): @@ -139,3 +140,37 @@ def testFromElementsOp(): t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1]) # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32> print(t) + + +# CHECK-LABEL: TEST: testGenerateRegionOp +@run +def testGenerateRegionOp(): + S = ShapedType.get_dynamic_size() + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index + one = arith.constant(T.index(), 1) + two = arith.constant(T.index(), 2) + + @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two]) + def generate_one(i: T.index(), j: T.index(), k: T.index()): + ij = arith.addi(i, j) + ijk = arith.addi(ij, k) + return ijk + + assert ( + isinstance(generate_one, Value) + and generate_one.owner.name == "tensor.generate" + ) + + # CHECK: %[[GENERATED:.*]] = tensor.generate + # CHECK-SAME: %[[VAL_0]], + # CHECK-SAME: %[[VAL_1]] { + # CHECK: ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): + # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index + # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index + # CHECK: tensor.yield %[[VAL_5]] : index + # CHECK: } : tensor<?x3x?xindex> + print(module) diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py index e7b43ea63c31..358f8c32f75c 100644 --- a/mlir/test/python/dialects/transform_extras.py +++ b/mlir/test/python/dialects/transform_extras.py @@ -2,9 +2,34 @@ from typing import Callable from mlir import ir -from mlir.dialects import scf -from mlir.dialects.transform import structured -from mlir.dialects.transform.extras import OpHandle, insert_transform_script +from mlir.dialects import scf, pdl +from mlir.dialects.transform import ( + structured, + get_parent_op, + apply_patterns_canonicalization, + apply_cse, + any_op_t, +) +from mlir.dialects.transform import FailurePropagationMode +from mlir.dialects.transform.structured import structured_match +from mlir.dialects.transform.loop import loop_unroll +from mlir.dialects.transform.extras import ( + OpHandle, + insert_transform_script, + sequence, + apply_patterns, +) +from mlir.extras import types as T + + +def construct_and_print_in_module(f): + print("\nTEST:", f.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + f() + print(module) + return f def build_transform_script(script: Callable[[OpHandle], None]): @@ -93,3 +118,45 @@ def test_match_ops_mixed(op: OpHandle): # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]] # CHECK-SAME: -> !transform.any_op + + +# CHECK-LABEL: TEST: test_sequence_region +@construct_and_print_in_module +def test_sequence_region(): + # CHECK: transform.sequence failures(propagate) { + # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation + # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation + # CHECK: } + @sequence([], FailurePropagationMode.Propagate, []) + def basic(target: any_op_t()): + m = structured_match(any_op_t(), target, ops=["arith.addi"]) + loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") + loop_unroll(loop, 4) + + +# CHECK-LABEL: TEST: test_apply_patterns +@construct_and_print_in_module +def test_apply_patterns(): + # CHECK: transform.sequence failures(propagate) { + # CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation + # CHECK: apply_patterns to %[[VAL_2]] { + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } : !pdl.operation + # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: apply_cse to %[[VAL_3]] : !transform.any_op + # CHECK: } + @sequence([], FailurePropagationMode.Propagate, []) + def basic(variant_op: any_op_t()): + matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"]) + top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") + + @apply_patterns(top_func) + def pats(): + apply_patterns_canonicalization() + + top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) + apply_cse(top_func) diff --git a/mlir/test/python/integration/dialects/transform.py b/mlir/test/python/integration/dialects/transform.py new file mode 100644 index 000000000000..bc88a61314d0 --- /dev/null +++ b/mlir/test/python/integration/dialects/transform.py @@ -0,0 +1,155 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.passmanager import PassManager +from mlir.ir import Context, Location, Module, InsertionPoint, UnitAttr +from mlir.dialects import scf, pdl, func, arith, linalg +from mlir.dialects.transform import ( + get_parent_op, + apply_patterns_canonicalization, + apply_cse, + any_op_t, +) +from mlir.dialects.transform.structured import structured_match +from mlir.dialects.transform.loop import loop_unroll +from mlir.dialects.transform.extras import named_sequence, apply_patterns +from mlir.extras import types as T +from mlir.dialects.builtin import module, ModuleOp + + +def construct_and_print_in_module(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + module = f(module) + if module is not None: + print(module) + return f + + +# CHECK-LABEL: TEST: test_named_sequence +@construct_and_print_in_module +def test_named_sequence(module_): + # CHECK-LABEL: func.func @loop_unroll_op() { + # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index + # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index + # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index + # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { + # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index + # CHECK: } + # CHECK: return + # CHECK: } + @func.func() + def loop_unroll_op(): + for i in scf.for_(0, 42, 5): + v = arith.addi(i, i) + scf.yield_([]) + + # CHECK-LABEL: module attributes {transform.with_named_sequence} { + # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation + # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation + # CHECK: transform.yield + # CHECK: } + # CHECK: } + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(target: any_op_t()): + m = structured_match(any_op_t(), target, ops=["arith.addi"]) + loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") + loop_unroll(loop, 4) + + # The identifier (name) of the function becomes the Operation + assert isinstance(mod.opview, ModuleOp) + + print(module_) + + pm = PassManager.parse("builtin.module(transform-interpreter)") + pm.run(module_.operation) + + # CHECK-LABEL: func.func @loop_unroll_op() { + # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index + # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index + # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index + # CHECK: %[[VAL_6:.*]] = arith.constant 40 : index + # CHECK: %[[VAL_7:.*]] = arith.constant 20 : index + # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] { + # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index + # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index + # CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index + # CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index + # CHECK: %[[VAL_12:.*]] = arith.constant 2 : index + # CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index + # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index + # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index + # CHECK: %[[VAL_16:.*]] = arith.constant 3 : index + # CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index + # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index + # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index + # CHECK: } + # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index + # CHECK: return + # CHECK: } + print(module_) + + +# CHECK-LABEL: TEST: test_apply_patterns +@construct_and_print_in_module +def test_apply_patterns(module_): + M, N, K = 3, 5, 3 + + # CHECK-LABEL: func.func @matmul( + # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { + # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 + # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32 + # CHECK: %[[VAL_5:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> + # CHECK: return %[[VAL_5]] : tensor<3x3xf32> + # CHECK: } + @func.func( + T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32()) + ) + def matmul(A, B, C): + i = arith.constant(T.i32(), 1) + v = arith.addi(i, i) + return linalg.matmul(A, B, outs=[C]) + + # CHECK-LABEL: module attributes {transform.with_named_sequence} { + # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { + # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation + # CHECK: transform.apply_patterns to %[[VAL_2]] { + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } : !pdl.operation + # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: transform.apply_cse to %[[VAL_3]] : !transform.any_op + # CHECK: transform.yield + # CHECK: } + # CHECK: } + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(variant_op: any_op_t()): + matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"]) + top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") + + @apply_patterns(top_func) + def pats(): + apply_patterns_canonicalization() + + top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) + apply_cse(top_func) + + print(module_) + + pm = PassManager.parse("builtin.module(transform-interpreter)") + pm.run(module_.operation) + + # CHECK-LABEL: func.func @matmul( + # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { + # CHECK: %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> + # CHECK: return %[[VAL_3]] : tensor<3x3xf32> + # CHECK: } + print(module_) |