1//===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.mask' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18#include "mlir/Dialect/Vector/Transforms/Passes.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22#define DEBUG_TYPE "lower-vector-mask"
23
24namespace mlir {
25namespace vector {
26#define GEN_PASS_DEF_LOWERVECTORMASKPASS
27#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28} // namespace vector
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::vector;
33
34//===----------------------------------------------------------------------===//
35// populateVectorMaskOpLoweringPatterns
36//===----------------------------------------------------------------------===//
37
38namespace {
39/// Progressive lowering of CreateMaskOp.
40/// One:
41/// %x = vector.create_mask %a, ... : vector<dx...>
42/// is replaced by:
43/// %l = vector.create_mask ... : vector<...> ; one lower rank
44/// %0 = arith.cmpi "slt", %ci, %a |
45/// %1 = select %0, %l, %zeroes |
46/// %r = vector.insert %1, %pr [i] | d-times
47/// %x = ....
48/// until a one-dimensional vector is reached.
49class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50public:
51 using OpRewritePattern::OpRewritePattern;
52
53 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54 PatternRewriter &rewriter) const override {
55 auto dstType = cast<VectorType>(Val: op.getResult().getType());
56 int64_t rank = dstType.getRank();
57 if (rank <= 1)
58 return rewriter.notifyMatchFailure(
59 arg&: op, msg: "0-D and 1-D vectors are handled separately");
60
61 if (dstType.getScalableDims().front())
62 return rewriter.notifyMatchFailure(
63 arg&: op, msg: "Cannot unroll leading scalable dim in dstType");
64
65 auto loc = op.getLoc();
66 int64_t dim = dstType.getDimSize(idx: 0);
67 Value idx = op.getOperand(i: 0);
68
69 VectorType lowType = VectorType::Builder(dstType).dropDim(pos: 0);
70 Value trueVal = rewriter.create<vector::CreateMaskOp>(
71 location: loc, args&: lowType, args: op.getOperands().drop_front());
72 Value falseVal = rewriter.create<arith::ConstantOp>(
73 location: loc, args&: lowType, args: rewriter.getZeroAttr(type: lowType));
74 Value result = rewriter.create<arith::ConstantOp>(
75 location: loc, args&: dstType, args: rewriter.getZeroAttr(type: dstType));
76 for (int64_t d = 0; d < dim; d++) {
77 Value bnd =
78 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: d));
79 Value val = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt,
80 args&: bnd, args&: idx);
81 Value sel = rewriter.create<arith::SelectOp>(location: loc, args&: val, args&: trueVal, args&: falseVal);
82 result = rewriter.create<vector::InsertOp>(location: loc, args&: sel, args&: result, args&: d);
83 }
84 rewriter.replaceOp(op, newValues: result);
85 return success();
86 }
87};
88
89/// Progressive lowering of ConstantMaskOp.
90/// One:
91/// %x = vector.constant_mask [a,b]
92/// is replaced by:
93/// %z = zero-result
94/// %l = vector.constant_mask [b]
95/// %4 = vector.insert %l, %z[0]
96/// ..
97/// %x = vector.insert %l, %..[a-1]
98/// until a one-dimensional vector is reached. All these operations
99/// will be folded at LLVM IR level.
100class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
101public:
102 using OpRewritePattern::OpRewritePattern;
103
104 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
105 PatternRewriter &rewriter) const override {
106 auto loc = op.getLoc();
107 auto dstType = op.getType();
108 auto dimSizes = op.getMaskDimSizes();
109 int64_t rank = dstType.getRank();
110
111 if (rank == 0) {
112 assert(dimSizes.size() == 1 &&
113 "Expected exactly one dim size for a 0-D vector");
114 bool value = dimSizes.front() == 1;
115 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
116 op, args&: dstType,
117 args: DenseIntElementsAttr::get(type: VectorType::get(shape: {}, elementType: rewriter.getI1Type()),
118 arg&: value));
119 return success();
120 }
121
122 int64_t trueDimSize = dimSizes.front();
123
124 if (rank == 1) {
125 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(idx: 0)) {
126 // Use constant splat for 'all set' or 'none set' dims.
127 // This produces correct code for scalable dimensions (it will lower to
128 // a constant splat).
129 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
130 op, args: DenseElementsAttr::get(type: dstType, value: trueDimSize != 0));
131 } else {
132 // Express constant 1-D case in explicit vector form:
133 // [T,..,T,F,..,F].
134 // Note: The verifier would reject this case for scalable vectors.
135 SmallVector<bool> values(dstType.getDimSize(idx: 0), false);
136 for (int64_t d = 0; d < trueDimSize; d++)
137 values[d] = true;
138 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
139 op, args&: dstType, args: rewriter.getBoolVectorAttr(values));
140 }
141 return success();
142 }
143
144 if (dstType.getScalableDims().front())
145 return rewriter.notifyMatchFailure(
146 arg&: op, msg: "Cannot unroll leading scalable dim in dstType");
147
148 VectorType lowType = VectorType::Builder(dstType).dropDim(pos: 0);
149 Value trueVal = rewriter.create<vector::ConstantMaskOp>(
150 location: loc, args&: lowType, args: dimSizes.drop_front());
151 Value result = rewriter.create<arith::ConstantOp>(
152 location: loc, args&: dstType, args: rewriter.getZeroAttr(type: dstType));
153 for (int64_t d = 0; d < trueDimSize; d++)
154 result = rewriter.create<vector::InsertOp>(location: loc, args&: trueVal, args&: result, args&: d);
155
156 rewriter.replaceOp(op, newValues: result);
157 return success();
158 }
159};
160} // namespace
161
162void mlir::vector::populateVectorMaskOpLoweringPatterns(
163 RewritePatternSet &patterns, PatternBenefit benefit) {
164 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
165 arg: patterns.getContext(), args&: benefit);
166}
167
168//===----------------------------------------------------------------------===//
169// populateVectorMaskLoweringPatternsForSideEffectingOps
170//===----------------------------------------------------------------------===//
171
172namespace {
173
174/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
175/// matching:
176/// 1. It matches a `vector.mask` operation.
177/// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
178/// in the matched `vector.mask` operation.
179///
180/// It is required that the replacement op in the pattern replaces the
181/// `vector.mask` operation and not the nested `MaskableOpInterface`. This
182/// approach allows having patterns that "stop" at every `vector.mask` operation
183/// and actually match the traits of its the nested `MaskableOpInterface`.
184template <class SourceOp>
185struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
186 using OpRewritePattern<MaskOp>::OpRewritePattern;
187
188private:
189 LogicalResult matchAndRewrite(MaskOp maskOp,
190 PatternRewriter &rewriter) const final {
191 auto maskableOp = cast_or_null<MaskableOpInterface>(Val: maskOp.getMaskableOp());
192 if (!maskableOp)
193 return failure();
194 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
195 if (!sourceOp)
196 return failure();
197
198 return matchAndRewriteMaskableOp(sourceOp, maskingOp: maskOp, rewriter);
199 }
200
201protected:
202 virtual LogicalResult
203 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
204 PatternRewriter &rewriter) const = 0;
205};
206
207/// Lowers a masked `vector.transfer_read` operation.
208struct MaskedTransferReadOpPattern
209 : public MaskOpRewritePattern<TransferReadOp> {
210public:
211 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
212
213 LogicalResult
214 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
215 PatternRewriter &rewriter) const override {
216 // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
217 // expects a scalar. We could only lower one to the other for cases where
218 // the passthru is a broadcast of a scalar.
219 if (maskingOp.hasPassthru())
220 return rewriter.notifyMatchFailure(
221 arg&: maskingOp, msg: "Can't lower passthru to vector.transfer_read");
222
223 // Replace the `vector.mask` operation.
224 rewriter.replaceOpWithNewOp<TransferReadOp>(
225 op: maskingOp.getOperation(), args: readOp.getVectorType(), args: readOp.getBase(),
226 args: readOp.getIndices(), args: readOp.getPermutationMap(), args: readOp.getPadding(),
227 args: maskingOp.getMask(), args: readOp.getInBounds());
228 return success();
229 }
230};
231
232/// Lowers a masked `vector.transfer_write` operation.
233struct MaskedTransferWriteOpPattern
234 : public MaskOpRewritePattern<TransferWriteOp> {
235public:
236 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
237
238 LogicalResult
239 matchAndRewriteMaskableOp(TransferWriteOp writeOp,
240 MaskingOpInterface maskingOp,
241 PatternRewriter &rewriter) const override {
242 Type resultType =
243 writeOp.getResult() ? writeOp.getResult().getType() : Type();
244
245 // Replace the `vector.mask` operation.
246 rewriter.replaceOpWithNewOp<TransferWriteOp>(
247 op: maskingOp.getOperation(), args&: resultType, args: writeOp.getVector(),
248 args: writeOp.getBase(), args: writeOp.getIndices(), args: writeOp.getPermutationMap(),
249 args: maskingOp.getMask(), args: writeOp.getInBounds());
250 return success();
251 }
252};
253
254/// Lowers a masked `vector.gather` operation.
255struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
256public:
257 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
258
259 LogicalResult
260 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
261 PatternRewriter &rewriter) const override {
262 Value passthru = maskingOp.hasPassthru()
263 ? maskingOp.getPassthru()
264 : rewriter.create<arith::ConstantOp>(
265 location: gatherOp.getLoc(),
266 args: rewriter.getZeroAttr(type: gatherOp.getVectorType()));
267
268 // Replace the `vector.mask` operation.
269 rewriter.replaceOpWithNewOp<GatherOp>(
270 op: maskingOp.getOperation(), args: gatherOp.getVectorType(), args: gatherOp.getBase(),
271 args: gatherOp.getIndices(), args: gatherOp.getIndexVec(), args: maskingOp.getMask(),
272 args&: passthru);
273 return success();
274 }
275};
276
277struct LowerVectorMaskPass
278 : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
279 using Base::Base;
280
281 void runOnOperation() override {
282 Operation *op = getOperation();
283 MLIRContext *context = op->getContext();
284
285 RewritePatternSet loweringPatterns(context);
286 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns&: loweringPatterns);
287 MaskOp::getCanonicalizationPatterns(results&: loweringPatterns, context);
288
289 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(loweringPatterns))))
290 signalPassFailure();
291 }
292
293 void getDependentDialects(DialectRegistry &registry) const override {
294 registry.insert<vector::VectorDialect>();
295 }
296};
297
298} // namespace
299
300/// Populates instances of `MaskOpRewritePattern` to lower masked operations
301/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
302/// not its nested `MaskableOpInterface`.
303void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
304 RewritePatternSet &patterns) {
305 patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306 MaskedGatherOpPattern>(arg: patterns.getContext());
307}
308
309std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
310 return std::make_unique<LowerVectorMaskPass>();
311}
312

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp