1 | //===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
11 | #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" |
12 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
13 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
14 | #include "mlir/Interfaces/FunctionInterfaces.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::vector; |
18 | namespace { |
19 | |
20 | /// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask. |
21 | /// All-true masks can then be eliminated by simple folds. |
22 | LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter, |
23 | vector::CreateMaskOp createMaskOp, |
24 | VscaleRange vscaleRange) { |
25 | auto maskType = createMaskOp.getVectorType(); |
26 | auto maskTypeDimScalableFlags = maskType.getScalableDims(); |
27 | auto maskTypeDimSizes = maskType.getShape(); |
28 | |
29 | struct UnknownMaskDim { |
30 | size_t position; |
31 | Value dimSize; |
32 | }; |
33 | |
34 | // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims |
35 | // that are not obviously constant). If any constant dimension is not all-true |
36 | // bail out early (as this transform only trying to resolve all-true masks). |
37 | // This avoids doing value-bounds anaylis in cases like: |
38 | // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>` |
39 | // ...where it is known the mask is not all-true by looking at `%c2`. |
40 | SmallVector<UnknownMaskDim> unknownDims; |
41 | for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) { |
42 | if (auto intSize = getConstantIntValue(dimSize)) { |
43 | // Mask not all-true for this dim. |
44 | if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i]) |
45 | return failure(); |
46 | } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) { |
47 | // Mask not all-true for this dim. |
48 | if (vscaleMultiplier < maskTypeDimSizes[i]) |
49 | return failure(); |
50 | } else { |
51 | // Unknown (without further analysis). |
52 | unknownDims.push_back(UnknownMaskDim{i, dimSize}); |
53 | } |
54 | } |
55 | |
56 | for (auto [i, dimSize] : unknownDims) { |
57 | // Compute the lower bound for the unknown dimension (i.e. the smallest |
58 | // value it could be). |
59 | FailureOr<ConstantOrScalableBound> dimLowerBound = |
60 | vector::ScalableValueBoundsConstraintSet::computeScalableBound( |
61 | dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax, |
62 | presburger::BoundType::LB); |
63 | if (failed(Result: dimLowerBound)) |
64 | return failure(); |
65 | auto dimLowerBoundSize = dimLowerBound->getSize(); |
66 | if (failed(dimLowerBoundSize)) |
67 | return failure(); |
68 | if (dimLowerBoundSize->scalable) { |
69 | // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then |
70 | // this dim is not all-true. |
71 | if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) |
72 | return failure(); |
73 | } else { |
74 | // 2. The lower bound, LB, is a constant. |
75 | // - If the mask dim size is scalable then this dim is not all-true. |
76 | if (maskTypeDimScalableFlags[i]) |
77 | return failure(); |
78 | // - If LB < the _fixed-size_ mask dim size then this dim is not all-true. |
79 | if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) |
80 | return failure(); |
81 | } |
82 | } |
83 | |
84 | // Replace createMaskOp with an all-true constant. This should result in the |
85 | // mask being removed in most cases (as xfer ops + vector.mask have folds to |
86 | // remove all-true masks). |
87 | auto allTrue = rewriter.create<vector::ConstantMaskOp>( |
88 | createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue); |
89 | rewriter.replaceAllUsesWith(createMaskOp, allTrue); |
90 | return success(); |
91 | } |
92 | |
93 | } // namespace |
94 | |
95 | namespace mlir::vector { |
96 | |
97 | void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, |
98 | std::optional<VscaleRange> vscaleRange) { |
99 | // TODO: Support fixed-size case. This is less likely to be useful as for |
100 | // fixed-size code dimensions are all static so masks tend to fold away. |
101 | if (!vscaleRange) |
102 | return; |
103 | |
104 | OpBuilder::InsertionGuard g(rewriter); |
105 | |
106 | // Build worklist so we can safely insert new ops in |
107 | // `resolveAllTrueCreateMaskOp()`. |
108 | SmallVector<vector::CreateMaskOp> worklist; |
109 | function.walk([&](vector::CreateMaskOp createMaskOp) { |
110 | worklist.push_back(createMaskOp); |
111 | }); |
112 | |
113 | rewriter.setInsertionPointToStart(&function.front()); |
114 | for (auto mask : worklist) |
115 | (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange); |
116 | } |
117 | |
118 | } // namespace mlir::vector |
119 | |