diff options
Diffstat (limited to 'mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h')
-rw-r--r-- | mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index bdfd689c7ac4..83107a3f5f94 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -117,8 +117,9 @@ public: /// /// The first parameter of the function is the shaped value/index-typed /// value. The second parameter is the dimension in case of a shaped value. - using StopConditionFn = - function_ref<bool(Value, std::optional<int64_t> /*dim*/)>; + /// The third parameter is this constraint set. + using StopConditionFn = std::function<bool( + Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>; /// Compute a bound for the given index-typed value or shape dimension size. /// The computed bound is stored in `resultMap`. The operands of the bound are @@ -271,22 +272,20 @@ protected: /// An index-typed value or the dimension of a shaped-type value. using ValueDim = std::pair<Value, int64_t>; - ValueBoundsConstraintSet(MLIRContext *ctx); + ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition); /// Populates the constraint set for a value/map without actually computing /// the bound. Returns the position for the value/map (via the return value /// and `posOut` output parameter). int64_t populateConstraintsSet(Value value, - std::optional<int64_t> dim = std::nullopt, - StopConditionFn stopCondition = nullptr); + std::optional<int64_t> dim = std::nullopt); int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands, - StopConditionFn stopCondition = nullptr, int64_t *posOut = nullptr); /// Iteratively process all elements on the worklist until an index-typed /// value or shaped value meets `stopCondition`. Such values are not processed /// any further. - void processWorklist(StopConditionFn stopCondition); + void processWorklist(); /// Bound the given column in the underlying constraint set by the given /// expression. @@ -333,6 +332,9 @@ protected: /// Builder for constructing affine expressions. Builder builder; + + /// The current stop condition function. + StopConditionFn stopCondition = nullptr; }; } // namespace mlir |