summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
blob: 1e13e60068ee7f6249a3c28d60041f43e5c8b8dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

using namespace mlir;
using presburger::BoundType;

namespace mlir {
namespace scf {
namespace {

struct ForOpInterface
    : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {

  /// Populate bounds of values/dimensions for iter_args/OpResults.
  static void populateIterArgBounds(scf::ForOp forOp, Value value,
                                    std::optional<int64_t> dim,
                                    ValueBoundsConstraintSet &cstr) {
    // `value` is an iter_arg or an OpResult.
    int64_t iterArgIdx;
    if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
      iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
    } else {
      iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
    }

    // An EQ constraint can be added if the yielded value (dimension size)
    // equals the corresponding block argument (dimension size).
    Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
                             .getOperand(iterArgIdx);
    Value iterArg = forOp.getRegionIterArg(iterArgIdx);
    Value initArg = forOp.getInitArgs()[iterArgIdx];

    auto addEqBound = [&]() {
      if (dim.has_value()) {
        cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
      } else {
        cstr.bound(value) == initArg;
      }
    };

    if (yieldedValue == iterArg) {
      addEqBound();
      return;
    }

    // Compute EQ bound for yielded value.
    AffineMap bound;
    ValueDimList boundOperands;
    LogicalResult status = ValueBoundsConstraintSet::computeBound(
        bound, boundOperands, BoundType::EQ, yieldedValue, dim,
        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
          // Stop when reaching a block argument of the loop body.
          if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
            return bbArg.getOwner()->getParentOp() == forOp;
          // Stop when reaching a value that is defined outside of the loop. It
          // is impossible to reach an iter_arg from there.
          Operation *op = v.getDefiningOp();
          return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
        });
    if (failed(status))
      return;
    if (bound.getNumResults() != 1)
      return;

    // Check if computed bound equals the corresponding iter_arg.
    Value singleValue = nullptr;
    std::optional<int64_t> singleDim;
    if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
      int64_t idx = dimExpr.getPosition();
      singleValue = boundOperands[idx].first;
      singleDim = boundOperands[idx].second;
    } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
      int64_t idx = symExpr.getPosition() + bound.getNumDims();
      singleValue = boundOperands[idx].first;
      singleDim = boundOperands[idx].second;
    }
    if (singleValue == iterArg && singleDim == dim)
      addEqBound();
  }

  void populateBoundsForIndexValue(Operation *op, Value value,
                                   ValueBoundsConstraintSet &cstr) const {
    auto forOp = cast<ForOp>(op);

    if (value == forOp.getInductionVar()) {
      // TODO: Take into account step size.
      cstr.bound(value) >= forOp.getLowerBound();
      cstr.bound(value) < forOp.getUpperBound();
      return;
    }

    // Handle iter_args and OpResults.
    populateIterArgBounds(forOp, value, std::nullopt, cstr);
  }

  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
                                       ValueBoundsConstraintSet &cstr) const {
    auto forOp = cast<ForOp>(op);
    // Handle iter_args and OpResults.
    populateIterArgBounds(forOp, value, dim, cstr);
  }
};

} // namespace
} // namespace scf
} // namespace mlir

void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
    DialectRegistry &registry) {
  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
    scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
  });
}