1//===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.mask' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18#include "mlir/Dialect/Vector/Transforms/Passes.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22#define DEBUG_TYPE "lower-vector-mask"
23
24namespace mlir {
25namespace vector {
26#define GEN_PASS_DEF_LOWERVECTORMASKPASS
27#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28} // namespace vector
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::vector;
33
34//===----------------------------------------------------------------------===//
35// populateVectorMaskOpLoweringPatterns
36//===----------------------------------------------------------------------===//
37
38namespace {
39/// Progressive lowering of CreateMaskOp.
40/// One:
41/// %x = vector.create_mask %a, ... : vector<dx...>
42/// is replaced by:
43/// %l = vector.create_mask ... : vector<...> ; one lower rank
44/// %0 = arith.cmpi "slt", %ci, %a |
45/// %1 = select %0, %l, %zeroes |
46/// %r = vector.insert %1, %pr [i] | d-times
47/// %x = ....
48/// until a one-dimensional vector is reached.
49class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50public:
51 using OpRewritePattern::OpRewritePattern;
52
53 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54 PatternRewriter &rewriter) const override {
55 auto dstType = cast<VectorType>(op.getResult().getType());
56 int64_t rank = dstType.getRank();
57 if (rank <= 1)
58 return rewriter.notifyMatchFailure(
59 op, "0-D and 1-D vectors are handled separately");
60
61 if (dstType.getScalableDims().front())
62 return rewriter.notifyMatchFailure(
63 op, "Cannot unroll leading scalable dim in dstType");
64
65 auto loc = op.getLoc();
66 int64_t dim = dstType.getDimSize(0);
67 Value idx = op.getOperand(0);
68
69 VectorType lowType = VectorType::Builder(dstType).dropDim(0);
70 Value trueVal = rewriter.create<vector::CreateMaskOp>(
71 loc, lowType, op.getOperands().drop_front());
72 Value falseVal = rewriter.create<arith::ConstantOp>(
73 loc, lowType, rewriter.getZeroAttr(lowType));
74 Value result = rewriter.create<arith::ConstantOp>(
75 loc, dstType, rewriter.getZeroAttr(dstType));
76 for (int64_t d = 0; d < dim; d++) {
77 Value bnd =
78 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
79 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
80 bnd, idx);
81 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
82 result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
83 }
84 rewriter.replaceOp(op, result);
85 return success();
86 }
87};
88
89/// Progressive lowering of ConstantMaskOp.
90/// One:
91/// %x = vector.constant_mask [a,b]
92/// is replaced by:
93/// %z = zero-result
94/// %l = vector.constant_mask [b]
95/// %4 = vector.insert %l, %z[0]
96/// ..
97/// %x = vector.insert %l, %..[a-1]
98/// until a one-dimensional vector is reached. All these operations
99/// will be folded at LLVM IR level.
100class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
101public:
102 using OpRewritePattern::OpRewritePattern;
103
104 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
105 PatternRewriter &rewriter) const override {
106 auto loc = op.getLoc();
107 auto dstType = op.getType();
108 auto dimSizes = op.getMaskDimSizes();
109 int64_t rank = dstType.getRank();
110
111 if (rank == 0) {
112 assert(dimSizes.size() == 1 &&
113 "Expected exactly one dim size for a 0-D vector");
114 bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
115 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
116 op, dstType,
117 DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
118 value));
119 return success();
120 }
121
122 int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
123
124 if (rank == 1) {
125 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
126 // Use constant splat for 'all set' or 'none set' dims.
127 // This produces correct code for scalable dimensions (it will lower to
128 // a constant splat).
129 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
130 op, DenseElementsAttr::get(dstType, trueDimSize != 0));
131 } else {
132 // Express constant 1-D case in explicit vector form:
133 // [T,..,T,F,..,F].
134 // Note: The verifier would reject this case for scalable vectors.
135 SmallVector<bool> values(dstType.getDimSize(0), false);
136 for (int64_t d = 0; d < trueDimSize; d++)
137 values[d] = true;
138 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
139 op, dstType, rewriter.getBoolVectorAttr(values));
140 }
141 return success();
142 }
143
144 if (dstType.getScalableDims().front())
145 return rewriter.notifyMatchFailure(
146 op, "Cannot unroll leading scalable dim in dstType");
147
148 VectorType lowType = VectorType::Builder(dstType).dropDim(0);
149 Value trueVal = rewriter.create<vector::ConstantMaskOp>(
150 loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
151 Value result = rewriter.create<arith::ConstantOp>(
152 loc, dstType, rewriter.getZeroAttr(dstType));
153 for (int64_t d = 0; d < trueDimSize; d++)
154 result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
155
156 rewriter.replaceOp(op, result);
157 return success();
158 }
159};
160} // namespace
161
162void mlir::vector::populateVectorMaskOpLoweringPatterns(
163 RewritePatternSet &patterns, PatternBenefit benefit) {
164 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
165 arg: patterns.getContext(), args&: benefit);
166}
167
168//===----------------------------------------------------------------------===//
169// populateVectorMaskLoweringPatternsForSideEffectingOps
170//===----------------------------------------------------------------------===//
171
172namespace {
173
174/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
175/// matching:
176/// 1. It matches a `vector.mask` operation.
177/// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
178/// in the matched `vector.mask` operation.
179///
180/// It is required that the replacement op in the pattern replaces the
181/// `vector.mask` operation and not the nested `MaskableOpInterface`. This
182/// approach allows having patterns that "stop" at every `vector.mask` operation
183/// and actually match the traits of its the nested `MaskableOpInterface`.
184template <class SourceOp>
185struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
186 using OpRewritePattern<MaskOp>::OpRewritePattern;
187
188private:
189 LogicalResult matchAndRewrite(MaskOp maskOp,
190 PatternRewriter &rewriter) const final {
191 auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
192 if (!maskableOp)
193 return failure();
194 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
195 if (!sourceOp)
196 return failure();
197
198 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
199 }
200
201protected:
202 virtual LogicalResult
203 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
204 PatternRewriter &rewriter) const = 0;
205};
206
207/// Lowers a masked `vector.transfer_read` operation.
208struct MaskedTransferReadOpPattern
209 : public MaskOpRewritePattern<TransferReadOp> {
210public:
211 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
212
213 LogicalResult
214 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
215 PatternRewriter &rewriter) const override {
216 // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
217 // expects a scalar. We could only lower one to the other for cases where
218 // the passthru is a broadcast of a scalar.
219 if (maskingOp.hasPassthru())
220 return rewriter.notifyMatchFailure(
221 maskingOp, "Can't lower passthru to vector.transfer_read");
222
223 // Replace the `vector.mask` operation.
224 rewriter.replaceOpWithNewOp<TransferReadOp>(
225 maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
226 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
227 maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr()));
228 return success();
229 }
230};
231
232/// Lowers a masked `vector.transfer_write` operation.
233struct MaskedTransferWriteOpPattern
234 : public MaskOpRewritePattern<TransferWriteOp> {
235public:
236 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
237
238 LogicalResult
239 matchAndRewriteMaskableOp(TransferWriteOp writeOp,
240 MaskingOpInterface maskingOp,
241 PatternRewriter &rewriter) const override {
242 Type resultType =
243 writeOp.getResult() ? writeOp.getResult().getType() : Type();
244
245 // Replace the `vector.mask` operation.
246 rewriter.replaceOpWithNewOp<TransferWriteOp>(
247 maskingOp.getOperation(), resultType, writeOp.getVector(),
248 writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
249 maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr()));
250 return success();
251 }
252};
253
254/// Lowers a masked `vector.gather` operation.
255struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
256public:
257 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
258
259 LogicalResult
260 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
261 PatternRewriter &rewriter) const override {
262 Value passthru = maskingOp.hasPassthru()
263 ? maskingOp.getPassthru()
264 : rewriter.create<arith::ConstantOp>(
265 gatherOp.getLoc(),
266 rewriter.getZeroAttr(gatherOp.getVectorType()));
267
268 // Replace the `vector.mask` operation.
269 rewriter.replaceOpWithNewOp<GatherOp>(
270 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
271 gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
272 passthru);
273 return success();
274 }
275};
276
277struct LowerVectorMaskPass
278 : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
279 using Base::Base;
280
281 void runOnOperation() override {
282 Operation *op = getOperation();
283 MLIRContext *context = op->getContext();
284
285 RewritePatternSet loweringPatterns(context);
286 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns&: loweringPatterns);
287 MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
288
289 if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
290 signalPassFailure();
291 }
292
293 void getDependentDialects(DialectRegistry &registry) const override {
294 registry.insert<vector::VectorDialect>();
295 }
296};
297
298} // namespace
299
300/// Populates instances of `MaskOpRewritePattern` to lower masked operations
301/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
302/// not its nested `MaskableOpInterface`.
303void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
304 RewritePatternSet &patterns) {
305 patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306 MaskedGatherOpPattern>(arg: patterns.getContext());
307}
308
309std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
310 return std::make_unique<LowerVectorMaskPass>();
311}
312

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp