1//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
10
11namespace mlir::vector {
12
13FailureOr<ConstantOrScalableBound::BoundSize>
14ConstantOrScalableBound::getSize() const {
15 if (map.isSingleConstant())
16 return BoundSize{.baseSize: map.getSingleConstantResult(), /*scalable=*/false};
17 if (map.getNumResults() != 1 || map.getNumInputs() != 1)
18 return failure();
19 auto binop = dyn_cast<AffineBinaryOpExpr>(Val: map.getResult(idx: 0));
20 if (!binop || binop.getKind() != AffineExprKind::Mul)
21 return failure();
22 auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
23 if (auto cst = dyn_cast<AffineConstantExpr>(Val&: expr)) {
24 constant = cst.getValue();
25 return true;
26 }
27 return false;
28 };
29 // Match `s0 * cst` or `cst * s0`:
30 int64_t cst = 0;
31 auto lhs = binop.getLHS();
32 auto rhs = binop.getRHS();
33 if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(Val: rhs)) ||
34 (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(Val: lhs))) {
35 return BoundSize{.baseSize: cst, /*scalable=*/true};
36 }
37 return failure();
38}
39
40char ScalableValueBoundsConstraintSet::ID = 0;
41
42FailureOr<ConstantOrScalableBound>
43ScalableValueBoundsConstraintSet::computeScalableBound(
44 Value value, std::optional<int64_t> dim, unsigned vscaleMin,
45 unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
46 StopConditionFn stopCondition) {
47 using namespace presburger;
48 assert(vscaleMin <= vscaleMax);
49
50 // No stop condition specified: Keep adding constraints until the worklist
51 // is empty.
52 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
53 mlir::ValueBoundsConstraintSet &cstr) {
54 return false;
55 };
56
57 ScalableValueBoundsConstraintSet scalableCstr(
58 value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
59 vscaleMin, vscaleMax);
60 int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
61 scalableCstr.processWorklist();
62
63 // Check the resulting constraints set is valid.
64 if (scalableCstr.cstr.isEmpty()) {
65 return failure();
66 }
67
68 // Project out all columns apart from vscale and the starting point
69 // (value/dim). This should result in constraints in terms of vscale only.
70 auto projectOutFn = [&](ValueDim p) {
71 bool isStartingPoint =
72 p.first == value &&
73 p.second == dim.value_or(u: ValueBoundsConstraintSet::kIndexValue);
74 return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
75 };
76 scalableCstr.projectOut(condition: projectOutFn);
77 scalableCstr.projectOutAnonymous(/*except=*/pos);
78 // Also project out local variables (these are not tracked by the
79 // ValueBoundsConstraintSet).
80 for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) {
81 scalableCstr.cstr.projectOut(pos: scalableCstr.cstr.getNumDimAndSymbolVars());
82 }
83
84 assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
85 scalableCstr.positionToValueDim.size() &&
86 "inconsistent mapping state");
87
88 // Check that the only columns left are vscale and the starting point.
89 for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
90 if (i == pos)
91 continue;
92 if (scalableCstr.positionToValueDim[i] !=
93 ValueDim(scalableCstr.getVscaleValue(),
94 ValueBoundsConstraintSet::kIndexValue)) {
95 return failure();
96 }
97 }
98
99 SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
100 scalableCstr.cstr.getSliceBounds(offset: pos, num: 1, context: value.getContext(), lbMaps: &lowerBound,
101 ubMaps: &upperBound, closedUB);
102
103 auto invalidBound = [](auto &bound) {
104 return !bound[0] || bound[0].getNumResults() != 1;
105 };
106
107 AffineMap bound = [&] {
108 if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
109 lowerBound[0] == upperBound[0]) {
110 return lowerBound[0];
111 } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
112 return lowerBound[0];
113 } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
114 return upperBound[0];
115 }
116 return AffineMap{};
117 }();
118
119 if (!bound)
120 return failure();
121
122 return ConstantOrScalableBound{.map: bound};
123}
124
125} // namespace mlir::vector
126

source code of mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp