1//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
10#include "mlir/Dialect/Vector/IR/VectorOps.h"
11
12namespace mlir::vector {
13
14FailureOr<ConstantOrScalableBound::BoundSize>
15ConstantOrScalableBound::getSize() const {
16 if (map.isSingleConstant())
17 return BoundSize{.baseSize: map.getSingleConstantResult(), /*scalable=*/false};
18 if (map.getNumResults() != 1 || map.getNumInputs() != 1)
19 return failure();
20 auto binop = dyn_cast<AffineBinaryOpExpr>(Val: map.getResult(idx: 0));
21 if (!binop || binop.getKind() != AffineExprKind::Mul)
22 return failure();
23 auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
24 if (auto cst = dyn_cast<AffineConstantExpr>(Val&: expr)) {
25 constant = cst.getValue();
26 return true;
27 }
28 return false;
29 };
30 // Match `s0 * cst` or `cst * s0`:
31 int64_t cst = 0;
32 auto lhs = binop.getLHS();
33 auto rhs = binop.getRHS();
34 if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(Val: rhs)) ||
35 (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(Val: lhs))) {
36 return BoundSize{.baseSize: cst, /*scalable=*/true};
37 }
38 return failure();
39}
40
41char ScalableValueBoundsConstraintSet::ID = 0;
42
43FailureOr<ConstantOrScalableBound>
44ScalableValueBoundsConstraintSet::computeScalableBound(
45 Value value, std::optional<int64_t> dim, unsigned vscaleMin,
46 unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
47 StopConditionFn stopCondition) {
48 using namespace presburger;
49 assert(vscaleMin <= vscaleMax);
50
51 // No stop condition specified: Keep adding constraints until the worklist
52 // is empty.
53 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
54 mlir::ValueBoundsConstraintSet &cstr) {
55 return false;
56 };
57
58 ScalableValueBoundsConstraintSet scalableCstr(
59 value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
60 vscaleMin, vscaleMax);
61 int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
62 scalableCstr.processWorklist();
63
64 // Check the resulting constraints set is valid.
65 if (scalableCstr.cstr.isEmpty()) {
66 return failure();
67 }
68
69 // Project out all columns apart from vscale and the starting point
70 // (value/dim). This should result in constraints in terms of vscale only.
71 auto projectOutFn = [&](ValueDim p) {
72 bool isStartingPoint =
73 p.first == value &&
74 p.second == dim.value_or(u: ValueBoundsConstraintSet::kIndexValue);
75 return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
76 };
77 scalableCstr.projectOut(projectOutFn);
78 scalableCstr.projectOutAnonymous(/*except=*/pos);
79 // Also project out local variables (these are not tracked by the
80 // ValueBoundsConstraintSet).
81 for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) {
82 scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars());
83 }
84
85 assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
86 scalableCstr.positionToValueDim.size() &&
87 "inconsistent mapping state");
88
89 // Check that the only columns left are vscale and the starting point.
90 for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
91 if (i == pos)
92 continue;
93 if (scalableCstr.positionToValueDim[i] !=
94 ValueDim(scalableCstr.getVscaleValue(),
95 ValueBoundsConstraintSet::kIndexValue)) {
96 return failure();
97 }
98 }
99
100 SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
101 scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
102 &upperBound, closedUB);
103
104 auto invalidBound = [](auto &bound) {
105 return !bound[0] || bound[0].getNumResults() != 1;
106 };
107
108 AffineMap bound = [&] {
109 if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
110 lowerBound[0] == upperBound[0]) {
111 return lowerBound[0];
112 } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
113 return lowerBound[0];
114 } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
115 return upperBound[0];
116 }
117 return AffineMap{};
118 }();
119
120 if (!bound)
121 return failure();
122
123 return ConstantOrScalableBound{.map: bound};
124}
125
126} // namespace mlir::vector
127

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp