summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/Interfaces.md24
-rw-r--r--mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h11
-rw-r--r--mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h11
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h9
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td61
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td10
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td4
-rw-r--r--mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h119
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp12
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp88
-rw-r--r--mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp15
-rw-r--r--mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp15
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Padding.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp5
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp26
-rw-r--r--mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp17
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp36
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Utils/Utils.cpp4
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp338
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir54
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir12
-rw-r--r--mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir24
-rw-r--r--mlir/test/Dialect/ArmSME/vector-legalization.mlir11
-rw-r--r--mlir/test/Dialect/OpenMP/invalid.mlir126
-rw-r--r--mlir/test/Dialect/OpenMP/ops.mlir260
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir8
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir7
-rw-r--r--mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir112
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic.ll2
-rw-r--r--mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp26
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp37
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td16
39 files changed, 1014 insertions, 552 deletions
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 536e7613e509..51747db546bb 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -299,6 +299,30 @@ owner of the dialect containing the object nor the owner of the interface are
aware of an interface implementation, which can lead to duplicate or
diverging implementations.
+Forgetting to register an external model can lead to bugs which are hard to
+track down. The `declarePromisedInterface` function can be used to declare that
+an external model implementation for an operation must eventually be provided.
+
+```
+ void MyDialect::initialize() {
+ declarePromisedInterface<SomeInterface, SomeOp>();
+ ...
+ }
+```
+
+Now attempting to use the interface, e.g in a cast, without a prior registration
+of the external model will lead to a runtime error that will look similar to
+this:
+
+```
+LLVM ERROR: checking for an interface (`SomeInterface`) that was promised by dialect 'mydialect' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered.
+```
+
+If you encounter this error for a dialect and an interface provided by MLIR, you
+may look for a method that will be named like
+`register<Dialect><Interface>ExternalModels(DialectRegistry &registry);` ; try
+to find it with `git grep 'register.*SomeInterface.*Model' mlir`.
+
#### Dialect Fallback for OpInterface
Some dialects have an open ecosystem and don't register all of the possible
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 8e840e744064..1ea737522081 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
/// maximally compose chains of AffineApplyOps.
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition,
+ bool closedUB = false);
+
/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
index 970a52a06a11..bbc7e5d3e0dd 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
@@ -24,6 +24,17 @@ enum class BoundType;
namespace arith {
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition,
+ bool closedUB = false);
+
/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 304a9740d91e..27a766aceb31 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -284,11 +284,10 @@ using TaskgroupClauseOps =
detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
using TaskloopClauseOps =
- detail::Clauses<AllocateClauseOps, CollapseClauseOps, FinalClauseOps,
- GrainsizeClauseOps, IfClauseOps, InReductionClauseOps,
- LoopRelatedOps, MergeableClauseOps, NogroupClauseOps,
- NumTasksClauseOps, PriorityClauseOps, PrivateClauseOps,
- ReductionClauseOps, UntiedClauseOps>;
+ detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
+ IfClauseOps, InReductionClauseOps, MergeableClauseOps,
+ NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
+ PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
using TaskwaitClauseOps = detail::Clauses<DependClauseOps, NowaitClauseOps>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 3abdbe3adfd0..82be7ad31a15 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -840,7 +840,8 @@ def YieldOp : OpenMP_Op<"yield",
//===----------------------------------------------------------------------===//
def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
- RecursiveMemoryEffects]> {
+ RecursiveMemoryEffects,
+ SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "distribute construct";
let description = [{
The distribute construct specifies that the iterations of one or more loops
@@ -855,15 +856,28 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
The distribute loop construct specifies that the iterations of the loop(s)
will be executed in parallel by threads in the current context. These
iterations are spread across threads that already exist in the enclosing
- region. The lower and upper bounds specify a half-open range: the
- range includes the lower bound but does not include the upper bound. If the
- `inclusive` attribute is specified then the upper bound is also included.
+ region.
+
+ The body region can contain a single block which must contain a single
+ operation and a terminator. The operation must be another compatible loop
+ wrapper or an `omp.loop_nest`.
The `dist_schedule_static` attribute specifies the schedule for this
loop, determining how the loop is distributed across the parallel threads.
The optional `schedule_chunk` associated with this determines further
controls this distribution.
+ ```mlir
+ omp.distribute <clauses> {
+ omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+ %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+ %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+ %sum = arith.addf %a, %b : f32
+ store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+ omp.yield
+ }
+ }
+ ```
// TODO: private_var, firstprivate_var, lastprivate_var, collapse
}];
let arguments = (ins
@@ -1016,10 +1030,10 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
}
def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
- AutomaticAllocationScope, RecursiveMemoryEffects,
- AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+ AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
- ReductionClauseInterface]> {
+ RecursiveMemoryEffects, ReductionClauseInterface,
+ SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "taskloop construct";
let description = [{
The taskloop construct specifies that the iterations of one or more
@@ -1027,21 +1041,19 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
iterations are distributed across tasks generated by the construct and
scheduled to be executed.
- The `lowerBound` and `upperBound` specify a half-open range: the range
- includes the lower bound but does not include the upper bound. If the
- `inclusive` attribute is specified then the upper bound is also included.
- The `step` specifies the loop step.
-
- The body region can contain any number of blocks.
+ The body region can contain a single block which must contain a single
+ operation and a terminator. The operation must be another compatible loop
+ wrapper or an `omp.loop_nest`.
```
- omp.taskloop <clauses>
- for (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
- %a = load %arrA[%i1, %i2] : memref<?x?xf32>
- %b = load %arrB[%i1, %i2] : memref<?x?xf32>
- %sum = arith.addf %a, %b : f32
- store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
- omp.terminator
+ omp.taskloop <clauses> {
+ omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+ %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+ %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+ %sum = arith.addf %a, %b : f32
+ store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+ omp.yield
+ }
}
```
@@ -1118,11 +1130,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
created.
}];
- let arguments = (ins Variadic<IntLikeType>:$lowerBound,
- Variadic<IntLikeType>:$upperBound,
- Variadic<IntLikeType>:$step,
- UnitAttr:$inclusive,
- Optional<I1>:$if_expr,
+ let arguments = (ins Optional<I1>:$if_expr,
Optional<I1>:$final_expr,
UnitAttr:$untied,
UnitAttr:$mergeable,
@@ -1165,8 +1173,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
|`grain_size` `(` $grain_size `:` type($grain_size) `)`
|`num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
|`nogroup` $nogroup
- ) `for` custom<LoopControl>($region, $lowerBound, $upperBound, $step,
- type($step), $inclusive) attr-dict
+ ) $region attr-dict
}];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cff3de0a69af..3687891fe4b7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -130,11 +130,11 @@ def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>]>;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 147bc2354977..332b5ad08ced 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -420,7 +420,7 @@ def Vector_ShuffleOp :
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
InferTypeOpAdaptor]>,
- Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
+ Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
I64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
@@ -444,6 +444,8 @@ def Vector_ShuffleOp :
mask values must be within range, viz. given two k-D operands v1 and v2
above, all mask values are in the range [0,s_1+t_1)
+ Note, scalable vectors are not supported.
+
Example:
```mlir
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961c..ac17ace5a976 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include <queue>
@@ -111,6 +112,39 @@ protected:
public:
static char ID;
+ /// A variable that can be added to the constraint set as a "column". The
+ /// value bounds infrastructure can compute bounds for variables and compare
+ /// two variables.
+ ///
+ /// Internally, a variable is represented as an affine map and operands.
+ class Variable {
+ public:
+ /// Construct a variable for an index-typed attribute or SSA value.
+ Variable(OpFoldResult ofr);
+
+ /// Construct a variable for an index-typed SSA value.
+ Variable(Value indexValue);
+
+ /// Construct a variable for a dimension of a shaped value.
+ Variable(Value shapedValue, int64_t dim);
+
+ /// Construct a variable for an index-typed attribute/SSA value or for a
+ /// dimension of a shaped value. A non-null dimension must be provided if
+ /// and only if `ofr` is a shaped value.
+ Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+ /// Construct a variable for a map and its operands.
+ Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+ Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+ MLIRContext *getContext() const { return map.getContext(); }
+
+ private:
+ friend class ValueBoundsConstraintSet;
+ AffineMap map;
+ ValueDimList mapOperands;
+ };
+
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ public:
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
- /// Compute a bound for the given index-typed value or shape dimension size.
- /// The computed bound is stored in `resultMap`. The operands of the bound are
- /// stored in `mapOperands`. An operand is either an index-type SSA value
- /// or a shaped value and a dimension.
+ /// Compute a bound for the given variable. The computed bound is stored in
+ /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+ /// operand is either an index-type SSA value or a shaped value and a
+ /// dimension.
///
- /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
- /// is computed in terms of values/dimensions for which `stopCondition`
- /// evaluates to "true". To that end, the backward slice (reverse use-def
- /// chain) of the given value is visited in a worklist-driven manner and the
- /// constraint set is populated according to `ValueBoundsOpInterface` for each
- /// visited value.
+ /// The bound is computed in terms of values/dimensions for which
+ /// `stopCondition` evaluates to "true". To that end, the backward slice
+ /// (reverse use-def chain) of the given value is visited in a worklist-driven
+ /// manner and the constraint set is populated according to
+ /// `ValueBoundsOpInterface` for each visited value.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
- static LogicalResult computeBound(AffineMap &resultMap,
- ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim,
- StopConditionFn stopCondition,
- bool closedUB = false);
+ static LogicalResult
+ computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+ presburger::BoundType type, const Variable &var,
+ StopConditionFn stopCondition, bool closedUB = false);
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
/// computed bound consists of only constant terms and dependent values (or
/// dimension sizes thereof).
static LogicalResult
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueDimList dependencies, bool closedUB = false);
/// Compute a bound in that is independent of all values in `independencies`.
///
@@ -161,13 +191,10 @@ public:
/// appear in the computed bound.
static LogicalResult
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueRange independencies, bool closedUB = false);
- /// Compute a constant bound for the given affine map, where dims and symbols
- /// are bound to the given operands. The affine map must have exactly one
- /// result.
+ /// Compute a constant bound for the given variable.
///
/// This function traverses the backward slice of the given operands in a
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ public:
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t>
- computeConstantBound(presburger::BoundType type, Value value,
- std::optional<int64_t> dim = std::nullopt,
+ computeConstantBound(presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
/// Compute a constant delta between the given two values. Return "failure"
/// if a constant delta could not be determined.
@@ -221,9 +241,8 @@ public:
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
- bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
+ bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
+ const Variable &rhs);
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
/// specified relation could not be proven. This could be because the
@@ -233,24 +252,12 @@ public:
///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
- static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
- static bool compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands);
- static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands);
-
- /// Compute whether the given values/dimensions are equal. Return "failure" if
+ static bool compare(const Variable &lhs, ComparisonOperator cmp,
+ const Variable &rhs);
+
+ /// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
- ///
- /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
- /// index-typed.
- static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1 = std::nullopt,
- std::optional<int64_t> dim2 = std::nullopt);
+ static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +324,6 @@ protected:
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
- bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
/// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +378,7 @@ protected:
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+ int64_t insert(const Variable &var, bool isSymbol = true);
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
@@ -381,6 +386,8 @@ protected:
/// Project out all columns for which the condition holds.
void projectOut(function_ref<bool(ValueDim)> condition);
+ void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
/// Mapping of columns to values/shape dimensions.
SmallVector<std::optional<ValueDim>> positionToValueDim;
/// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412..d8dd1c93722b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
+ // It is not safe to cache constants across regions.
+ // New constants could potentially violate dominance requirements.
+ IndexPool localPool;
+
// Emit 'tensor.empty' op
SmallVector<OpFoldResult> outputTensorShape;
for (auto index : llvm::seq<int64_t>(0, rank)) {
auto size = index == dim ? targetSize
- : getOrFoldTensorDim(rewriter, loc, indexPool,
+ : getOrFoldTensorDim(rewriter, loc, localPool,
operand, index);
outputTensorShape.push_back(size);
}
@@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
- size_t rank = operand.getType().cast<RankedTensorType>().getRank();
- assert(targetShape.size() == rank);
- assert(masterOperands.size() == rank);
+ int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
+ assert((int64_t)targetShape.size() == rank);
+ assert((int64_t)masterOperands.size() == rank);
for (auto index : llvm::seq<int64_t>(0, rank))
operand =
broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3f39cbf03a9a..8fb8d1648656 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
#include <numeric>
#include <type_traits>
@@ -34,7 +36,7 @@ using namespace mlir::tosa;
static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
TypedAttr padAttr, OpBuilder &rewriter) {
- // Input should be padded if necessary.
+ // Input should be padded only if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;
@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
SmallVector<int64_t, 4> paddedShape;
SmallVector<OpFoldResult, 8> lowIndices;
SmallVector<OpFoldResult, 8> highIndices;
- for (int i = 0, s = inputShape.size(); i < s; i++) {
+ for (size_t i : llvm::seq(inputShape.size())) {
auto lowPad = pad[i * 2];
auto highPad = pad[i * 2 + 1];
if (ShapedType::isDynamic(inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
- return builder.createOrFold<arith::IndexCastOp>(
- builder.getIndexType(),
- builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr)));
+ return builder.create<arith::ConstantIndexOp>(attr);
}
// Calculating the output width/height using the formula:
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
-static mlir::Value getConvOutputDim(Location loc, Value inputDim,
- int64_t padBeforeAttr, int64_t padAfterAttr,
- Value kernelDim, int64_t strideAttr,
- int64_t dilationAttr, Type inputETy,
- OpBuilder &rewriter) {
+static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
+ int64_t padBeforeAttr,
+ int64_t padAfterAttr, Value kernelDim,
+ int64_t strideAttr,
+ int64_t dilationAttr,
+ OpBuilder &rewriter) {
ImplicitLocOpBuilder builder(loc, rewriter);
auto one = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(inputDim.getType(), 1));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
ShapedType inputTy = cast<ShapedType>(input.getType());
- Type inputETy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();
SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
dynDims[inputDim] =
- getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
- stride, dilation, inputETy, rewriter);
+ getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
+ kernelDynDim, stride, dilation, rewriter);
}
}
@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
public:
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+ // Compute the dynamic output sizes of the maxpool operation.
+ static SmallVector<Value>
+ computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+ TensorType resultTy = op.getType();
+ Location loc = op.getLoc();
+
+ TypedValue<TensorType> input = op.getInput();
+ ArrayRef<int64_t> kernel = op.getKernel();
+ ArrayRef<int64_t> pad = op.getPad();
+ ArrayRef<int64_t> stride = op.getStride();
+
+ SmallVector<Value> dynamicDims;
+
+ // Batch dimension
+ if (resultTy.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+ // Height/width dimensions
+ for (int64_t dim : {1, 2}) {
+ if (!resultTy.isDynamicDim(dim))
+ continue;
+
+ // Index into the attribute arrays
+ int64_t index = dim - 1;
+
+ // Input height/width
+ Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
+
+ // Kernel height/width
+ Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
+
+ // Output height/width
+ Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
+ pad[index * 2 + 1], khw, stride[index],
+ /*dilationAttr=*/1, rewriter);
+ dynamicDims.push_back(ohw);
+ }
+
+ // Channel dimension
+ if (resultTy.isDynamicDim(3))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
+
+ return dynamicDims;
+ }
+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- Value input = op.getInput();
- ShapedType inputTy = cast<ShapedType>(input.getType());
+ TypedValue<TensorType> input = op.getInput();
+ ShapedType inputTy = input.getType();
- ShapedType resultTy = cast<ShapedType>(op.getType());
+ ShapedType resultTy = op.getType();
Type resultETy = inputTy.getElementType();
- auto dynamicDimsOr =
- checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
- if (!dynamicDimsOr.has_value())
- return failure();
- SmallVector<Value> dynamicDims = *dynamicDimsOr;
+ SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
@@ -721,6 +762,7 @@ public:
pad.resize(2, 0);
llvm::append_range(pad, op.getPad());
pad.resize(pad.size() + 2, 0);
+
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ public:
loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
Value filledEmptyTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{initialValue},
- ValueRange{emptyTensor})
+ rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
.result();
Value fakeWindowDims =
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f7..82a9fb0d4908 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
mapOperands.push_back(value1);
mapOperands.push_back(value2);
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
- ValueDimList valueDims;
- for (Value v : mapOperands)
- valueDims.push_back({v, std::nullopt});
return ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::EQ, map, valueDims);
+ presburger::BoundType::EQ,
+ ValueBoundsConstraintSet::Variable(map, mapOperands));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701a..1a266b72d1f8 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -16,16 +16,15 @@
using namespace mlir;
using namespace mlir::affine;
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB) {
+FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Reify bound.
@@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
- return reifyValueBound(b, loc, type, value, dim,
+ return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
@@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
- return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index f0d43808bc45..7cfcc4180539 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -107,9 +107,9 @@ struct SelectOpInterface
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
- if (cstr.compare(trueValue, dim,
+ if (cstr.compare(/*lhs=*/{trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- falseValue, dim)) {
+ /*rhs=*/{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
- if (cstr.compare(falseValue, dim,
+ if (cstr.compare(/*lhs=*/{falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- trueValue, dim)) {
+ /*rhs=*/{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e9..f87f3d6350c0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
return failure();
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, in,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(ub))
return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f19..5fb7953f9370 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
return buildExpr(map.getResult(0));
}
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB) {
+FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Materialize tensor.dim/memref.dim ops.
@@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
- return reifyValueBound(b, loc, type, value, dim,
+ return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
@@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
- return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c62c0d6..b595c6dd8a68 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
}
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+ : public OneToNOpConversionPattern<arith::ConstantOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!vectorType || !denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(constantOp,
+ kMatchFailureNotSMETileTypeMultiple);
+
+ auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+ auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+ auto tileSplat = rewriter.create<arith::ConstantOp>(
+ constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+ rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+ adaptor.getResultMapping());
+
+ return success();
+ }
+};
+
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
- patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+ patterns.add<LegalizeArithConstantOpsByDecomposition,
+ LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
populateFuncTypeConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489..518d2e138c02 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, opOperand->get(),
- /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+ presburger::BoundType::UB,
+ {opOperand->get(),
+ /*dim=*/i},
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d0..71eb59d40836 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
- Value materializedSize =
- getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, rangeValue.size,
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
- ? materializedSize
+ ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7..1f06318cbd60 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueRange independencies) {
if (ofr.is<Attribute>())
return ofr;
- Value value = ofr.get<Value>();
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
- boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+ /*closedUB=*/true)))
return failure();
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90b49b2528b7..e500d0fca741 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1656,6 +1656,17 @@ LogicalResult DistributeOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
+ if (!isWrapper())
+ return emitOpError() << "must be a loop wrapper";
+
+ if (LoopWrapperInterface nested = getNestedWrapper()) {
+ // Check for the allowed leaf constructs that may appear in a composite
+ // construct directly after DISTRIBUTE.
+ if (!isa<ParallelOp, SimdLoopOp>(nested))
+ return emitError() << "only supported nested wrappers are 'omp.parallel' "
+ "and 'omp.simdloop'";
+ }
+
return success();
}
@@ -1818,9 +1829,8 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
TaskloopOp::build(
- builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
- clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar,
- clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars,
+ builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
+ clauses.mergeableAttr, clauses.inReductionVars,
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
@@ -1859,6 +1869,16 @@ LogicalResult TaskloopOp::verify() {
"the grainsize clause and num_tasks clause are mutually exclusive and "
"may not appear on the same taskloop directive");
}
+
+ if (!isWrapper())
+ return emitOpError() << "must be a loop wrapper";
+
+ if (LoopWrapperInterface nested = getNestedWrapper()) {
+ // Check for the allowed leaf constructs that may appear in a composite
+ // construct directly after TASKLOOP.
+ if (!isa<SimdLoopOp>(nested))
+ return emitError() << "only supported nested wrapper is 'omp.simdloop'";
+ }
return success();
}
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a83..17a1c016ea16 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
if (cstr.populateAndCompare(
- yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
- iterArg, dim)) {
+ /*lhs=*/{yieldedValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::EQ,
+ /*rhs=*/{iterArg, dim})) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
- cstr.bound(value) == initArg;
+ cstr.bound(value) == cstr.getExpr(initArg);
}
}
}
@@ -113,8 +114,9 @@ struct IfOpInterface
// * result <= elseValue
// * result >= thenValue
if (cstr.populateAndCompare(
- thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- elseValue, dim)) {
+ /*lhs=*/{thenValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{elseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
// * result <= thenValue
// * result >= elseValue
if (cstr.populateAndCompare(
- elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- thenValue, dim)) {
+ /*lhs=*/{elseValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{thenValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index fea2f659535b..7b4024b6861a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
Block *afterBody = loop.getAfterBody();
scf::YieldOp afterTerm = loop.getYieldOp();
- auto argNumber = inductionVar.getArgNumber();
- auto afterTermIndArg = afterTerm.getResults()[argNumber];
+ unsigned argNumber = inductionVar.getArgNumber();
+ Value afterTermIndArg = afterTerm.getResults()[argNumber];
- auto inductionVarAfter = afterBody->getArgument(argNumber);
-
- Value step;
+ Value inductionVarAfter = afterBody->getArgument(argNumber);
// Find suitable `addi` op inside `after` block, one of the args must be an
// Induction var passed from `before` block and second arg must be defined
// outside of the loop and will be considered step value.
// TODO: Add `subi` support?
- for (auto &use : inductionVarAfter.getUses()) {
- auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
- if (!owner)
- continue;
-
- auto other =
- (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
- if (!dom.properlyDominates(other, loop))
- continue;
-
- if (afterTermIndArg != owner.getResult())
- continue;
+ auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
+ if (!addOp)
+ return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
- step = other;
- break;
+ Value step;
+ if (addOp.getLhs() == inductionVarAfter) {
+ step = addOp.getRhs();
+ } else if (addOp.getRhs() == inductionVarAfter) {
+ step = addOp.getLhs();
}
- if (!step)
- return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
+ if (!step || !dom.properlyDominates(step, loop))
+ return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
- auto lb = loop.getInits()[argNumber];
+ Value lb = loop.getInits()[argNumber];
assert(lb.getType().isIntOrIndex());
assert(lb.getType() == ub.getType());
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8e301c..d25efcf50ec5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
info.isAlignedToInnerTileSize = false;
FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB,
- getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+ presburger::BoundType::UB, tileSize,
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
if (!failed(cstSize) && cstInnerSize) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 721730862d49..a89ce20048df 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ independencies,
+ /*closedUB=*/true)))
return failure();
return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2f7a17..15381ec520e2 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), srcDim, resultDim);
+ {op.getSource(), srcDim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), dim, resultDim);
+ {op.getSource(), dim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++resultDim;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ffa4c0b55cad..87937591e60a 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,12 @@ namespace mlir {
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
} // namespace mlir
+static Operation *getOwnerOfValue(Value value) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
+ return bbArg.getOwner()->getParentOp();
+ return value.getDefiningOp();
+}
+
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides)
@@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
+ : Variable(ofr, std::nullopt) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
+ : Variable(static_cast<OpFoldResult>(indexValue)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
+ : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
+ std::optional<int64_t> dim) {
+ Builder b(ofr.getContext());
+ if (auto constInt = ::getConstantIntValue(ofr)) {
+ assert(!dim && "expected no dim for index-typed values");
+ map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+ b.getAffineConstantExpr(*constInt));
+ return;
+ }
+ Value value = cast<Value>(ofr);
+#ifndef NDEBUG
+ if (dim) {
+ assert(isa<ShapedType>(value.getType()) && "expected shaped type");
+ } else {
+ assert(value.getType().isIndex() && "expected index type");
+ }
+#endif // NDEBUG
+ map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+ b.getAffineSymbolExpr(0));
+ mapOperands.emplace_back(value, dim);
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+ ArrayRef<Variable> mapOperands) {
+ assert(map.getNumResults() == 1 && "expected single result");
+
+ // Turn all dims into symbols.
+ Builder b(map.getContext());
+ SmallVector<AffineExpr> dimReplacements, symReplacements;
+ for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
+ dimReplacements.push_back(b.getAffineSymbolExpr(i));
+ for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
+ symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
+ AffineMap tmpMap = map.replaceDimsAndSymbols(
+ dimReplacements, symReplacements, /*numResultDims=*/0,
+ /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
+
+ // Inline operands.
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ for (auto [index, var] : llvm::enumerate(mapOperands)) {
+ assert(var.map.getNumResults() == 1 && "expected single result");
+ assert(var.map.getNumDims() == 0 && "expected only symbols");
+ SmallVector<AffineExpr> symReplacements;
+ for (auto valueDim : var.mapOperands) {
+ auto it = llvm::find(this->mapOperands, valueDim);
+ if (it != this->mapOperands.end()) {
+ // There is already a symbol for this operand.
+ symReplacements.push_back(b.getAffineSymbolExpr(
+ std::distance(this->mapOperands.begin(), it)));
+ } else {
+ // This is a new operand: add a new symbol.
+ symReplacements.push_back(
+ b.getAffineSymbolExpr(this->mapOperands.size()));
+ this->mapOperands.push_back(valueDim);
+ }
+ }
+ replacements[b.getAffineSymbolExpr(index)] =
+ var.map.getResult(0).replaceSymbols(symReplacements);
+ }
+ this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
+ /*numResultSyms=*/this->mapOperands.size());
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+ ArrayRef<Value> mapOperands)
+ : Variable(map, llvm::map_to_vector(mapOperands,
+ [](Value v) { return Variable(v); })) {}
+
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
MLIRContext *ctx, StopConditionFn stopCondition)
: builder(ctx), stopCondition(stopCondition) {
@@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
assert(!valueDimToPosition.contains(valueDim) && "already mapped");
int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
: cstr.appendVar(VarKind::SetDim);
+ LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
+ << " for: " << value
+ << " (dim: " << dim.value_or(kIndexValue)
+ << ", owner: " << getOwnerOfValue(value)->getName()
+ << ")\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
: cstr.appendVar(VarKind::SetDim);
+ LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
+ << "\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
return pos;
}
+int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
+ return insert(var.map, var.mapOperands, isSymbol);
+}
+
int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
@@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
"unstructured control flow is not supported");
#endif // NDEBUG
-
+ LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
+ << " (dim: " << dim.value_or(kIndexValue)
+ << ", owner: " << getOwnerOfValue(value)->getName()
+ << ")\n");
auto it =
valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
assert(it != valueDimToPosition.end() && "expected mapped entry");
@@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value,
return it != valueDimToPosition.end();
}
-static Operation *getOwnerOfValue(Value value) {
- if (auto bbArg = dyn_cast<BlockArgument>(value))
- return bbArg.getOwner()->getParentOp();
- return value.getDefiningOp();
-}
-
void ValueBoundsConstraintSet::processWorklist() {
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
@@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut(
}
}
+void ValueBoundsConstraintSet::projectOutAnonymous(
+ std::optional<int64_t> except) {
+ int64_t nextPos = 0;
+ while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+ if (positionToValueDim[nextPos].has_value() || except == nextPos) {
+ ++nextPos;
+ } else {
+ projectOut(nextPos);
+ // The column was projected out so another column is now at that position.
+ // Do not increase the counter.
+ }
+ }
+}
+
LogicalResult ValueBoundsConstraintSet::computeBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, StopConditionFn stopCondition,
- bool closedUB) {
-#ifndef NDEBUG
- assertValidValueDim(value, dim);
-#endif // NDEBUG
-
+ const Variable &var, StopConditionFn stopCondition, bool closedUB) {
+ MLIRContext *ctx = var.getContext();
int64_t ubAdjustment = closedUB ? 0 : 1;
- Builder b(value.getContext());
+ Builder b(ctx);
mapOperands.clear();
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
- ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
- ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
- assert(!stopCondition(value, dim, cstr) &&
- "stop condition should not be satisfied for starting point");
- int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+ ValueBoundsConstraintSet cstr(ctx, stopCondition);
+ int64_t pos = cstr.insert(var, /*isSymbol=*/false);
+ assert(pos == 0 && "expected first column");
cstr.processWorklist();
// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
cstr.projectOut([&](ValueDim p) {
- // Do not project out `valueDim`.
- if (valueDim == p)
- return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
return !stopCondition(p.first, maybeDim, cstr);
});
+ cstr.projectOutAnonymous(/*except=*/pos);
// Compute lower and upper bounds for `valueDim`.
SmallVector<AffineMap> lb(1), ub(1);
- cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+ cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
/*closedUB=*/true);
// Note: There are TODOs in the implementation of `getSliceBounds`. In such a
@@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
LogicalResult ValueBoundsConstraintSet::computeDependentBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB) {
+ const Variable &var, ValueDimList dependencies, bool closedUB) {
return computeBound(
- resultMap, mapOperands, type, value, dim,
+ resultMap, mapOperands, type, var,
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
@@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB) {
+ const Variable &var, ValueRange independencies, bool closedUB) {
// Return "true" if the given value is independent of all values in
// `independencies`. I.e., neither the value itself nor any value in the
// backward slice (reverse use-def chain) is contained in `independencies`.
@@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
- resultMap, mapOperands, type, value, dim,
+ resultMap, mapOperands, type, var,
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return isIndependent(v);
},
@@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
}
FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, Value value, std::optional<int64_t> dim,
- StopConditionFn stopCondition, bool closedUB) {
-#ifndef NDEBUG
- assertValidValueDim(value, dim);
-#endif // NDEBUG
-
- AffineMap map =
- AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
- Builder(value.getContext()).getAffineDimExpr(0));
- return computeConstantBound(type, map, {{value, dim}}, stopCondition,
- closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+ presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition, bool closedUB) {
- ValueDimList valueDims;
- for (Value v : operands) {
- assert(v.getType().isIndex() && "expected index type");
- valueDims.emplace_back(v, std::nullopt);
- }
- return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList operands,
- StopConditionFn stopCondition, bool closedUB) {
- assert(map.getNumResults() == 1 && "expected affine map with one result");
-
// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos = 0;
@@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
};
ValueBoundsConstraintSet cstr(
- map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
- pos = cstr.populateConstraints(map, operands);
+ var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+ pos = cstr.populateConstraints(var.map, var.mapOperands);
assert(pos == 0 && "expected `map` is the first column");
// Compute constant bound for `valueDim`.
@@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
Builder b(value1.getContext());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
- return computeConstantBound(presburger::BoundType::EQ, map,
- {{value1, dim1}, {value2, dim2}});
+ return computeConstantBound(presburger::BoundType::EQ,
+ Variable(map, {{value1, dim1}, {value2, dim2}}));
}
-bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
- std::optional<int64_t> lhsDim,
- ComparisonOperator cmp,
- OpFoldResult rhs,
- std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
- if (auto lhsVal = dyn_cast<Value>(lhs))
- assertValidValueDim(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+ ComparisonOperator cmp,
+ int64_t rhsPos) {
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -642,50 +701,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
// EQ can be expressed as LE and GE.
if (cmp == EQ)
- return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
- compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
-
- // Construct inequality. For the above example: lhs > rhs.
- // `IntegerRelation` inequalities are expressed in the "flattened" form and
- // with ">= 0". I.e., lhs - rhs - 1 >= 0.
- SmallVector<int64_t> eq(cstr.getNumCols(), 0);
- auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
- int64_t factor) {
- if (auto constVal = ::getConstantIntValue(ofr)) {
- eq[cstr.getNumCols() - 1] += *constVal * factor;
- } else {
- eq[getPos(cast<Value>(ofr), dim)] += factor;
- }
- };
- if (cmp == LT || cmp == LE) {
- addToEq(lhs, lhsDim, 1);
- addToEq(rhs, rhsDim, -1);
- } else if (cmp == GT || cmp == GE) {
- addToEq(lhs, lhsDim, -1);
- addToEq(rhs, rhsDim, 1);
- } else {
- llvm_unreachable("unsupported comparison operator");
- }
- if (cmp == LE || cmp == GE)
- eq[cstr.getNumCols() - 1] -= 1;
-
- // Add inequality to the constraint set and check if it made the constraint
- // set empty.
- int64_t ineqPos = cstr.getNumInequalities();
- cstr.addInequality(eq);
- bool isEmpty = cstr.isEmpty();
- cstr.removeInequality(ineqPos);
- return isEmpty;
-}
-
-bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
- ComparisonOperator cmp,
- int64_t rhsPos) {
- // This function returns "true" if "lhs CMP rhs" is proven to hold. For
- // detailed documentation, see `compareValueDims`.
-
- // EQ can be expressed as LE and GE.
- if (cmp == EQ)
return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
@@ -712,48 +727,17 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
-bool ValueBoundsConstraintSet::populateAndCompare(
- OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
- OpFoldResult rhs, std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
- if (auto lhsVal = dyn_cast<Value>(lhs))
- assertValidValueDim(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
- if (auto lhsVal = dyn_cast<Value>(lhs))
- populateConstraints(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- populateConstraints(rhsVal, rhsDim);
-
- return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs) {
+ int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
+ int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
+ return comparePos(lhsPos, cmp, rhsPos);
}
-bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
- std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim) {
- auto stopCondition = [&](Value v, std::optional<int64_t> dim,
- ValueBoundsConstraintSet &cstr) {
- // Keep processing as long as lhs/rhs are not mapped.
- if (auto lhsVal = dyn_cast<Value>(lhs))
- if (!cstr.isMapped(lhsVal, dim))
- return false;
- if (auto rhsVal = dyn_cast<Value>(rhs))
- if (!cstr.isMapped(rhsVal, dim))
- return false;
- // Keep processing as long as the relation cannot be proven.
- return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
- };
-
- ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
- return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
-}
-
-bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands) {
+bool ValueBoundsConstraintSet::compare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs) {
int64_t lhsPos = -1, rhsPos = -1;
auto stopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
@@ -765,39 +749,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
return cstr.comparePos(lhsPos, cmp, rhsPos);
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
- lhsPos = cstr.insert(lhs, lhsOperands);
- rhsPos = cstr.insert(rhs, rhsOperands);
- cstr.processWorklist();
+ lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+ rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
-bool ValueBoundsConstraintSet::compare(AffineMap lhs,
- ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands) {
- ValueDimList lhsValueDimOperands =
- llvm::map_to_vector(lhsOperands, [](Value v) {
- return std::make_pair(v, std::optional<int64_t>());
- });
- ValueDimList rhsValueDimOperands =
- llvm::map_to_vector(rhsOperands, [](Value v) {
- return std::make_pair(v, std::optional<int64_t>());
- });
- return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
- rhsValueDimOperands);
-}
-
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1,
- std::optional<int64_t> dim2) {
- if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
- value2, dim2))
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
+ const Variable &var2) {
+ if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
return true;
- if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
- value2, dim2) ||
- ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
- value2, dim2))
+ if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
+ ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
return false;
return failure();
}
@@ -833,7 +795,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
FailureOr<int64_t> constBound = computeConstantBound(
- presburger::BoundType::EQ, foldedMap, valueOperands);
+ presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
foundUnknownBound |= failed(constBound);
if (succeeded(constBound) && *constBound <= 0)
return false;
@@ -850,7 +812,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
FailureOr<int64_t> constBound = computeConstantBound(
- presburger::BoundType::EQ, foldedMap, valueOperands);
+ presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
foundUnknownBound |= failed(constBound);
if (succeeded(constBound) && *constBound <= 0)
return false;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index e64903671e59..b4049000c50d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -215,6 +216,59 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
return
}
+// CHECK-CSE-LABEL: @max_pool_all_dynamic
+func.func @max_pool_all_dynamic(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // Batch size
+ // CHECK-CSE: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK-CSE: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x?xf32>
+
+ // Compute output height
+ // CHECK-CSE: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-CSE: %[[IH:.+]] = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x?xf32>
+ // CHECK-CSE: %[[C2:.+]] = arith.constant 2 : index
+ // CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IH]], %[[C0]] : index
+ // CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C0]] : index
+ // CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C2]], %[[C1]] : index
+ // CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
+ // CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
+ // CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
+ // CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
+ // CHECK-CSE: %[[HEIGHT:.+]] = arith.addi %[[DIVIDE]], %[[C1]] : index
+
+ // Compute output width
+ // CHECK-CSE: %[[IW:.+]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x?xf32>
+ // CHECK-CSE: %[[C5:.+]] = arith.constant 5 : index
+ // CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IW]], %[[C2]] : index
+ // CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C2]] : index
+ // CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C5]], %[[C1]] : index
+ // CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
+ // CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
+ // CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
+ // CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
+ // CHECK-CSE: %[[WIDTH:.+]] = arith.addi %14, %[[C1]] : index
+
+ // Channel size
+ // CHECK-CSE: %[[C3:.+]] = arith.constant 3 : index
+ // CHECK-CSE: %[[CHANNEL:.+]] = tensor.dim %arg0, %[[C3]] : tensor<?x?x?x?xf32>
+
+ // Pad the input
+ // CHECK-CSE: %[[FLOAT_MIN:.+]] = arith.constant -3.40282347E+38 : f32
+ // CHECK-CSE: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 0, 2, 0] high[0, 0, 2, 0] {
+ // CHECK-CSE: tensor.yield %[[FLOAT_MIN]] : f32
+
+ // Allocate the output and fill with minimum value
+ // CHECK-CSE: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[HEIGHT]], %[[WIDTH]], %[[CHANNEL]]) : tensor<?x?x?x?xf32>
+ // CHECK-CSE: %[[FILL:.+]] = linalg.fill ins(%[[FLOAT_MIN]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ // CHECK-CSE: %[[FAKE_WINDOW:.+]] = tensor.empty() : tensor<2x5xf32>
+
+ // Compute max pool
+ // CHECK-CSE: %[[OUT:.+]] = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PADDED]], %[[FAKE_WINDOW]] : tensor<?x?x?x?xf32>, tensor<2x5xf32>) outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ // CHECK-CSE: return %[[OUT]]
+
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 2, 5>, pad = array<i64: 0, 0, 2, 2>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
// -----
// CHECK-LABEL: @avg_pool_f32
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04..445e8be47678 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -270,7 +270,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
// CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
- // CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
+ // CHECK: %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
@@ -284,7 +285,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
// CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
- // CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+ // CHECK: %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
// CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
@@ -298,7 +300,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
// CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
- // CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
+ // CHECK: %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
@@ -312,7 +315,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
// CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
- // CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+ // CHECK: %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_24]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 23c6872dcebe..935c08aceff5 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -131,3 +131,27 @@ func.func @compare_affine_min(%a: index, %b: index) {
"test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
return
}
+
+// -----
+
+func.func @compare_const_map() {
+ %c5 = arith.constant 5 : index
+ // expected-remark @below{{true}}
+ "test.compare"(%c5) {cmp = "GT", rhs_map = affine_map<() -> (4)>}
+ : (index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%c5) {cmp = "LT", lhs_map = affine_map<() -> (4)>}
+ : (index) -> ()
+ return
+}
+
+// -----
+
+func.func @compare_maps(%a: index, %b: index) {
+ // expected-remark @below{{true}}
+ "test.compare"(%a, %b, %b, %a)
+ {cmp = "GT", lhs_map = affine_map<(d0, d1) -> (1 + d0 + d1)>,
+ rhs_map = affine_map<(d0, d1) -> (d0 + d1)>}
+ : (index, index, index, index) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f8be697548c1..f43ef1cce787 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
return %cast : vector<[4]xf32>
}
+
+// -----
+
+// CHECK-LABEL: @multi_tile_splat
+func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
+{
+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
+ // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+ %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
+ return %0 : vector<[8]x[8]xi32>
+}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 88dca1b85ee5..7f86a7f5b318 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1580,10 +1580,11 @@ func.func @omp_cancellationpoint2() {
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testmemref = "test.memref"() : () -> (memref<i32>)
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
- "omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testmemref) ({
- ^bb0(%arg3: i32, %arg4: i32):
- "omp.terminator"() : () -> ()
- }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0>} : (i32, i32, i32, i32, i32, i32, memref<i32>) -> ()
+ "omp.taskloop"(%testmemref) ({
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
return
}
@@ -1593,10 +1594,11 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
- "omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32, %testf32_2) ({
- ^bb0(%arg3: i32, %arg4: i32):
- "omp.terminator"() : () -> ()
- }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0>, reductions = [@add_f32]} : (i32, i32, i32, i32, i32, i32, !llvm.ptr, !llvm.ptr) -> ()
+ "omp.taskloop"(%testf32, %testf32_2) ({
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }) {operandSegmentSizes = array<i32: 0, 0, 0, 2, 0, 0, 0, 0, 0>, reductions = [@add_f32]} : (!llvm.ptr, !llvm.ptr) -> ()
return
}
@@ -1604,12 +1606,12 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
- %testf32_2 = "test.f32"() : () -> (!llvm.ptr)
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
- "omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32) ({
- ^bb0(%arg3: i32, %arg4: i32):
- "omp.terminator"() : () -> ()
- }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0>, reductions = [@add_f32, @add_f32]} : (i32, i32, i32, i32, i32, i32, !llvm.ptr) -> ()
+ "omp.taskloop"(%testf32) ({
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }) {operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0, 0>, reductions = [@add_f32, @add_f32]} : (!llvm.ptr) -> ()
return
}
@@ -1619,10 +1621,11 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
- "omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32, %testf32_2) ({
- ^bb0(%arg3: i32, %arg4: i32):
- "omp.terminator"() : () -> ()
- }) {in_reductions = [@add_f32], operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0>} : (i32, i32, i32, i32, i32, i32, !llvm.ptr, !llvm.ptr) -> ()
+ "omp.taskloop"(%testf32, %testf32_2) ({
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }) {in_reductions = [@add_f32], operandSegmentSizes = array<i32: 0, 0, 2, 0, 0, 0, 0, 0, 0>} : (!llvm.ptr, !llvm.ptr) -> ()
return
}
@@ -1630,12 +1633,12 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
- %testf32_2 = "test.f32"() : () -> (!llvm.ptr)
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
- "omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32_2) ({
- ^bb0(%arg3: i32, %arg4: i32):
- "omp.terminator"() : () -> ()
- }) {in_reductions = [@add_f32, @add_f32], operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0>} : (i32, i32, i32, i32, i32, i32, !llvm.ptr) -> ()
+ "omp.taskloop"(%testf32) ({
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }) {in_reductions = [@add_f32, @add_f32], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0>} : (!llvm.ptr) -> ()
return
}
@@ -1657,9 +1660,10 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
// expected-error @below {{if a reduction clause is present on the taskloop directive, the nogroup clause must not be specified}}
- omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) nogroup
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- omp.terminator
+ omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) nogroup {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
}
return
}
@@ -1681,9 +1685,10 @@ combiner {
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
// expected-error @below {{the same list item cannot appear in both a reduction and an in_reduction clause}}
- omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr) in_reduction(@add_f32 -> %testf32 : !llvm.ptr)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- omp.terminator
+ omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr) in_reduction(@add_f32 -> %testf32 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
}
return
}
@@ -1693,8 +1698,20 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testi64 = "test.i64"() : () -> (i64)
// expected-error @below {{the grainsize clause and num_tasks clause are mutually exclusive and may not appear on the same taskloop directive}}
- omp.taskloop grain_size(%testi64: i64) num_tasks(%testi64: i64)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.taskloop grain_size(%testi64: i64) num_tasks(%testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ omp.yield
+ }
+ }
+ return
+}
+
+// -----
+
+func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
+ // expected-error @below {{op must be a loop wrapper}}
+ omp.taskloop {
+ %0 = arith.constant 0 : i32
omp.terminator
}
return
@@ -1702,6 +1719,21 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// -----
+func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
+ // expected-error @below {{only supported nested wrapper is 'omp.simdloop'}}
+ omp.taskloop {
+ omp.distribute {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
func.func @omp_threadprivate() {
%1 = llvm.mlir.addressof @_QFsubEx : !llvm.ptr
// expected-error @below {{op failed to verify that all of {sym_addr, tls_addr} have same type}}
@@ -1866,7 +1898,16 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// -----
-func.func @omp_distribute(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_schedule(%chunk_size : i32) -> () {
+ // expected-error @below {{op chunk size set without dist_schedule_static being present}}
+ "omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 1, 0, 0>}> ({
+ "omp.terminator"() : () -> ()
+ }) : (i32) -> ()
+}
+
+// -----
+
+func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
"omp.terminator"() : () -> ()
@@ -1875,6 +1916,29 @@ func.func @omp_distribute(%data_var : memref<i32>) -> () {
// -----
+func.func @omp_distribute_wrapper() -> () {
+ // expected-error @below {{op must be a loop wrapper}}
+ "omp.distribute"() ({
+ %0 = arith.constant 0 : i32
+ "omp.terminator"() : () -> ()
+ }) : () -> ()
+}
+
+// -----
+
+func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+ // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simdloop'}}
+ "omp.distribute"() ({
+ "omp.wsloop"() ({
+ %0 = arith.constant 0 : i32
+ "omp.terminator"() : () -> ()
+ }) : () -> ()
+ "omp.terminator"() : () -> ()
+ }) : () -> ()
+}
+
+// -----
+
omp.private {type = private} @x.privatizer : i32 alloc {
^bb0(%arg0: i32):
%0 = arith.constant 0.0 : f32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 851d44ad984e..802e1795b3ff 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -171,6 +171,23 @@ func.func @omp_loop_nest(%lb : index, %ub : index, %step : index) -> () {
omp.yield
}
+ // TODO Remove induction variables from omp.wsloop.
+ omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+ // CHECK: omp.loop_nest
+ // CHECK-SAME: (%{{.*}}) : index =
+ // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+ "omp.loop_nest" (%lb, %ub, %step) ({
+ ^bb0(%iv2: index):
+ // CHECK: test.op1
+ "test.op1"(%lb) : (index) -> ()
+ // CHECK: test.op2
+ "test.op2"() : () -> ()
+ // CHECK: omp.yield
+ omp.yield
+ }) : (index, index, index) -> ()
+ omp.yield
+ }
+
return
}
@@ -209,6 +226,22 @@ func.func @omp_loop_nest_pretty(%lb : index, %ub : index, %step : index) -> () {
omp.yield
}
+ // TODO Remove induction variables from omp.wsloop.
+ omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+ // CHECK: omp.loop_nest
+ // CHECK-SAME: (%{{.*}}) : index =
+ // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+ omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+ // CHECK: test.op1
+ "test.op1"(%lb) : (index) -> ()
+ // CHECK: test.op2
+ "test.op2"() : () -> ()
+ // CHECK: omp.yield
+ omp.yield
+ }
+ omp.yield
+ }
+
return
}
@@ -559,30 +592,54 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
}
// CHECK-LABEL: omp_distribute
-func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>) -> () {
+func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i32) -> () {
// CHECK: omp.distribute
"omp.distribute" () ({
- omp.terminator
+ "omp.loop_nest" (%arg0, %arg0, %arg0) ({
+ ^bb0(%iv: i32):
+ "omp.yield"() : () -> ()
+ }) : (i32, i32, i32) -> ()
+ "omp.terminator"() : () -> ()
}) {} : () -> ()
// CHECK: omp.distribute
omp.distribute {
- omp.terminator
+ omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.yield
+ }
}
// CHECK: omp.distribute dist_schedule_static
omp.distribute dist_schedule_static {
- omp.terminator
+ omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.yield
+ }
}
// CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
- omp.terminator
+ omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.yield
+ }
}
// CHECK: omp.distribute order(concurrent)
omp.distribute order(concurrent) {
- omp.terminator
+ omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.yield
+ }
}
// CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
- omp.terminator
+ omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.yield
+ }
+ }
+ // CHECK: omp.distribute
+ omp.distribute {
+ // TODO Remove induction variables from omp.simdloop.
+ omp.simdloop for (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.loop_nest (%iv2) : i32 = (%arg0) to (%arg0) step (%arg0) {
+ omp.yield
+ }
+ omp.yield
+ }
}
return
}
@@ -2000,135 +2057,128 @@ func.func @omp_taskgroup_clauses() -> () {
// CHECK-LABEL: @omp_taskloop
func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
- // CHECK: omp.taskloop for (%{{.+}}) : i32 = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) {
- omp.taskloop for (%i) : i32 = (%lb) to (%ub) step (%step) {
- // CHECK: omp.terminator
- omp.terminator
- }
-
- // CHECK: omp.taskloop for (%{{.+}}) : i32 = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) {
- omp.taskloop for (%i) : i32 = (%lb) to (%ub) step (%step) {
- // CHECK: test.op1
- "test.op1"(%lb) : (i32) -> ()
- // CHECK: test.op2
- "test.op2"() : () -> ()
- // CHECK: omp.terminator
- omp.terminator
- }
-
- // CHECK: omp.taskloop for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
- }
-
- // CHECK: omp.taskloop for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) inclusive step (%{{.+}}, %{{.+}}) {
- omp.taskloop for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) inclusive step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop {
+ omp.taskloop {
+ omp.loop_nest (%i) : i32 = (%lb) to (%ub) step (%step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
%testbool = "test.bool"() : () -> (i1)
- // CHECK: omp.taskloop if(%{{[^)]+}})
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop if(%testbool)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop if(%{{[^)]+}}) {
+ omp.taskloop if(%testbool) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop final(%{{[^)]+}})
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop final(%testbool)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop final(%{{[^)]+}}) {
+ omp.taskloop final(%testbool) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop untied
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop untied
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop untied {
+ omp.taskloop untied {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop mergeable
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop mergeable
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop mergeable {
+ omp.taskloop mergeable {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
%testf32 = "test.f32"() : () -> (!llvm.ptr)
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
- // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop in_reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
+ omp.taskloop in_reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop reduction(@add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop reduction(@add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
+ omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop in_reduction(@add_f32 -> %testf32 : !llvm.ptr) reduction(@add_f32 -> %testf32_2 : !llvm.ptr)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
+ omp.taskloop in_reduction(@add_f32 -> %testf32 : !llvm.ptr) reduction(@add_f32 -> %testf32_2 : !llvm.ptr) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
%testi32 = "test.i32"() : () -> (i32)
- // CHECK: omp.taskloop priority(%{{[^:]+}}: i32)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop priority(%testi32: i32)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop priority(%{{[^:]+}}: i32) {
+ omp.taskloop priority(%testi32: i32) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
%testmemref = "test.memref"() : () -> (memref<i32>)
- // CHECK: omp.taskloop allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
- omp.taskloop allocate(%testmemref : memref<i32> -> %testmemref : memref<i32>)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>) {
+ omp.taskloop allocate(%testmemref : memref<i32> -> %testmemref : memref<i32>) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
%testi64 = "test.i64"() : () -> (i64)
- // CHECK: omp.taskloop grain_size(%{{[^:]+}}: i64)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop grain_size(%testi64: i64)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop grain_size(%{{[^:]+}}: i64) {
+ omp.taskloop grain_size(%testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64)
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop num_tasks(%testi64: i64)
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64) {
+ omp.taskloop num_tasks(%testi64: i64) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
}
- // CHECK: omp.taskloop nogroup
- // CHECK-SAME: for (%{{.+}}, %{{.+}}) : i32 = (%{{.+}}, %{{.+}}) to (%{{.+}}, %{{.+}}) step (%{{.+}}, %{{.+}}) {
- omp.taskloop nogroup
- for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
- // CHECK: omp.terminator
- omp.terminator
+ // CHECK: omp.taskloop nogroup {
+ omp.taskloop nogroup {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+
+ // CHECK: omp.taskloop {
+ omp.taskloop {
+ // TODO Remove induction variables from omp.simdloop.
+ omp.simdloop for (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+ // CHECK: omp.yield
+ omp.yield
+ }
+ // CHECK: omp.yield
+ omp.yield
+ }
}
// CHECK: return
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 627ac54cf145..61a5f2a96e1c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1943,14 +1943,6 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
return %shuffle : vector<5xi32>
}
-// CHECK-LABEL: func @shuffle_nofold2
-// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32>
-// CHECK: return %[[V]]
-func.func @shuffle_nofold2(%v0 : vector<[4]xi32>, %v1 : vector<[2]xi32>) -> vector<4xi32> {
- %shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32>
- return %shuffle : vector<4xi32>
-}
-
// -----
// CHECK-LABEL: func @transpose_scalar_broadcast1
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c16f1cb2876d..c9f7e9c6e2fb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -84,6 +84,13 @@ func.func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>
// -----
+func.func @shuffle_scalable_vec(%arg0: vector<[2]xf32>, %arg1: vector<[2]xf32>) {
+ // expected-error@+1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any type values}}
+ %1 = vector.shuffle %arg0, %arg1 [0, 1, 2, 3] : vector<[2]xf32>, vector<[2]xf32>
+}
+
+// -----
+
func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
// expected-error@+1 {{'vector.shuffle' op invalid mask length}}
%1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir b/mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir
new file mode 100644
index 000000000000..05a78e32b9e1
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir
@@ -0,0 +1,112 @@
+// DEFINE: %{tosa-to-linalg-pipeline} = -pass-pipeline="builtin.module(func.func(tosa-infer-shapes,tosa-to-linalg-named,tosa-to-linalg,tosa-to-arith))"
+
+// RUN: mlir-opt %s \
+// RUN: %{tosa-to-linalg-pipeline} \
+// RUN: | mlir-opt \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -buffer-deallocation-pipeline \
+// RUN: -test-lower-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils \
+// RUN: | FileCheck %s
+
+// Validate that the TOSA lowering for tosa.max_pool2d produces the same results when
+// for fully static and fully dynamic inputs.
+
+!tensor_type = tensor<1x4x4x1xf32>
+!memref_type = memref<1x4x4x1xf32>
+
+// Utility functions
+func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func.func @max_pool_static(%arg0: !tensor_type) -> (!tensor_type) {
+ %0 = tosa.max_pool2d %arg0 {
+ pad = array<i64: 1, 1, 1, 1>,
+ kernel = array<i64: 3, 3>,
+ stride = array<i64: 1, 1>
+ } : (tensor<1x4x4x1xf32>) -> tensor<1x4x4x1xf32>
+ return %0 : tensor<1x4x4x1xf32>
+}
+
+func.func @max_pool_dynamic(%arg0: tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>) {
+ %0 = tosa.max_pool2d %arg0 {
+ pad = array<i64: 1, 1, 1, 1>,
+ kernel = array<i64: 3, 3>,
+ stride = array<i64: 1, 1>
+ } : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// Test harness to compare the results of a fully statically shaped max_pool2d with
+// a fully dynamically shaped max_pool2d on the same inputs.
+func.func @main() {
+ %A = arith.constant dense<[[
+ [[0.0], [0.1], [0.2], [0.3]], // H = 0
+ [[1.0], [1.1], [1.2], [1.3]], // H = 1
+ [[2.0], [2.1], [2.2], [2.3]], // H = 2
+ [[3.0], [3.1], [3.2], [3.3]] // H = 3
+ ]]> : tensor<1x4x4x1xf32>
+
+ %A_dynamic = tensor.cast %A : !tensor_type to tensor<?x?x?x?xf32>
+
+ // Call both static and dynamically sized variants
+ %result_static = func.call @max_pool_static(%A) : (!tensor_type) -> !tensor_type
+ %result_dynamic = func.call @max_pool_dynamic(%A_dynamic) : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+ %static_buffer = bufferization.to_memref %result_static : !memref_type
+ %unranked_static_buffer = memref.cast %static_buffer : !memref_type to memref<*xf32>
+
+ // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 4, 4, 1] strides = [16, 4, 1, 1] data =
+
+ // CHECK-NEXT: 1.1
+ // CHECK-NEXT: 1.2
+ // CHECK-NEXT: 1.3
+ // CHECK-NEXT: 1.3
+
+ // CHECK-NEXT: 2.1
+ // CHECK-NEXT: 2.2
+ // CHECK-NEXT: 2.3
+ // CHECK-NEXT: 2.3
+
+ // CHECK-NEXT: 3.1
+ // CHECK-NEXT: 3.2
+ // CHECK-NEXT: 3.3
+ // CHECK-NEXT: 3.3
+
+ // CHECK-NEXT: 3.1
+ // CHECK-NEXT: 3.2
+ // CHECK-NEXT: 3.3
+ // CHECK-NEXT: 3.3
+
+ func.call @printMemrefF32(%unranked_static_buffer) : (memref<*xf32>) -> ()
+
+ %dynamic_buffer = bufferization.to_memref %result_dynamic : memref<?x?x?x?xf32>
+ %unranked_dynamic_buffer = memref.cast %dynamic_buffer : memref<?x?x?x?xf32> to memref<*xf32>
+
+ // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 4, 4, 1] strides = [16, 4, 1, 1] data =
+ // CHECK-NEXT: 1.1
+ // CHECK-NEXT: 1.2
+ // CHECK-NEXT: 1.3
+ // CHECK-NEXT: 1.3
+
+ // CHECK-NEXT: 2.1
+ // CHECK-NEXT: 2.2
+ // CHECK-NEXT: 2.3
+ // CHECK-NEXT: 2.3
+
+ // CHECK-NEXT: 3.1
+ // CHECK-NEXT: 3.2
+ // CHECK-NEXT: 3.3
+ // CHECK-NEXT: 3.3
+
+ // CHECK-NEXT: 3.1
+ // CHECK-NEXT: 3.2
+ // CHECK-NEXT: 3.3
+ // CHECK-NEXT: 3.3
+
+ func.call @printMemrefF32(%unranked_dynamic_buffer) : (memref<*xf32>) -> ()
+
+ return
+}
+
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 81a6eadbadd3..bf6847a32ff4 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -597,7 +597,7 @@ define void @ushl_sat_test(i32 %0, i32 %1, <8 x i32> %2, <8 x i32> %3) {
}
; CHECK-LABEL: llvm.func @va_intrinsics_test
-define void @va_intrinsics_test(ptr %0, ptr %1) {
+define void @va_intrinsics_test(ptr %0, ptr %1, ...) {
; CHECK: llvm.intr.vastart %{{.*}}
call void @llvm.va_start.p0(ptr %0)
; CHECK: llvm.intr.vacopy %{{.*}} to %{{.*}}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6730f9b292ad..b098a5a23fd3 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
FailureOr<OpFoldResult> reified = failure();
if (constant) {
auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
- boundType, value, dim, /*stopCondition=*/nullptr);
+ boundType, {value, dim}, /*stopCondition=*/nullptr);
if (succeeded(reifiedConst))
reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
} else if (scalable) {
@@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
rewriter, loc, reifiedScalable->map, vscaleOperand);
}
} else {
- if (dim) {
- if (useArithOps) {
- reified = arith::reifyShapedValueDimBound(
- rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
- } else {
- reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
- value, *dim, stopCondition);
- }
+ if (useArithOps) {
+ reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
+ op.getVariable(), stopCondition);
} else {
- if (useArithOps) {
- reified = arith::reifyIndexValueBound(
- rewriter, op->getLoc(), boundType, value, stopCondition);
- } else {
- reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
- value, stopCondition);
- }
+ reified = reifyValueBound(rewriter, op->getLoc(), boundType,
+ op.getVariable(), stopCondition);
}
}
if (failed(reified)) {
@@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
}
auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
- return ValueBoundsConstraintSet::compare(
- /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
- /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+ return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
};
if (compare(cmpType)) {
op->emitRemark("true");
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 25c5190ca0ef..a23ed89c4b04 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() {
return success();
}
+::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
::mlir::ValueBoundsConstraintSet::ComparisonOperator
CompareOp::getComparisonOperator() {
if (getCmp() == "EQ")
@@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() {
llvm_unreachable("invalid comparison operator");
}
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != size_t(expectedNumOperands))
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ebf158b8bb82..b641b3da719c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let extraClassDeclaration = [{
::mlir::presburger::BoundType getBoundType();
+ ::mlir::ValueBoundsConstraintSet::Variable getVariable();
}];
let hasVerifier = 1;
@@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> {
Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
specified comparison operator was proven to hold. The remark also indicates
whether the opposite comparison operator was proven to hold.
+
+ `var_operands` must have exactly two operands: one for the LHS operand and
+ one for the RHS operand. If `lhs_map` is specified, as many operands as
+ `lhs_map` has inputs are expected instead of the first operand. If `rhs_map`
+ is specified, as many operands as `rhs_map` has inputs are expected instead
+ of the second operand.
}];
- let arguments = (ins Index:$lhs,
- Index:$rhs,
+ let arguments = (ins Variadic<Index>:$var_operands,
DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+ OptionalAttr<AffineMapAttr>:$lhs_map,
+ OptionalAttr<AffineMapAttr>:$rhs_map,
UnitAttr:$compose);
let results = (outs);
let extraClassDeclaration = [{
::mlir::ValueBoundsConstraintSet::ComparisonOperator
getComparisonOperator();
+ ::mlir::ValueBoundsConstraintSet::Variable getLhs();
+ ::mlir::ValueBoundsConstraintSet::Variable getRhs();
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//