diff options
Diffstat (limited to 'mlir/lib/Interfaces/ValueBoundsOpInterface.cpp')
-rw-r--r-- | mlir/lib/Interfaces/ValueBoundsOpInterface.cpp | 94 |
1 files changed, 44 insertions, 50 deletions
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 99598f2e89d9..0d362c7efa0a 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -67,8 +67,11 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { return std::nullopt; } -ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx) - : builder(ctx) {} +ValueBoundsConstraintSet::ValueBoundsConstraintSet( + MLIRContext *ctx, StopConditionFn stopCondition) + : builder(ctx), stopCondition(stopCondition) { + assert(stopCondition && "expected non-null stop condition"); +} char ValueBoundsConstraintSet::ID = 0; @@ -193,7 +196,8 @@ static Operation *getOwnerOfValue(Value value) { return value.getDefiningOp(); } -void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { +void ValueBoundsConstraintSet::processWorklist() { + LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); while (!worklist.empty()) { int64_t pos = worklist.front(); worklist.pop(); @@ -214,13 +218,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { // Do not process any further if the stop condition is met. auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); - if (stopCondition(value, maybeDim)) + if (stopCondition(value, maybeDim, *this)) { + LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value + << " (dim: " << maybeDim << ")\n"); continue; + } // Query `ValueBoundsOpInterface` for constraints. New items may be added to // the worklist. auto valueBoundsOp = dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value)); + LLVM_DEBUG(llvm::dbgs() + << "Query value bounds for: " << value + << " (owner: " << getOwnerOfValue(value)->getName() << ")\n"); if (valueBoundsOp) { if (dim == kIndexValue) { valueBoundsOp.populateBoundsForIndexValue(value, *this); @@ -229,6 +239,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { } continue; } + LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n"); // If the op does not implement `ValueBoundsOpInterface`, check if it // implements the `DestinationStyleOpInterface`. OpResults of such ops are @@ -278,32 +289,20 @@ LogicalResult ValueBoundsConstraintSet::computeBound( bool closedUB) { #ifndef NDEBUG assertValidValueDim(value, dim); - assert(!stopCondition(value, dim) && - "stop condition should not be satisfied for starting point"); #endif // NDEBUG int64_t ubAdjustment = closedUB ? 0 : 1; Builder b(value.getContext()); mapOperands.clear(); - if (stopCondition(value, dim)) { - // Special case: If the stop condition is satisfied for the input - // value/dimension, directly return it. - mapOperands.push_back(std::make_pair(value, dim)); - AffineExpr bound = b.getAffineDimExpr(0); - if (type == BoundType::UB) - bound = bound + ubAdjustment; - resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - b.getAffineDimExpr(0)); - return success(); - } - // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - ValueBoundsConstraintSet cstr(value.getContext()); + ValueBoundsConstraintSet cstr(value.getContext(), stopCondition); + assert(!stopCondition(value, dim, cstr) && + "stop condition should not be satisfied for starting point"); int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false); - cstr.processWorklist(stopCondition); + cstr.processWorklist(); // Project out all variables (apart from `valueDim`) that do not match the // stop condition. @@ -313,7 +312,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound( return false; auto maybeDim = p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); - return !stopCondition(p.first, maybeDim); + return !stopCondition(p.first, maybeDim, cstr); }); // Compute lower and upper bounds for `valueDim`. @@ -419,7 +418,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound( bool closedUB) { return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional<int64_t> d) { + [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { return llvm::is_contained(dependencies, std::make_pair(v, d)); }, closedUB); @@ -455,7 +454,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( // Reify bounds in terms of any independent values. return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional<int64_t> d) { return isIndependent(v); }, + [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { + return isIndependent(v); + }, closedUB); } @@ -488,21 +489,19 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, AffineMap map, ValueDimList operands, StopConditionFn stopCondition, bool closedUB) { assert(map.getNumResults() == 1 && "expected affine map with one result"); - ValueBoundsConstraintSet cstr(map.getContext()); - int64_t pos = 0; - if (stopCondition) { - cstr.populateConstraintsSet(map, operands, stopCondition, &pos); - } else { - // No stop condition specified: Keep adding constraints until a bound could - // be computed. - cstr.populateConstraintsSet( - map, operands, - [&](Value v, std::optional<int64_t> dim) { - return cstr.cstr.getConstantBound64(type, pos).has_value(); - }, - &pos); - } + // Default stop condition if none was specified: Keep adding constraints until + // a bound could be computed. + int64_t pos; + auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }; + + ValueBoundsConstraintSet cstr( + map.getContext(), stopCondition ? stopCondition : defaultStopCondition); + cstr.populateConstraintsSet(map, operands, &pos); + // Compute constant bound for `valueDim`. int64_t ubAdjustment = closedUB ? 0 : 1; if (auto bound = cstr.cstr.getConstantBound64(type, pos)) @@ -510,8 +509,9 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( return failure(); } -int64_t ValueBoundsConstraintSet::populateConstraintsSet( - Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) { +int64_t +ValueBoundsConstraintSet::populateConstraintsSet(Value value, + std::optional<int64_t> dim) { #ifndef NDEBUG assertValidValueDim(value, dim); #endif // NDEBUG @@ -519,12 +519,12 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet( AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, Builder(value.getContext()).getAffineDimExpr(0)); - return populateConstraintsSet(map, {{value, dim}}, stopCondition); + return populateConstraintsSet(map, {{value, dim}}); } -int64_t ValueBoundsConstraintSet::populateConstraintsSet( - AffineMap map, ValueDimList operands, StopConditionFn stopCondition, - int64_t *posOut) { +int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map, + ValueDimList operands, + int64_t *posOut) { assert(map.getNumResults() == 1 && "expected affine map with one result"); int64_t pos = insert(/*isSymbol=*/false); if (posOut) @@ -545,13 +545,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet( // Process the backward slice of `operands` (i.e., reverse use-def chain) // until `stopCondition` is met. - if (stopCondition) { - processWorklist(stopCondition); - } else { - // No stop condition specified: Keep adding constraints until the worklist - // is empty. - processWorklist([](Value v, std::optional<int64_t> dim) { return false; }); - } + processWorklist(); return pos; } |