summaryrefslogtreecommitdiffstats
path: root/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
blob: 31e19ff1ad39f7f1b9bd9edb83ff7612de3158e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H

#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

namespace mlir::vector {

namespace detail {

/// Parent class for the value bounds RTTIExtends. Uses protected inheritance to
/// hide all ValueBoundsConstraintSet methods by default (as some do not use the
/// ScalableValueBoundsConstraintSet, so may produce unexpected results).
struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
  using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet;
};
} // namespace detail

/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds.
struct ScalableValueBoundsConstraintSet
    : public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
                               detail::ValueBoundsConstraintSet> {
  ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
                                   unsigned vscaleMax)
      : RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};

  using RTTIExtends::bound;
  using RTTIExtends::StopConditionFn;

  /// A thin wrapper over an `AffineMap` which can represent a constant bound,
  /// or a scalable bound (in terms of vscale). The `AffineMap` will always
  /// take at most one parameter, vscale, and returns a single result, which is
  /// the bound of value.
  struct ConstantOrScalableBound {
    AffineMap map;

    struct BoundSize {
      int64_t baseSize{0};
      bool scalable{false};
    };

    /// Get the (possibly) scalable size of the bound, returns failure if
    /// the bound cannot be represented as a single quantity.
    FailureOr<BoundSize> getSize() const;
  };

  /// Computes a (possibly) scalable bound for a given value. This is
  /// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but
  /// uses knowledge of the range of vscale to compute either a constant
  /// bound, an expression in terms of vscale, or failure if no bound can
  /// be computed.
  ///
  /// The resulting `AffineMap` will always take at most one parameter,
  /// vscale, and return a single result, which is the bound of `value`.
  ///
  /// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
  /// `vscaleMax`, the resulting bound (if found), will be constant.
  static FailureOr<ConstantOrScalableBound>
  computeScalableBound(Value value, std::optional<int64_t> dim,
                       unsigned vscaleMin, unsigned vscaleMax,
                       presburger::BoundType boundType, bool closedUB = true,
                       StopConditionFn stopCondition = nullptr);

  /// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
  Value getVscaleValue() const { return vscale; }

  /// Sets the value of vscale. Asserts if vscale has already been set.
  void setVscale(vector::VectorScaleOp vscaleOp) {
    assert(!vscale && "expected vscale to be unset");
    vscale = vscaleOp.getResult();
  }

  /// The minimum possible value of vscale.
  unsigned getVscaleMin() const { return vscaleMin; }

  /// The maximum possible value of vscale.
  unsigned getVscaleMax() const { return vscaleMax; }

  static char ID;

private:
  const unsigned vscaleMin;
  const unsigned vscaleMax;

  // This will be set when the first `vector.vscale` operation is found within
  // the `ValueBoundsOpInterface` implementation then reused from there on.
  Value vscale = nullptr;
};

using ConstantOrScalableBound =
    ScalableValueBoundsConstraintSet::ConstantOrScalableBound;

} // namespace mlir::vector

#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H