summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-04-16 10:59:02 +0200
committerGitHub <noreply@github.com>2024-04-16 10:59:02 +0200
commit40dd3aa91d3f73184e34e45e597b84bec059c572 (patch)
treecbf224521b9828bdbee603f03e295564544d4ebe
parentd6d84b5d1448e4f2e24b467a0abcf42fe9d543e9 (diff)
[mlir][Interfaces] `Variable` abstraction for `ValueBoundsOpInterface` (#87980)
This commit generalizes and cleans up the `ValueBoundsConstraintSet` API. The API used to provide function overloads for comparing/computing bounds of: - index-typed SSA value - dimension of shaped value - affine map + operands This commit removes all overloads. There is now a single entry point for each `compare` variant and each `computeBound` variant. These functions now take a `Variable`, which is internally represented as an affine map and map operands. This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.
-rw-r--r--mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h11
-rw-r--r--mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h11
-rw-r--r--mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h119
-rw-r--r--mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp15
-rw-r--r--mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Padding.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp17
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Utils/Utils.cpp4
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp338
-rw-r--r--mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir24
-rw-r--r--mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp26
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp37
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td16
20 files changed, 361 insertions, 311 deletions
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 8e840e744064..1ea737522081 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
/// maximally compose chains of AffineApplyOps.
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition,
+ bool closedUB = false);
+
/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
index 970a52a06a11..bbc7e5d3e0dd 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
@@ -24,6 +24,17 @@ enum class BoundType;
namespace arith {
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition,
+ bool closedUB = false);
+
/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961c..ac17ace5a976 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include <queue>
@@ -111,6 +112,39 @@ protected:
public:
static char ID;
+ /// A variable that can be added to the constraint set as a "column". The
+ /// value bounds infrastructure can compute bounds for variables and compare
+ /// two variables.
+ ///
+ /// Internally, a variable is represented as an affine map and operands.
+ class Variable {
+ public:
+ /// Construct a variable for an index-typed attribute or SSA value.
+ Variable(OpFoldResult ofr);
+
+ /// Construct a variable for an index-typed SSA value.
+ Variable(Value indexValue);
+
+ /// Construct a variable for a dimension of a shaped value.
+ Variable(Value shapedValue, int64_t dim);
+
+ /// Construct a variable for an index-typed attribute/SSA value or for a
+ /// dimension of a shaped value. A non-null dimension must be provided if
+ /// and only if `ofr` is a shaped value.
+ Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+ /// Construct a variable for a map and its operands.
+ Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+ Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+ MLIRContext *getContext() const { return map.getContext(); }
+
+ private:
+ friend class ValueBoundsConstraintSet;
+ AffineMap map;
+ ValueDimList mapOperands;
+ };
+
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ public:
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
- /// Compute a bound for the given index-typed value or shape dimension size.
- /// The computed bound is stored in `resultMap`. The operands of the bound are
- /// stored in `mapOperands`. An operand is either an index-type SSA value
- /// or a shaped value and a dimension.
+ /// Compute a bound for the given variable. The computed bound is stored in
+ /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+ /// operand is either an index-type SSA value or a shaped value and a
+ /// dimension.
///
- /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
- /// is computed in terms of values/dimensions for which `stopCondition`
- /// evaluates to "true". To that end, the backward slice (reverse use-def
- /// chain) of the given value is visited in a worklist-driven manner and the
- /// constraint set is populated according to `ValueBoundsOpInterface` for each
- /// visited value.
+ /// The bound is computed in terms of values/dimensions for which
+ /// `stopCondition` evaluates to "true". To that end, the backward slice
+ /// (reverse use-def chain) of the given value is visited in a worklist-driven
+ /// manner and the constraint set is populated according to
+ /// `ValueBoundsOpInterface` for each visited value.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
- static LogicalResult computeBound(AffineMap &resultMap,
- ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim,
- StopConditionFn stopCondition,
- bool closedUB = false);
+ static LogicalResult
+ computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+ presburger::BoundType type, const Variable &var,
+ StopConditionFn stopCondition, bool closedUB = false);
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
/// computed bound consists of only constant terms and dependent values (or
/// dimension sizes thereof).
static LogicalResult
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueDimList dependencies, bool closedUB = false);
/// Compute a bound in that is independent of all values in `independencies`.
///
@@ -161,13 +191,10 @@ public:
/// appear in the computed bound.
static LogicalResult
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueRange independencies, bool closedUB = false);
- /// Compute a constant bound for the given affine map, where dims and symbols
- /// are bound to the given operands. The affine map must have exactly one
- /// result.
+ /// Compute a constant bound for the given variable.
///
/// This function traverses the backward slice of the given operands in a
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ public:
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t>
- computeConstantBound(presburger::BoundType type, Value value,
- std::optional<int64_t> dim = std::nullopt,
+ computeConstantBound(presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
/// Compute a constant delta between the given two values. Return "failure"
/// if a constant delta could not be determined.
@@ -221,9 +241,8 @@ public:
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
- bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
+ bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
+ const Variable &rhs);
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
/// specified relation could not be proven. This could be because the
@@ -233,24 +252,12 @@ public:
///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
- static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
- static bool compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands);
- static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands);
-
- /// Compute whether the given values/dimensions are equal. Return "failure" if
+ static bool compare(const Variable &lhs, ComparisonOperator cmp,
+ const Variable &rhs);
+
+ /// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
- ///
- /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
- /// index-typed.
- static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1 = std::nullopt,
- std::optional<int64_t> dim2 = std::nullopt);
+ static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +324,6 @@ protected:
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
- bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
/// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +378,7 @@ protected:
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+ int64_t insert(const Variable &var, bool isSymbol = true);
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
@@ -381,6 +386,8 @@ protected:
/// Project out all columns for which the condition holds.
void projectOut(function_ref<bool(ValueDim)> condition);
+ void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
/// Mapping of columns to values/shape dimensions.
SmallVector<std::optional<ValueDim>> positionToValueDim;
/// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f7..82a9fb0d4908 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
mapOperands.push_back(value1);
mapOperands.push_back(value2);
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
- ValueDimList valueDims;
- for (Value v : mapOperands)
- valueDims.push_back({v, std::nullopt});
return ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::EQ, map, valueDims);
+ presburger::BoundType::EQ,
+ ValueBoundsConstraintSet::Variable(map, mapOperands));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701a..1a266b72d1f8 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -16,16 +16,15 @@
using namespace mlir;
using namespace mlir::affine;
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB) {
+FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Reify bound.
@@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
- return reifyValueBound(b, loc, type, value, dim,
+ return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
@@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
- return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index f0d43808bc45..7cfcc4180539 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -107,9 +107,9 @@ struct SelectOpInterface
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
- if (cstr.compare(trueValue, dim,
+ if (cstr.compare(/*lhs=*/{trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- falseValue, dim)) {
+ /*rhs=*/{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
- if (cstr.compare(falseValue, dim,
+ if (cstr.compare(/*lhs=*/{falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- trueValue, dim)) {
+ /*rhs=*/{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e9..f87f3d6350c0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
return failure();
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, in,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(ub))
return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f19..5fb7953f9370 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
return buildExpr(map.getResult(0));
}
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB) {
+FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Materialize tensor.dim/memref.dim ops.
@@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
- return reifyValueBound(b, loc, type, value, dim,
+ return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
@@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
- return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489..518d2e138c02 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, opOperand->get(),
- /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+ presburger::BoundType::UB,
+ {opOperand->get(),
+ /*dim=*/i},
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d0..71eb59d40836 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
- Value materializedSize =
- getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, rangeValue.size,
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
- ? materializedSize
+ ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7..1f06318cbd60 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueRange independencies) {
if (ofr.is<Attribute>())
return ofr;
- Value value = ofr.get<Value>();
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
- boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+ /*closedUB=*/true)))
return failure();
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a83..17a1c016ea16 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
if (cstr.populateAndCompare(
- yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
- iterArg, dim)) {
+ /*lhs=*/{yieldedValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::EQ,
+ /*rhs=*/{iterArg, dim})) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
- cstr.bound(value) == initArg;
+ cstr.bound(value) == cstr.getExpr(initArg);
}
}
}
@@ -113,8 +114,9 @@ struct IfOpInterface
// * result <= elseValue
// * result >= thenValue
if (cstr.populateAndCompare(
- thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- elseValue, dim)) {
+ /*lhs=*/{thenValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{elseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
// * result <= thenValue
// * result >= elseValue
if (cstr.populateAndCompare(
- elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- thenValue, dim)) {
+ /*lhs=*/{elseValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{thenValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8e301c..d25efcf50ec5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
info.isAlignedToInnerTileSize = false;
FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB,
- getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+ presburger::BoundType::UB, tileSize,
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
if (!failed(cstSize) && cstInnerSize) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 721730862d49..a89ce20048df 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ independencies,
+ /*closedUB=*/true)))
return failure();
return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2f7a17..15381ec520e2 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), srcDim, resultDim);
+ {op.getSource(), srcDim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), dim, resultDim);
+ {op.getSource(), dim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++resultDim;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ffa4c0b55cad..87937591e60a 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,12 @@ namespace mlir {
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
} // namespace mlir
+static Operation *getOwnerOfValue(Value value) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
+ return bbArg.getOwner()->getParentOp();
+ return value.getDefiningOp();
+}
+
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides)
@@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
+ : Variable(ofr, std::nullopt) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
+ : Variable(static_cast<OpFoldResult>(indexValue)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
+ : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
+ std::optional<int64_t> dim) {
+ Builder b(ofr.getContext());
+ if (auto constInt = ::getConstantIntValue(ofr)) {
+ assert(!dim && "expected no dim for index-typed values");
+ map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+ b.getAffineConstantExpr(*constInt));
+ return;
+ }
+ Value value = cast<Value>(ofr);
+#ifndef NDEBUG
+ if (dim) {
+ assert(isa<ShapedType>(value.getType()) && "expected shaped type");
+ } else {
+ assert(value.getType().isIndex() && "expected index type");
+ }
+#endif // NDEBUG
+ map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+ b.getAffineSymbolExpr(0));
+ mapOperands.emplace_back(value, dim);
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+ ArrayRef<Variable> mapOperands) {
+ assert(map.getNumResults() == 1 && "expected single result");
+
+ // Turn all dims into symbols.
+ Builder b(map.getContext());
+ SmallVector<AffineExpr> dimReplacements, symReplacements;
+ for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
+ dimReplacements.push_back(b.getAffineSymbolExpr(i));
+ for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
+ symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
+ AffineMap tmpMap = map.replaceDimsAndSymbols(
+ dimReplacements, symReplacements, /*numResultDims=*/0,
+ /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
+
+ // Inline operands.
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ for (auto [index, var] : llvm::enumerate(mapOperands)) {
+ assert(var.map.getNumResults() == 1 && "expected single result");
+ assert(var.map.getNumDims() == 0 && "expected only symbols");
+ SmallVector<AffineExpr> symReplacements;
+ for (auto valueDim : var.mapOperands) {
+ auto it = llvm::find(this->mapOperands, valueDim);
+ if (it != this->mapOperands.end()) {
+ // There is already a symbol for this operand.
+ symReplacements.push_back(b.getAffineSymbolExpr(
+ std::distance(this->mapOperands.begin(), it)));
+ } else {
+ // This is a new operand: add a new symbol.
+ symReplacements.push_back(
+ b.getAffineSymbolExpr(this->mapOperands.size()));
+ this->mapOperands.push_back(valueDim);
+ }
+ }
+ replacements[b.getAffineSymbolExpr(index)] =
+ var.map.getResult(0).replaceSymbols(symReplacements);
+ }
+ this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
+ /*numResultSyms=*/this->mapOperands.size());
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+ ArrayRef<Value> mapOperands)
+ : Variable(map, llvm::map_to_vector(mapOperands,
+ [](Value v) { return Variable(v); })) {}
+
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
MLIRContext *ctx, StopConditionFn stopCondition)
: builder(ctx), stopCondition(stopCondition) {
@@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
assert(!valueDimToPosition.contains(valueDim) && "already mapped");
int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
: cstr.appendVar(VarKind::SetDim);
+ LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
+ << " for: " << value
+ << " (dim: " << dim.value_or(kIndexValue)
+ << ", owner: " << getOwnerOfValue(value)->getName()
+ << ")\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
: cstr.appendVar(VarKind::SetDim);
+ LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
+ << "\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
return pos;
}
+int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
+ return insert(var.map, var.mapOperands, isSymbol);
+}
+
int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
@@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
"unstructured control flow is not supported");
#endif // NDEBUG
-
+ LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
+ << " (dim: " << dim.value_or(kIndexValue)
+ << ", owner: " << getOwnerOfValue(value)->getName()
+ << ")\n");
auto it =
valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
assert(it != valueDimToPosition.end() && "expected mapped entry");
@@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value,
return it != valueDimToPosition.end();
}
-static Operation *getOwnerOfValue(Value value) {
- if (auto bbArg = dyn_cast<BlockArgument>(value))
- return bbArg.getOwner()->getParentOp();
- return value.getDefiningOp();
-}
-
void ValueBoundsConstraintSet::processWorklist() {
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
@@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut(
}
}
+void ValueBoundsConstraintSet::projectOutAnonymous(
+ std::optional<int64_t> except) {
+ int64_t nextPos = 0;
+ while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+ if (positionToValueDim[nextPos].has_value() || except == nextPos) {
+ ++nextPos;
+ } else {
+ projectOut(nextPos);
+ // The column was projected out so another column is now at that position.
+ // Do not increase the counter.
+ }
+ }
+}
+
LogicalResult ValueBoundsConstraintSet::computeBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, StopConditionFn stopCondition,
- bool closedUB) {
-#ifndef NDEBUG
- assertValidValueDim(value, dim);
-#endif // NDEBUG
-
+ const Variable &var, StopConditionFn stopCondition, bool closedUB) {
+ MLIRContext *ctx = var.getContext();
int64_t ubAdjustment = closedUB ? 0 : 1;
- Builder b(value.getContext());
+ Builder b(ctx);
mapOperands.clear();
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
- ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
- ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
- assert(!stopCondition(value, dim, cstr) &&
- "stop condition should not be satisfied for starting point");
- int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+ ValueBoundsConstraintSet cstr(ctx, stopCondition);
+ int64_t pos = cstr.insert(var, /*isSymbol=*/false);
+ assert(pos == 0 && "expected first column");
cstr.processWorklist();
// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
cstr.projectOut([&](ValueDim p) {
- // Do not project out `valueDim`.
- if (valueDim == p)
- return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
return !stopCondition(p.first, maybeDim, cstr);
});
+ cstr.projectOutAnonymous(/*except=*/pos);
// Compute lower and upper bounds for `valueDim`.
SmallVector<AffineMap> lb(1), ub(1);
- cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+ cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
/*closedUB=*/true);
// Note: There are TODOs in the implementation of `getSliceBounds`. In such a
@@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
LogicalResult ValueBoundsConstraintSet::computeDependentBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB) {
+ const Variable &var, ValueDimList dependencies, bool closedUB) {
return computeBound(
- resultMap, mapOperands, type, value, dim,
+ resultMap, mapOperands, type, var,
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
@@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
- Value value, std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB) {
+ const Variable &var, ValueRange independencies, bool closedUB) {
// Return "true" if the given value is independent of all values in
// `independencies`. I.e., neither the value itself nor any value in the
// backward slice (reverse use-def chain) is contained in `independencies`.
@@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
- resultMap, mapOperands, type, value, dim,
+ resultMap, mapOperands, type, var,
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return isIndependent(v);
},
@@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
}
FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, Value value, std::optional<int64_t> dim,
- StopConditionFn stopCondition, bool closedUB) {
-#ifndef NDEBUG
- assertValidValueDim(value, dim);
-#endif // NDEBUG
-
- AffineMap map =
- AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
- Builder(value.getContext()).getAffineDimExpr(0));
- return computeConstantBound(type, map, {{value, dim}}, stopCondition,
- closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+ presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition, bool closedUB) {
- ValueDimList valueDims;
- for (Value v : operands) {
- assert(v.getType().isIndex() && "expected index type");
- valueDims.emplace_back(v, std::nullopt);
- }
- return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList operands,
- StopConditionFn stopCondition, bool closedUB) {
- assert(map.getNumResults() == 1 && "expected affine map with one result");
-
// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos = 0;
@@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
};
ValueBoundsConstraintSet cstr(
- map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
- pos = cstr.populateConstraints(map, operands);
+ var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+ pos = cstr.populateConstraints(var.map, var.mapOperands);
assert(pos == 0 && "expected `map` is the first column");
// Compute constant bound for `valueDim`.
@@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
Builder b(value1.getContext());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
- return computeConstantBound(presburger::BoundType::EQ, map,
- {{value1, dim1}, {value2, dim2}});
+ return computeConstantBound(presburger::BoundType::EQ,
+ Variable(map, {{value1, dim1}, {value2, dim2}}));
}
-bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
- std::optional<int64_t> lhsDim,
- ComparisonOperator cmp,
- OpFoldResult rhs,
- std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
- if (auto lhsVal = dyn_cast<Value>(lhs))
- assertValidValueDim(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+ ComparisonOperator cmp,
+ int64_t rhsPos) {
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -642,50 +701,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
// EQ can be expressed as LE and GE.
if (cmp == EQ)
- return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
- compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
-
- // Construct inequality. For the above example: lhs > rhs.
- // `IntegerRelation` inequalities are expressed in the "flattened" form and
- // with ">= 0". I.e., lhs - rhs - 1 >= 0.
- SmallVector<int64_t> eq(cstr.getNumCols(), 0);
- auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
- int64_t factor) {
- if (auto constVal = ::getConstantIntValue(ofr)) {
- eq[cstr.getNumCols() - 1] += *constVal * factor;
- } else {
- eq[getPos(cast<Value>(ofr), dim)] += factor;
- }
- };
- if (cmp == LT || cmp == LE) {
- addToEq(lhs, lhsDim, 1);
- addToEq(rhs, rhsDim, -1);
- } else if (cmp == GT || cmp == GE) {
- addToEq(lhs, lhsDim, -1);
- addToEq(rhs, rhsDim, 1);
- } else {
- llvm_unreachable("unsupported comparison operator");
- }
- if (cmp == LE || cmp == GE)
- eq[cstr.getNumCols() - 1] -= 1;
-
- // Add inequality to the constraint set and check if it made the constraint
- // set empty.
- int64_t ineqPos = cstr.getNumInequalities();
- cstr.addInequality(eq);
- bool isEmpty = cstr.isEmpty();
- cstr.removeInequality(ineqPos);
- return isEmpty;
-}
-
-bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
- ComparisonOperator cmp,
- int64_t rhsPos) {
- // This function returns "true" if "lhs CMP rhs" is proven to hold. For
- // detailed documentation, see `compareValueDims`.
-
- // EQ can be expressed as LE and GE.
- if (cmp == EQ)
return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
@@ -712,48 +727,17 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
-bool ValueBoundsConstraintSet::populateAndCompare(
- OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
- OpFoldResult rhs, std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
- if (auto lhsVal = dyn_cast<Value>(lhs))
- assertValidValueDim(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
- if (auto lhsVal = dyn_cast<Value>(lhs))
- populateConstraints(lhsVal, lhsDim);
- if (auto rhsVal = dyn_cast<Value>(rhs))
- populateConstraints(rhsVal, rhsDim);
-
- return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs) {
+ int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
+ int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
+ return comparePos(lhsPos, cmp, rhsPos);
}
-bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
- std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim) {
- auto stopCondition = [&](Value v, std::optional<int64_t> dim,
- ValueBoundsConstraintSet &cstr) {
- // Keep processing as long as lhs/rhs are not mapped.
- if (auto lhsVal = dyn_cast<Value>(lhs))
- if (!cstr.isMapped(lhsVal, dim))
- return false;
- if (auto rhsVal = dyn_cast<Value>(rhs))
- if (!cstr.isMapped(rhsVal, dim))
- return false;
- // Keep processing as long as the relation cannot be proven.
- return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
- };
-
- ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
- return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
-}
-
-bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands) {
+bool ValueBoundsConstraintSet::compare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs) {
int64_t lhsPos = -1, rhsPos = -1;
auto stopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
@@ -765,39 +749,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
return cstr.comparePos(lhsPos, cmp, rhsPos);
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
- lhsPos = cstr.insert(lhs, lhsOperands);
- rhsPos = cstr.insert(rhs, rhsOperands);
- cstr.processWorklist();
+ lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+ rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
-bool ValueBoundsConstraintSet::compare(AffineMap lhs,
- ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands) {
- ValueDimList lhsValueDimOperands =
- llvm::map_to_vector(lhsOperands, [](Value v) {
- return std::make_pair(v, std::optional<int64_t>());
- });
- ValueDimList rhsValueDimOperands =
- llvm::map_to_vector(rhsOperands, [](Value v) {
- return std::make_pair(v, std::optional<int64_t>());
- });
- return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
- rhsValueDimOperands);
-}
-
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1,
- std::optional<int64_t> dim2) {
- if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
- value2, dim2))
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
+ const Variable &var2) {
+ if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
return true;
- if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
- value2, dim2) ||
- ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
- value2, dim2))
+ if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
+ ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
return false;
return failure();
}
@@ -833,7 +795,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
FailureOr<int64_t> constBound = computeConstantBound(
- presburger::BoundType::EQ, foldedMap, valueOperands);
+ presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
foundUnknownBound |= failed(constBound);
if (succeeded(constBound) && *constBound <= 0)
return false;
@@ -850,7 +812,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
FailureOr<int64_t> constBound = computeConstantBound(
- presburger::BoundType::EQ, foldedMap, valueOperands);
+ presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
foundUnknownBound |= failed(constBound);
if (succeeded(constBound) && *constBound <= 0)
return false;
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 23c6872dcebe..935c08aceff5 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -131,3 +131,27 @@ func.func @compare_affine_min(%a: index, %b: index) {
"test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
return
}
+
+// -----
+
+func.func @compare_const_map() {
+ %c5 = arith.constant 5 : index
+ // expected-remark @below{{true}}
+ "test.compare"(%c5) {cmp = "GT", rhs_map = affine_map<() -> (4)>}
+ : (index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%c5) {cmp = "LT", lhs_map = affine_map<() -> (4)>}
+ : (index) -> ()
+ return
+}
+
+// -----
+
+func.func @compare_maps(%a: index, %b: index) {
+ // expected-remark @below{{true}}
+ "test.compare"(%a, %b, %b, %a)
+ {cmp = "GT", lhs_map = affine_map<(d0, d1) -> (1 + d0 + d1)>,
+ rhs_map = affine_map<(d0, d1) -> (d0 + d1)>}
+ : (index, index, index, index) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6730f9b292ad..b098a5a23fd3 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
FailureOr<OpFoldResult> reified = failure();
if (constant) {
auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
- boundType, value, dim, /*stopCondition=*/nullptr);
+ boundType, {value, dim}, /*stopCondition=*/nullptr);
if (succeeded(reifiedConst))
reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
} else if (scalable) {
@@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
rewriter, loc, reifiedScalable->map, vscaleOperand);
}
} else {
- if (dim) {
- if (useArithOps) {
- reified = arith::reifyShapedValueDimBound(
- rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
- } else {
- reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
- value, *dim, stopCondition);
- }
+ if (useArithOps) {
+ reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
+ op.getVariable(), stopCondition);
} else {
- if (useArithOps) {
- reified = arith::reifyIndexValueBound(
- rewriter, op->getLoc(), boundType, value, stopCondition);
- } else {
- reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
- value, stopCondition);
- }
+ reified = reifyValueBound(rewriter, op->getLoc(), boundType,
+ op.getVariable(), stopCondition);
}
}
if (failed(reified)) {
@@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
}
auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
- return ValueBoundsConstraintSet::compare(
- /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
- /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+ return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
};
if (compare(cmpType)) {
op->emitRemark("true");
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 25c5190ca0ef..36d7606fe134 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() {
return success();
}
+::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
::mlir::ValueBoundsConstraintSet::ComparisonOperator
CompareOp::getComparisonOperator() {
if (getCmp() == "EQ")
@@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() {
llvm_unreachable("invalid comparison operator");
}
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != expectedNumOperands)
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ebf158b8bb82..b641b3da719c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let extraClassDeclaration = [{
::mlir::presburger::BoundType getBoundType();
+ ::mlir::ValueBoundsConstraintSet::Variable getVariable();
}];
let hasVerifier = 1;
@@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> {
Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
specified comparison operator was proven to hold. The remark also indicates
whether the opposite comparison operator was proven to hold.
+
+ `var_operands` must have exactly two operands: one for the LHS operand and
+ one for the RHS operand. If `lhs_map` is specified, as many operands as
+ `lhs_map` has inputs are expected instead of the first operand. If `rhs_map`
+ is specified, as many operands as `rhs_map` has inputs are expected instead
+ of the second operand.
}];
- let arguments = (ins Index:$lhs,
- Index:$rhs,
+ let arguments = (ins Variadic<Index>:$var_operands,
DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+ OptionalAttr<AffineMapAttr>:$lhs_map,
+ OptionalAttr<AffineMapAttr>:$rhs_map,
UnitAttr:$compose);
let results = (outs);
let extraClassDeclaration = [{
::mlir::ValueBoundsConstraintSet::ComparisonOperator
getComparisonOperator();
+ ::mlir::ValueBoundsConstraintSet::Variable getLhs();
+ ::mlir::ValueBoundsConstraintSet::Variable getRhs();
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//