1//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StaticValueUtils.h"
10#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
11#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
12#include "mlir/Interfaces/FunctionInterfaces.h"
13
14using namespace mlir;
15using namespace mlir::vector;
16namespace {
17
18/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
19/// All-true masks can then be eliminated by simple folds.
20LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
21 vector::CreateMaskOp createMaskOp,
22 VscaleRange vscaleRange) {
23 auto maskType = createMaskOp.getVectorType();
24 auto maskTypeDimScalableFlags = maskType.getScalableDims();
25 auto maskTypeDimSizes = maskType.getShape();
26
27 struct UnknownMaskDim {
28 size_t position;
29 Value dimSize;
30 };
31
32 // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
33 // that are not obviously constant). If any constant dimension is not all-true
34 // bail out early (as this transform only trying to resolve all-true masks).
35 // This avoids doing value-bounds anaylis in cases like:
36 // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
37 // ...where it is known the mask is not all-true by looking at `%c2`.
38 SmallVector<UnknownMaskDim> unknownDims;
39 for (auto [i, dimSize] : llvm::enumerate(First: createMaskOp.getOperands())) {
40 if (auto intSize = getConstantIntValue(ofr: dimSize)) {
41 // Mask not all-true for this dim.
42 if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
43 return failure();
44 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(value: dimSize)) {
45 // Mask not all-true for this dim.
46 if (vscaleMultiplier < maskTypeDimSizes[i])
47 return failure();
48 } else {
49 // Unknown (without further analysis).
50 unknownDims.push_back(Elt: UnknownMaskDim{.position: i, .dimSize: dimSize});
51 }
52 }
53
54 for (auto [i, dimSize] : unknownDims) {
55 // Compute the lower bound for the unknown dimension (i.e. the smallest
56 // value it could be).
57 FailureOr<ConstantOrScalableBound> dimLowerBound =
58 vector::ScalableValueBoundsConstraintSet::computeScalableBound(
59 value: dimSize, dim: {}, vscaleMin: vscaleRange.vscaleMin, vscaleMax: vscaleRange.vscaleMax,
60 boundType: presburger::BoundType::LB);
61 if (failed(Result: dimLowerBound))
62 return failure();
63 auto dimLowerBoundSize = dimLowerBound->getSize();
64 if (failed(Result: dimLowerBoundSize))
65 return failure();
66 if (dimLowerBoundSize->scalable) {
67 // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
68 // this dim is not all-true.
69 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
70 return failure();
71 } else {
72 // 2. The lower bound, LB, is a constant.
73 // - If the mask dim size is scalable then this dim is not all-true.
74 if (maskTypeDimScalableFlags[i])
75 return failure();
76 // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
77 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
78 return failure();
79 }
80 }
81
82 // Replace createMaskOp with an all-true constant. This should result in the
83 // mask being removed in most cases (as xfer ops + vector.mask have folds to
84 // remove all-true masks).
85 auto allTrue = rewriter.create<vector::ConstantMaskOp>(
86 location: createMaskOp.getLoc(), args&: maskType, args: ConstantMaskKind::AllTrue);
87 rewriter.replaceAllUsesWith(from: createMaskOp, to: allTrue);
88 return success();
89}
90
91} // namespace
92
93namespace mlir::vector {
94
95void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
96 std::optional<VscaleRange> vscaleRange) {
97 // TODO: Support fixed-size case. This is less likely to be useful as for
98 // fixed-size code dimensions are all static so masks tend to fold away.
99 if (!vscaleRange)
100 return;
101
102 OpBuilder::InsertionGuard g(rewriter);
103
104 // Build worklist so we can safely insert new ops in
105 // `resolveAllTrueCreateMaskOp()`.
106 SmallVector<vector::CreateMaskOp> worklist;
107 function.walk(callback: [&](vector::CreateMaskOp createMaskOp) {
108 worklist.push_back(Elt: createMaskOp);
109 });
110
111 rewriter.setInsertionPointToStart(&function.front());
112 for (auto mask : worklist)
113 (void)resolveAllTrueCreateMaskOp(rewriter, createMaskOp: mask, vscaleRange: *vscaleRange);
114}
115
116} // namespace mlir::vector
117

source code of mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp