summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Interfaces/ValueBoundsOpInterface.cpp')
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp94
1 files changed, 44 insertions, 50 deletions
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 99598f2e89d9..0d362c7efa0a 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,11 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
- : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+ MLIRContext *ctx, StopConditionFn stopCondition)
+ : builder(ctx), stopCondition(stopCondition) {
+ assert(stopCondition && "expected non-null stop condition");
+}
char ValueBoundsConstraintSet::ID = 0;
@@ -193,7 +196,8 @@ static Operation *getOwnerOfValue(Value value) {
return value.getDefiningOp();
}
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+ LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
int64_t pos = worklist.front();
worklist.pop();
@@ -214,13 +218,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
// Do not process any further if the stop condition is met.
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
- if (stopCondition(value, maybeDim))
+ if (stopCondition(value, maybeDim, *this)) {
+ LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+ << " (dim: " << maybeDim << ")\n");
continue;
+ }
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
// the worklist.
auto valueBoundsOp =
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+ LLVM_DEBUG(llvm::dbgs()
+ << "Query value bounds for: " << value
+ << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
if (valueBoundsOp) {
if (dim == kIndexValue) {
valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -229,6 +239,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
}
continue;
}
+ LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
// If the op does not implement `ValueBoundsOpInterface`, check if it
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -278,32 +289,20 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
bool closedUB) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
- assert(!stopCondition(value, dim) &&
- "stop condition should not be satisfied for starting point");
#endif // NDEBUG
int64_t ubAdjustment = closedUB ? 0 : 1;
Builder b(value.getContext());
mapOperands.clear();
- if (stopCondition(value, dim)) {
- // Special case: If the stop condition is satisfied for the input
- // value/dimension, directly return it.
- mapOperands.push_back(std::make_pair(value, dim));
- AffineExpr bound = b.getAffineDimExpr(0);
- if (type == BoundType::UB)
- bound = bound + ubAdjustment;
- resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
- b.getAffineDimExpr(0));
- return success();
- }
-
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
- ValueBoundsConstraintSet cstr(value.getContext());
+ ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+ assert(!stopCondition(value, dim, cstr) &&
+ "stop condition should not be satisfied for starting point");
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
- cstr.processWorklist(stopCondition);
+ cstr.processWorklist();
// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
@@ -313,7 +312,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
- return !stopCondition(p.first, maybeDim);
+ return !stopCondition(p.first, maybeDim, cstr);
});
// Compute lower and upper bounds for `valueDim`.
@@ -419,7 +418,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
bool closedUB) {
return computeBound(
resultMap, mapOperands, type, value, dim,
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
closedUB);
@@ -455,7 +454,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
resultMap, mapOperands, type, value, dim,
- [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+ return isIndependent(v);
+ },
closedUB);
}
@@ -488,21 +489,19 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList operands,
StopConditionFn stopCondition, bool closedUB) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
- ValueBoundsConstraintSet cstr(map.getContext());
- int64_t pos = 0;
- if (stopCondition) {
- cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
- } else {
- // No stop condition specified: Keep adding constraints until a bound could
- // be computed.
- cstr.populateConstraintsSet(
- map, operands,
- [&](Value v, std::optional<int64_t> dim) {
- return cstr.cstr.getConstantBound64(type, pos).has_value();
- },
- &pos);
- }
+ // Default stop condition if none was specified: Keep adding constraints until
+ // a bound could be computed.
+ int64_t pos;
+ auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ return cstr.cstr.getConstantBound64(type, pos).has_value();
+ };
+
+ ValueBoundsConstraintSet cstr(
+ map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+ cstr.populateConstraintsSet(map, operands, &pos);
+
// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
if (auto bound = cstr.cstr.getConstantBound64(type, pos))
@@ -510,8 +509,9 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(
- Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
+int64_t
+ValueBoundsConstraintSet::populateConstraintsSet(Value value,
+ std::optional<int64_t> dim) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG
@@ -519,12 +519,12 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
AffineMap map =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
Builder(value.getContext()).getAffineDimExpr(0));
- return populateConstraintsSet(map, {{value, dim}}, stopCondition);
+ return populateConstraintsSet(map, {{value, dim}});
}
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(
- AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
- int64_t *posOut) {
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
+ ValueDimList operands,
+ int64_t *posOut) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(/*isSymbol=*/false);
if (posOut)
@@ -545,13 +545,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
- if (stopCondition) {
- processWorklist(stopCondition);
- } else {
- // No stop condition specified: Keep adding constraints until the worklist
- // is empty.
- processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
- }
+ processWorklist();
return pos;
}