1//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
10
11#include "mlir/Dialect/Vector/IR/VectorOps.h"
12
13namespace mlir::vector {
14
15FailureOr<ConstantOrScalableBound::BoundSize>
16ConstantOrScalableBound::getSize() const {
17 if (map.isSingleConstant())
18 return BoundSize{.baseSize: map.getSingleConstantResult(), /*scalable=*/false};
19 if (map.getNumResults() != 1 || map.getNumInputs() != 1)
20 return failure();
21 auto binop = dyn_cast<AffineBinaryOpExpr>(Val: map.getResult(idx: 0));
22 if (!binop || binop.getKind() != AffineExprKind::Mul)
23 return failure();
24 auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
25 if (auto cst = dyn_cast<AffineConstantExpr>(Val&: expr)) {
26 constant = cst.getValue();
27 return true;
28 }
29 return false;
30 };
31 // Match `s0 * cst` or `cst * s0`:
32 int64_t cst = 0;
33 auto lhs = binop.getLHS();
34 auto rhs = binop.getRHS();
35 if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(Val: rhs)) ||
36 (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(Val: lhs))) {
37 return BoundSize{.baseSize: cst, /*scalable=*/true};
38 }
39 return failure();
40}
41
42char ScalableValueBoundsConstraintSet::ID = 0;
43
44FailureOr<ConstantOrScalableBound>
45ScalableValueBoundsConstraintSet::computeScalableBound(
46 Value value, std::optional<int64_t> dim, unsigned vscaleMin,
47 unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
48 StopConditionFn stopCondition) {
49 using namespace presburger;
50 assert(vscaleMin <= vscaleMax);
51
52 // No stop condition specified: Keep adding constraints until the worklist
53 // is empty.
54 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
55 mlir::ValueBoundsConstraintSet &cstr) {
56 return false;
57 };
58
59 ScalableValueBoundsConstraintSet scalableCstr(
60 value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
61 vscaleMin, vscaleMax);
62 int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
63 scalableCstr.processWorklist();
64
65 // Project out all columns apart from vscale and the starting point
66 // (value/dim). This should result in constraints in terms of vscale only.
67 auto projectOutFn = [&](ValueDim p) {
68 bool isStartingPoint =
69 p.first == value &&
70 p.second == dim.value_or(u: ValueBoundsConstraintSet::kIndexValue);
71 return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
72 };
73 scalableCstr.projectOut(projectOutFn);
74
75 assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
76 scalableCstr.positionToValueDim.size() &&
77 "inconsistent mapping state");
78
79 // Check that the only columns left are vscale and the starting point.
80 for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
81 if (i == pos)
82 continue;
83 if (scalableCstr.positionToValueDim[i] !=
84 ValueDim(scalableCstr.getVscaleValue(),
85 ValueBoundsConstraintSet::kIndexValue)) {
86 return failure();
87 }
88 }
89
90 SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
91 scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
92 &upperBound, closedUB);
93
94 auto invalidBound = [](auto &bound) {
95 return !bound[0] || bound[0].getNumResults() != 1;
96 };
97
98 AffineMap bound = [&] {
99 if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
100 lowerBound[0] == lowerBound[0]) {
101 return lowerBound[0];
102 } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
103 return lowerBound[0];
104 } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
105 return upperBound[0];
106 }
107 return AffineMap{};
108 }();
109
110 if (!bound)
111 return failure();
112
113 return ConstantOrScalableBound{.map: bound};
114}
115
116} // namespace mlir::vector
117

source code of mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp