| 1 | //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" |
| 10 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 11 | |
| 12 | namespace mlir::vector { |
| 13 | |
| 14 | FailureOr<ConstantOrScalableBound::BoundSize> |
| 15 | ConstantOrScalableBound::getSize() const { |
| 16 | if (map.isSingleConstant()) |
| 17 | return BoundSize{.baseSize: map.getSingleConstantResult(), /*scalable=*/false}; |
| 18 | if (map.getNumResults() != 1 || map.getNumInputs() != 1) |
| 19 | return failure(); |
| 20 | auto binop = dyn_cast<AffineBinaryOpExpr>(Val: map.getResult(idx: 0)); |
| 21 | if (!binop || binop.getKind() != AffineExprKind::Mul) |
| 22 | return failure(); |
| 23 | auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { |
| 24 | if (auto cst = dyn_cast<AffineConstantExpr>(Val&: expr)) { |
| 25 | constant = cst.getValue(); |
| 26 | return true; |
| 27 | } |
| 28 | return false; |
| 29 | }; |
| 30 | // Match `s0 * cst` or `cst * s0`: |
| 31 | int64_t cst = 0; |
| 32 | auto lhs = binop.getLHS(); |
| 33 | auto rhs = binop.getRHS(); |
| 34 | if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(Val: rhs)) || |
| 35 | (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(Val: lhs))) { |
| 36 | return BoundSize{.baseSize: cst, /*scalable=*/true}; |
| 37 | } |
| 38 | return failure(); |
| 39 | } |
| 40 | |
| 41 | char ScalableValueBoundsConstraintSet::ID = 0; |
| 42 | |
| 43 | FailureOr<ConstantOrScalableBound> |
| 44 | ScalableValueBoundsConstraintSet::computeScalableBound( |
| 45 | Value value, std::optional<int64_t> dim, unsigned vscaleMin, |
| 46 | unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, |
| 47 | StopConditionFn stopCondition) { |
| 48 | using namespace presburger; |
| 49 | assert(vscaleMin <= vscaleMax); |
| 50 | |
| 51 | // No stop condition specified: Keep adding constraints until the worklist |
| 52 | // is empty. |
| 53 | auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, |
| 54 | mlir::ValueBoundsConstraintSet &cstr) { |
| 55 | return false; |
| 56 | }; |
| 57 | |
| 58 | ScalableValueBoundsConstraintSet scalableCstr( |
| 59 | value.getContext(), stopCondition ? stopCondition : defaultStopCondition, |
| 60 | vscaleMin, vscaleMax); |
| 61 | int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); |
| 62 | scalableCstr.processWorklist(); |
| 63 | |
| 64 | // Check the resulting constraints set is valid. |
| 65 | if (scalableCstr.cstr.isEmpty()) { |
| 66 | return failure(); |
| 67 | } |
| 68 | |
| 69 | // Project out all columns apart from vscale and the starting point |
| 70 | // (value/dim). This should result in constraints in terms of vscale only. |
| 71 | auto projectOutFn = [&](ValueDim p) { |
| 72 | bool isStartingPoint = |
| 73 | p.first == value && |
| 74 | p.second == dim.value_or(u: ValueBoundsConstraintSet::kIndexValue); |
| 75 | return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; |
| 76 | }; |
| 77 | scalableCstr.projectOut(projectOutFn); |
| 78 | scalableCstr.projectOutAnonymous(/*except=*/pos); |
| 79 | // Also project out local variables (these are not tracked by the |
| 80 | // ValueBoundsConstraintSet). |
| 81 | for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) { |
| 82 | scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars()); |
| 83 | } |
| 84 | |
| 85 | assert(scalableCstr.cstr.getNumDimAndSymbolVars() == |
| 86 | scalableCstr.positionToValueDim.size() && |
| 87 | "inconsistent mapping state" ); |
| 88 | |
| 89 | // Check that the only columns left are vscale and the starting point. |
| 90 | for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { |
| 91 | if (i == pos) |
| 92 | continue; |
| 93 | if (scalableCstr.positionToValueDim[i] != |
| 94 | ValueDim(scalableCstr.getVscaleValue(), |
| 95 | ValueBoundsConstraintSet::kIndexValue)) { |
| 96 | return failure(); |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | SmallVector<AffineMap, 1> lowerBound(1), upperBound(1); |
| 101 | scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, |
| 102 | &upperBound, closedUB); |
| 103 | |
| 104 | auto invalidBound = [](auto &bound) { |
| 105 | return !bound[0] || bound[0].getNumResults() != 1; |
| 106 | }; |
| 107 | |
| 108 | AffineMap bound = [&] { |
| 109 | if (boundType == BoundType::EQ && !invalidBound(lowerBound) && |
| 110 | lowerBound[0] == upperBound[0]) { |
| 111 | return lowerBound[0]; |
| 112 | } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { |
| 113 | return lowerBound[0]; |
| 114 | } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { |
| 115 | return upperBound[0]; |
| 116 | } |
| 117 | return AffineMap{}; |
| 118 | }(); |
| 119 | |
| 120 | if (!bound) |
| 121 | return failure(); |
| 122 | |
| 123 | return ConstantOrScalableBound{.map: bound}; |
| 124 | } |
| 125 | |
| 126 | } // namespace mlir::vector |
| 127 | |