1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
namespace mlir::vector {
FailureOr<ConstantOrScalableBound::BoundSize>
ConstantOrScalableBound::getSize() const {
if (map.isSingleConstant())
return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
if (map.getNumResults() != 1 || map.getNumInputs() != 1)
return failure();
auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
if (!binop || binop.getKind() != AffineExprKind::Mul)
return failure();
auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
constant = cst.getValue();
return true;
}
return false;
};
// Match `s0 * cst` or `cst * s0`:
int64_t cst = 0;
auto lhs = binop.getLHS();
auto rhs = binop.getRHS();
if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
(matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
return BoundSize{cst, /*scalable=*/true};
}
return failure();
}
char ScalableValueBoundsConstraintSet::ID = 0;
FailureOr<ConstantOrScalableBound>
ScalableValueBoundsConstraintSet::computeScalableBound(
Value value, std::optional<int64_t> dim, unsigned vscaleMin,
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
StopConditionFn stopCondition) {
using namespace presburger;
assert(vscaleMin <= vscaleMax);
// No stop condition specified: Keep adding constraints until the worklist
// is empty.
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
mlir::ValueBoundsConstraintSet &cstr) {
return false;
};
ScalableValueBoundsConstraintSet scalableCstr(
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
vscaleMin, vscaleMax);
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
auto projectOutFn = [&](ValueDim p) {
return p.first != scalableCstr.getVscaleValue();
};
scalableCstr.projectOut(projectOutFn);
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
"inconsistent mapping state");
// Check that the only symbols left are vscale.
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
if (i == pos)
continue;
if (scalableCstr.positionToValueDim[i] !=
ValueDim(scalableCstr.getVscaleValue(),
ValueBoundsConstraintSet::kIndexValue)) {
return failure();
}
}
SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
&upperBound, closedUB);
auto invalidBound = [](auto &bound) {
return !bound[0] || bound[0].getNumResults() != 1;
};
AffineMap bound = [&] {
if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
lowerBound[0] == lowerBound[0]) {
return lowerBound[0];
} else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
return lowerBound[0];
} else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
return upperBound[0];
}
return AffineMap{};
}();
if (!bound)
return failure();
return ConstantOrScalableBound{bound};
}
} // namespace mlir::vector
|