1//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Utils/StaticValueUtils.h"
11#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
12#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
13#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15
16using namespace mlir;
17using namespace mlir::vector;
18namespace {
19
20/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
21/// All-true masks can then be eliminated by simple folds.
22LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
23 vector::CreateMaskOp createMaskOp,
24 VscaleRange vscaleRange) {
25 auto maskType = createMaskOp.getVectorType();
26 auto maskTypeDimScalableFlags = maskType.getScalableDims();
27 auto maskTypeDimSizes = maskType.getShape();
28
29 struct UnknownMaskDim {
30 size_t position;
31 Value dimSize;
32 };
33
34 // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
35 // that are not obviously constant). If any constant dimension is not all-true
36 // bail out early (as this transform only trying to resolve all-true masks).
37 // This avoids doing value-bounds anaylis in cases like:
38 // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
39 // ...where it is known the mask is not all-true by looking at `%c2`.
40 SmallVector<UnknownMaskDim> unknownDims;
41 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
42 if (auto intSize = getConstantIntValue(dimSize)) {
43 // Mask not all-true for this dim.
44 if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
45 return failure();
46 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
47 // Mask not all-true for this dim.
48 if (vscaleMultiplier < maskTypeDimSizes[i])
49 return failure();
50 } else {
51 // Unknown (without further analysis).
52 unknownDims.push_back(UnknownMaskDim{i, dimSize});
53 }
54 }
55
56 for (auto [i, dimSize] : unknownDims) {
57 // Compute the lower bound for the unknown dimension (i.e. the smallest
58 // value it could be).
59 FailureOr<ConstantOrScalableBound> dimLowerBound =
60 vector::ScalableValueBoundsConstraintSet::computeScalableBound(
61 dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
62 presburger::BoundType::LB);
63 if (failed(Result: dimLowerBound))
64 return failure();
65 auto dimLowerBoundSize = dimLowerBound->getSize();
66 if (failed(dimLowerBoundSize))
67 return failure();
68 if (dimLowerBoundSize->scalable) {
69 // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
70 // this dim is not all-true.
71 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
72 return failure();
73 } else {
74 // 2. The lower bound, LB, is a constant.
75 // - If the mask dim size is scalable then this dim is not all-true.
76 if (maskTypeDimScalableFlags[i])
77 return failure();
78 // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
79 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
80 return failure();
81 }
82 }
83
84 // Replace createMaskOp with an all-true constant. This should result in the
85 // mask being removed in most cases (as xfer ops + vector.mask have folds to
86 // remove all-true masks).
87 auto allTrue = rewriter.create<vector::ConstantMaskOp>(
88 createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
89 rewriter.replaceAllUsesWith(createMaskOp, allTrue);
90 return success();
91}
92
93} // namespace
94
95namespace mlir::vector {
96
97void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
98 std::optional<VscaleRange> vscaleRange) {
99 // TODO: Support fixed-size case. This is less likely to be useful as for
100 // fixed-size code dimensions are all static so masks tend to fold away.
101 if (!vscaleRange)
102 return;
103
104 OpBuilder::InsertionGuard g(rewriter);
105
106 // Build worklist so we can safely insert new ops in
107 // `resolveAllTrueCreateMaskOp()`.
108 SmallVector<vector::CreateMaskOp> worklist;
109 function.walk([&](vector::CreateMaskOp createMaskOp) {
110 worklist.push_back(createMaskOp);
111 });
112
113 rewriter.setInsertionPointToStart(&function.front());
114 for (auto mask : worklist)
115 (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
116}
117
118} // namespace mlir::vector
119

source code of mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp