1 | //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" |
10 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
11 | |
12 | namespace mlir::vector { |
13 | |
14 | FailureOr<ConstantOrScalableBound::BoundSize> |
15 | ConstantOrScalableBound::getSize() const { |
16 | if (map.isSingleConstant()) |
17 | return BoundSize{.baseSize: map.getSingleConstantResult(), /*scalable=*/false}; |
18 | if (map.getNumResults() != 1 || map.getNumInputs() != 1) |
19 | return failure(); |
20 | auto binop = dyn_cast<AffineBinaryOpExpr>(Val: map.getResult(idx: 0)); |
21 | if (!binop || binop.getKind() != AffineExprKind::Mul) |
22 | return failure(); |
23 | auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { |
24 | if (auto cst = dyn_cast<AffineConstantExpr>(Val&: expr)) { |
25 | constant = cst.getValue(); |
26 | return true; |
27 | } |
28 | return false; |
29 | }; |
30 | // Match `s0 * cst` or `cst * s0`: |
31 | int64_t cst = 0; |
32 | auto lhs = binop.getLHS(); |
33 | auto rhs = binop.getRHS(); |
34 | if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(Val: rhs)) || |
35 | (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(Val: lhs))) { |
36 | return BoundSize{.baseSize: cst, /*scalable=*/true}; |
37 | } |
38 | return failure(); |
39 | } |
40 | |
41 | char ScalableValueBoundsConstraintSet::ID = 0; |
42 | |
43 | FailureOr<ConstantOrScalableBound> |
44 | ScalableValueBoundsConstraintSet::computeScalableBound( |
45 | Value value, std::optional<int64_t> dim, unsigned vscaleMin, |
46 | unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, |
47 | StopConditionFn stopCondition) { |
48 | using namespace presburger; |
49 | assert(vscaleMin <= vscaleMax); |
50 | |
51 | // No stop condition specified: Keep adding constraints until the worklist |
52 | // is empty. |
53 | auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, |
54 | mlir::ValueBoundsConstraintSet &cstr) { |
55 | return false; |
56 | }; |
57 | |
58 | ScalableValueBoundsConstraintSet scalableCstr( |
59 | value.getContext(), stopCondition ? stopCondition : defaultStopCondition, |
60 | vscaleMin, vscaleMax); |
61 | int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); |
62 | scalableCstr.processWorklist(); |
63 | |
64 | // Check the resulting constraints set is valid. |
65 | if (scalableCstr.cstr.isEmpty()) { |
66 | return failure(); |
67 | } |
68 | |
69 | // Project out all columns apart from vscale and the starting point |
70 | // (value/dim). This should result in constraints in terms of vscale only. |
71 | auto projectOutFn = [&](ValueDim p) { |
72 | bool isStartingPoint = |
73 | p.first == value && |
74 | p.second == dim.value_or(u: ValueBoundsConstraintSet::kIndexValue); |
75 | return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; |
76 | }; |
77 | scalableCstr.projectOut(projectOutFn); |
78 | scalableCstr.projectOutAnonymous(/*except=*/pos); |
79 | // Also project out local variables (these are not tracked by the |
80 | // ValueBoundsConstraintSet). |
81 | for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) { |
82 | scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars()); |
83 | } |
84 | |
85 | assert(scalableCstr.cstr.getNumDimAndSymbolVars() == |
86 | scalableCstr.positionToValueDim.size() && |
87 | "inconsistent mapping state" ); |
88 | |
89 | // Check that the only columns left are vscale and the starting point. |
90 | for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { |
91 | if (i == pos) |
92 | continue; |
93 | if (scalableCstr.positionToValueDim[i] != |
94 | ValueDim(scalableCstr.getVscaleValue(), |
95 | ValueBoundsConstraintSet::kIndexValue)) { |
96 | return failure(); |
97 | } |
98 | } |
99 | |
100 | SmallVector<AffineMap, 1> lowerBound(1), upperBound(1); |
101 | scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, |
102 | &upperBound, closedUB); |
103 | |
104 | auto invalidBound = [](auto &bound) { |
105 | return !bound[0] || bound[0].getNumResults() != 1; |
106 | }; |
107 | |
108 | AffineMap bound = [&] { |
109 | if (boundType == BoundType::EQ && !invalidBound(lowerBound) && |
110 | lowerBound[0] == upperBound[0]) { |
111 | return lowerBound[0]; |
112 | } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { |
113 | return lowerBound[0]; |
114 | } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { |
115 | return upperBound[0]; |
116 | } |
117 | return AffineMap{}; |
118 | }(); |
119 | |
120 | if (!bound) |
121 | return failure(); |
122 | |
123 | return ConstantOrScalableBound{.map: bound}; |
124 | } |
125 | |
126 | } // namespace mlir::vector |
127 | |