1 | //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" |
10 | |
11 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
12 | |
13 | namespace mlir::vector { |
14 | |
15 | FailureOr<ConstantOrScalableBound::BoundSize> |
16 | ConstantOrScalableBound::getSize() const { |
17 | if (map.isSingleConstant()) |
18 | return BoundSize{.baseSize: map.getSingleConstantResult(), /*scalable=*/false}; |
19 | if (map.getNumResults() != 1 || map.getNumInputs() != 1) |
20 | return failure(); |
21 | auto binop = dyn_cast<AffineBinaryOpExpr>(Val: map.getResult(idx: 0)); |
22 | if (!binop || binop.getKind() != AffineExprKind::Mul) |
23 | return failure(); |
24 | auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { |
25 | if (auto cst = dyn_cast<AffineConstantExpr>(Val&: expr)) { |
26 | constant = cst.getValue(); |
27 | return true; |
28 | } |
29 | return false; |
30 | }; |
31 | // Match `s0 * cst` or `cst * s0`: |
32 | int64_t cst = 0; |
33 | auto lhs = binop.getLHS(); |
34 | auto rhs = binop.getRHS(); |
35 | if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(Val: rhs)) || |
36 | (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(Val: lhs))) { |
37 | return BoundSize{.baseSize: cst, /*scalable=*/true}; |
38 | } |
39 | return failure(); |
40 | } |
41 | |
42 | char ScalableValueBoundsConstraintSet::ID = 0; |
43 | |
44 | FailureOr<ConstantOrScalableBound> |
45 | ScalableValueBoundsConstraintSet::computeScalableBound( |
46 | Value value, std::optional<int64_t> dim, unsigned vscaleMin, |
47 | unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, |
48 | StopConditionFn stopCondition) { |
49 | using namespace presburger; |
50 | assert(vscaleMin <= vscaleMax); |
51 | |
52 | // No stop condition specified: Keep adding constraints until the worklist |
53 | // is empty. |
54 | auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, |
55 | mlir::ValueBoundsConstraintSet &cstr) { |
56 | return false; |
57 | }; |
58 | |
59 | ScalableValueBoundsConstraintSet scalableCstr( |
60 | value.getContext(), stopCondition ? stopCondition : defaultStopCondition, |
61 | vscaleMin, vscaleMax); |
62 | int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); |
63 | scalableCstr.processWorklist(); |
64 | |
65 | // Project out all columns apart from vscale and the starting point |
66 | // (value/dim). This should result in constraints in terms of vscale only. |
67 | auto projectOutFn = [&](ValueDim p) { |
68 | bool isStartingPoint = |
69 | p.first == value && |
70 | p.second == dim.value_or(u: ValueBoundsConstraintSet::kIndexValue); |
71 | return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; |
72 | }; |
73 | scalableCstr.projectOut(projectOutFn); |
74 | |
75 | assert(scalableCstr.cstr.getNumDimAndSymbolVars() == |
76 | scalableCstr.positionToValueDim.size() && |
77 | "inconsistent mapping state" ); |
78 | |
79 | // Check that the only columns left are vscale and the starting point. |
80 | for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { |
81 | if (i == pos) |
82 | continue; |
83 | if (scalableCstr.positionToValueDim[i] != |
84 | ValueDim(scalableCstr.getVscaleValue(), |
85 | ValueBoundsConstraintSet::kIndexValue)) { |
86 | return failure(); |
87 | } |
88 | } |
89 | |
90 | SmallVector<AffineMap, 1> lowerBound(1), upperBound(1); |
91 | scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, |
92 | &upperBound, closedUB); |
93 | |
94 | auto invalidBound = [](auto &bound) { |
95 | return !bound[0] || bound[0].getNumResults() != 1; |
96 | }; |
97 | |
98 | AffineMap bound = [&] { |
99 | if (boundType == BoundType::EQ && !invalidBound(lowerBound) && |
100 | lowerBound[0] == lowerBound[0]) { |
101 | return lowerBound[0]; |
102 | } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { |
103 | return lowerBound[0]; |
104 | } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { |
105 | return upperBound[0]; |
106 | } |
107 | return AffineMap{}; |
108 | }(); |
109 | |
110 | if (!bound) |
111 | return failure(); |
112 | |
113 | return ConstantOrScalableBound{.map: bound}; |
114 | } |
115 | |
116 | } // namespace mlir::vector |
117 | |