1 | //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.mask' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
17 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
18 | #include "mlir/Dialect/Vector/Transforms/Passes.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | |
22 | #define DEBUG_TYPE "lower-vector-mask" |
23 | |
24 | namespace mlir { |
25 | namespace vector { |
26 | #define GEN_PASS_DEF_LOWERVECTORMASKPASS |
27 | #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
28 | } // namespace vector |
29 | } // namespace mlir |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::vector; |
33 | |
34 | //===----------------------------------------------------------------------===// |
35 | // populateVectorMaskOpLoweringPatterns |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | namespace { |
39 | /// Progressive lowering of CreateMaskOp. |
40 | /// One: |
41 | /// %x = vector.create_mask %a, ... : vector<dx...> |
42 | /// is replaced by: |
43 | /// %l = vector.create_mask ... : vector<...> ; one lower rank |
44 | /// %0 = arith.cmpi "slt", %ci, %a | |
45 | /// %1 = select %0, %l, %zeroes | |
46 | /// %r = vector.insert %1, %pr [i] | d-times |
47 | /// %x = .... |
48 | /// until a one-dimensional vector is reached. |
49 | class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { |
50 | public: |
51 | using OpRewritePattern::OpRewritePattern; |
52 | |
53 | LogicalResult matchAndRewrite(vector::CreateMaskOp op, |
54 | PatternRewriter &rewriter) const override { |
55 | auto dstType = cast<VectorType>(op.getResult().getType()); |
56 | int64_t rank = dstType.getRank(); |
57 | if (rank <= 1) |
58 | return rewriter.notifyMatchFailure( |
59 | op, "0-D and 1-D vectors are handled separately" ); |
60 | |
61 | if (dstType.getScalableDims().front()) |
62 | return rewriter.notifyMatchFailure( |
63 | op, "Cannot unroll leading scalable dim in dstType" ); |
64 | |
65 | auto loc = op.getLoc(); |
66 | int64_t dim = dstType.getDimSize(0); |
67 | Value idx = op.getOperand(0); |
68 | |
69 | VectorType lowType = VectorType::Builder(dstType).dropDim(0); |
70 | Value trueVal = rewriter.create<vector::CreateMaskOp>( |
71 | loc, lowType, op.getOperands().drop_front()); |
72 | Value falseVal = rewriter.create<arith::ConstantOp>( |
73 | loc, lowType, rewriter.getZeroAttr(lowType)); |
74 | Value result = rewriter.create<arith::ConstantOp>( |
75 | loc, dstType, rewriter.getZeroAttr(dstType)); |
76 | for (int64_t d = 0; d < dim; d++) { |
77 | Value bnd = |
78 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); |
79 | Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, |
80 | bnd, idx); |
81 | Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); |
82 | result = rewriter.create<vector::InsertOp>(loc, sel, result, d); |
83 | } |
84 | rewriter.replaceOp(op, result); |
85 | return success(); |
86 | } |
87 | }; |
88 | |
89 | /// Progressive lowering of ConstantMaskOp. |
90 | /// One: |
91 | /// %x = vector.constant_mask [a,b] |
92 | /// is replaced by: |
93 | /// %z = zero-result |
94 | /// %l = vector.constant_mask [b] |
95 | /// %4 = vector.insert %l, %z[0] |
96 | /// .. |
97 | /// %x = vector.insert %l, %..[a-1] |
98 | /// until a one-dimensional vector is reached. All these operations |
99 | /// will be folded at LLVM IR level. |
100 | class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { |
101 | public: |
102 | using OpRewritePattern::OpRewritePattern; |
103 | |
104 | LogicalResult matchAndRewrite(vector::ConstantMaskOp op, |
105 | PatternRewriter &rewriter) const override { |
106 | auto loc = op.getLoc(); |
107 | auto dstType = op.getType(); |
108 | auto dimSizes = op.getMaskDimSizes(); |
109 | int64_t rank = dstType.getRank(); |
110 | |
111 | if (rank == 0) { |
112 | assert(dimSizes.size() == 1 && |
113 | "Expected exactly one dim size for a 0-D vector" ); |
114 | bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1; |
115 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
116 | op, dstType, |
117 | DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), |
118 | value)); |
119 | return success(); |
120 | } |
121 | |
122 | int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt(); |
123 | |
124 | if (rank == 1) { |
125 | if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { |
126 | // Use constant splat for 'all set' or 'none set' dims. |
127 | // This produces correct code for scalable dimensions (it will lower to |
128 | // a constant splat). |
129 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
130 | op, DenseElementsAttr::get(dstType, trueDimSize != 0)); |
131 | } else { |
132 | // Express constant 1-D case in explicit vector form: |
133 | // [T,..,T,F,..,F]. |
134 | // Note: The verifier would reject this case for scalable vectors. |
135 | SmallVector<bool> values(dstType.getDimSize(0), false); |
136 | for (int64_t d = 0; d < trueDimSize; d++) |
137 | values[d] = true; |
138 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
139 | op, dstType, rewriter.getBoolVectorAttr(values)); |
140 | } |
141 | return success(); |
142 | } |
143 | |
144 | if (dstType.getScalableDims().front()) |
145 | return rewriter.notifyMatchFailure( |
146 | op, "Cannot unroll leading scalable dim in dstType" ); |
147 | |
148 | VectorType lowType = VectorType::Builder(dstType).dropDim(0); |
149 | Value trueVal = rewriter.create<vector::ConstantMaskOp>( |
150 | loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front())); |
151 | Value result = rewriter.create<arith::ConstantOp>( |
152 | loc, dstType, rewriter.getZeroAttr(dstType)); |
153 | for (int64_t d = 0; d < trueDimSize; d++) |
154 | result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d); |
155 | |
156 | rewriter.replaceOp(op, result); |
157 | return success(); |
158 | } |
159 | }; |
160 | } // namespace |
161 | |
162 | void mlir::vector::populateVectorMaskOpLoweringPatterns( |
163 | RewritePatternSet &patterns, PatternBenefit benefit) { |
164 | patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( |
165 | arg: patterns.getContext(), args&: benefit); |
166 | } |
167 | |
168 | //===----------------------------------------------------------------------===// |
169 | // populateVectorMaskLoweringPatternsForSideEffectingOps |
170 | //===----------------------------------------------------------------------===// |
171 | |
172 | namespace { |
173 | |
174 | /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold |
175 | /// matching: |
176 | /// 1. It matches a `vector.mask` operation. |
177 | /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested |
178 | /// in the matched `vector.mask` operation. |
179 | /// |
180 | /// It is required that the replacement op in the pattern replaces the |
181 | /// `vector.mask` operation and not the nested `MaskableOpInterface`. This |
182 | /// approach allows having patterns that "stop" at every `vector.mask` operation |
183 | /// and actually match the traits of its the nested `MaskableOpInterface`. |
184 | template <class SourceOp> |
185 | struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { |
186 | using OpRewritePattern<MaskOp>::OpRewritePattern; |
187 | |
188 | private: |
189 | LogicalResult matchAndRewrite(MaskOp maskOp, |
190 | PatternRewriter &rewriter) const final { |
191 | auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp()); |
192 | if (!maskableOp) |
193 | return failure(); |
194 | SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); |
195 | if (!sourceOp) |
196 | return failure(); |
197 | |
198 | return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); |
199 | } |
200 | |
201 | protected: |
202 | virtual LogicalResult |
203 | matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, |
204 | PatternRewriter &rewriter) const = 0; |
205 | }; |
206 | |
207 | /// Lowers a masked `vector.transfer_read` operation. |
208 | struct MaskedTransferReadOpPattern |
209 | : public MaskOpRewritePattern<TransferReadOp> { |
210 | public: |
211 | using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; |
212 | |
213 | LogicalResult |
214 | matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, |
215 | PatternRewriter &rewriter) const override { |
216 | // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' |
217 | // expects a scalar. We could only lower one to the other for cases where |
218 | // the passthru is a broadcast of a scalar. |
219 | if (maskingOp.hasPassthru()) |
220 | return rewriter.notifyMatchFailure( |
221 | maskingOp, "Can't lower passthru to vector.transfer_read" ); |
222 | |
223 | // Replace the `vector.mask` operation. |
224 | rewriter.replaceOpWithNewOp<TransferReadOp>( |
225 | maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(), |
226 | readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), |
227 | maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr())); |
228 | return success(); |
229 | } |
230 | }; |
231 | |
232 | /// Lowers a masked `vector.transfer_write` operation. |
233 | struct MaskedTransferWriteOpPattern |
234 | : public MaskOpRewritePattern<TransferWriteOp> { |
235 | public: |
236 | using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; |
237 | |
238 | LogicalResult |
239 | matchAndRewriteMaskableOp(TransferWriteOp writeOp, |
240 | MaskingOpInterface maskingOp, |
241 | PatternRewriter &rewriter) const override { |
242 | Type resultType = |
243 | writeOp.getResult() ? writeOp.getResult().getType() : Type(); |
244 | |
245 | // Replace the `vector.mask` operation. |
246 | rewriter.replaceOpWithNewOp<TransferWriteOp>( |
247 | maskingOp.getOperation(), resultType, writeOp.getVector(), |
248 | writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(), |
249 | maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr())); |
250 | return success(); |
251 | } |
252 | }; |
253 | |
254 | /// Lowers a masked `vector.gather` operation. |
255 | struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { |
256 | public: |
257 | using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; |
258 | |
259 | LogicalResult |
260 | matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, |
261 | PatternRewriter &rewriter) const override { |
262 | Value passthru = maskingOp.hasPassthru() |
263 | ? maskingOp.getPassthru() |
264 | : rewriter.create<arith::ConstantOp>( |
265 | gatherOp.getLoc(), |
266 | rewriter.getZeroAttr(gatherOp.getVectorType())); |
267 | |
268 | // Replace the `vector.mask` operation. |
269 | rewriter.replaceOpWithNewOp<GatherOp>( |
270 | maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), |
271 | gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), |
272 | passthru); |
273 | return success(); |
274 | } |
275 | }; |
276 | |
277 | struct LowerVectorMaskPass |
278 | : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { |
279 | using Base::Base; |
280 | |
281 | void runOnOperation() override { |
282 | Operation *op = getOperation(); |
283 | MLIRContext *context = op->getContext(); |
284 | |
285 | RewritePatternSet loweringPatterns(context); |
286 | populateVectorMaskLoweringPatternsForSideEffectingOps(patterns&: loweringPatterns); |
287 | MaskOp::getCanonicalizationPatterns(loweringPatterns, context); |
288 | |
289 | if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) |
290 | signalPassFailure(); |
291 | } |
292 | |
293 | void getDependentDialects(DialectRegistry ®istry) const override { |
294 | registry.insert<vector::VectorDialect>(); |
295 | } |
296 | }; |
297 | |
298 | } // namespace |
299 | |
300 | /// Populates instances of `MaskOpRewritePattern` to lower masked operations |
301 | /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and |
302 | /// not its nested `MaskableOpInterface`. |
303 | void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( |
304 | RewritePatternSet &patterns) { |
305 | patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, |
306 | MaskedGatherOpPattern>(arg: patterns.getContext()); |
307 | } |
308 | |
309 | std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { |
310 | return std::make_unique<LowerVectorMaskPass>(); |
311 | } |
312 | |