1 | //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.shape_cast' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
25 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
26 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
27 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
28 | #include "mlir/IR/BuiltinTypes.h" |
29 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
30 | #include "mlir/IR/Location.h" |
31 | #include "mlir/IR/Matchers.h" |
32 | #include "mlir/IR/PatternMatch.h" |
33 | #include "mlir/IR/TypeUtilities.h" |
34 | #include "mlir/Interfaces/VectorInterfaces.h" |
35 | #include "mlir/Support/LogicalResult.h" |
36 | |
37 | #define DEBUG_TYPE "vector-shape-cast-lowering" |
38 | |
39 | using namespace mlir; |
40 | using namespace mlir::vector; |
41 | |
42 | namespace { |
43 | /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D |
44 | /// vectors progressively on the way to target llvm.matrix intrinsics. |
45 | /// This iterates over the most major dimension of the 2-D vector and performs |
46 | /// rewrites into: |
47 | /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D |
48 | class ShapeCastOp2DDownCastRewritePattern |
49 | : public OpRewritePattern<vector::ShapeCastOp> { |
50 | public: |
51 | using OpRewritePattern::OpRewritePattern; |
52 | |
53 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
54 | PatternRewriter &rewriter) const override { |
55 | auto sourceVectorType = op.getSourceVectorType(); |
56 | auto resultVectorType = op.getResultVectorType(); |
57 | |
58 | if (sourceVectorType.isScalable() || resultVectorType.isScalable()) |
59 | return failure(); |
60 | |
61 | if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) |
62 | return failure(); |
63 | |
64 | auto loc = op.getLoc(); |
65 | Value desc = rewriter.create<arith::ConstantOp>( |
66 | loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); |
67 | unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; |
68 | for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { |
69 | Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); |
70 | desc = rewriter.create<vector::InsertStridedSliceOp>( |
71 | loc, vec, desc, |
72 | /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); |
73 | } |
74 | rewriter.replaceOp(op, desc); |
75 | return success(); |
76 | } |
77 | }; |
78 | |
79 | /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D |
80 | /// vectors progressively. |
81 | /// This iterates over the most major dimension of the 2-D vector and performs |
82 | /// rewrites into: |
83 | /// vector.extract_strided_slice from 1-D + vector.insert into 2-D |
84 | /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. |
85 | class ShapeCastOp2DUpCastRewritePattern |
86 | : public OpRewritePattern<vector::ShapeCastOp> { |
87 | public: |
88 | using OpRewritePattern::OpRewritePattern; |
89 | |
90 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
91 | PatternRewriter &rewriter) const override { |
92 | auto sourceVectorType = op.getSourceVectorType(); |
93 | auto resultVectorType = op.getResultVectorType(); |
94 | |
95 | if (sourceVectorType.isScalable() || resultVectorType.isScalable()) |
96 | return failure(); |
97 | |
98 | if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) |
99 | return failure(); |
100 | |
101 | auto loc = op.getLoc(); |
102 | Value desc = rewriter.create<arith::ConstantOp>( |
103 | loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); |
104 | unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; |
105 | for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { |
106 | Value vec = rewriter.create<vector::ExtractStridedSliceOp>( |
107 | loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, |
108 | /*sizes=*/mostMinorVectorSize, |
109 | /*strides=*/1); |
110 | desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); |
111 | } |
112 | rewriter.replaceOp(op, desc); |
113 | return success(); |
114 | } |
115 | }; |
116 | |
117 | static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp, |
118 | int dimIdx, int initialStep = 1) { |
119 | int step = initialStep; |
120 | for (int d = dimIdx; d >= 0; d--) { |
121 | idx[d] += step; |
122 | if (idx[d] >= tp.getDimSize(d)) { |
123 | idx[d] = 0; |
124 | step = 1; |
125 | } else { |
126 | break; |
127 | } |
128 | } |
129 | } |
130 | |
131 | // We typically should not lower general shape cast operations into data |
132 | // movement instructions, since the assumption is that these casts are |
133 | // optimized away during progressive lowering. For completeness, however, |
134 | // we fall back to a reference implementation that moves all elements |
135 | // into the right place if we get here. |
136 | class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { |
137 | public: |
138 | using OpRewritePattern::OpRewritePattern; |
139 | |
140 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
141 | PatternRewriter &rewriter) const override { |
142 | Location loc = op.getLoc(); |
143 | auto sourceVectorType = op.getSourceVectorType(); |
144 | auto resultVectorType = op.getResultVectorType(); |
145 | |
146 | if (sourceVectorType.isScalable() || resultVectorType.isScalable()) |
147 | return failure(); |
148 | |
149 | // Special case 2D / 1D lowerings with better implementations. |
150 | // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. |
151 | int64_t srcRank = sourceVectorType.getRank(); |
152 | int64_t resRank = resultVectorType.getRank(); |
153 | if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) |
154 | return failure(); |
155 | |
156 | // Generic ShapeCast lowering path goes all the way down to unrolled scalar |
157 | // extract/insert chains. |
158 | // TODO: consider evolving the semantics to only allow 1D source or dest and |
159 | // drop this potentially very expensive lowering. |
160 | // Compute number of elements involved in the reshape. |
161 | int64_t numElts = 1; |
162 | for (int64_t r = 0; r < srcRank; r++) |
163 | numElts *= sourceVectorType.getDimSize(r); |
164 | // Replace with data movement operations: |
165 | // x[0,0,0] = y[0,0] |
166 | // x[0,0,1] = y[0,1] |
167 | // x[0,1,0] = y[0,2] |
168 | // etc., incrementing the two index vectors "row-major" |
169 | // within the source and result shape. |
170 | SmallVector<int64_t> srcIdx(srcRank); |
171 | SmallVector<int64_t> resIdx(resRank); |
172 | Value result = rewriter.create<arith::ConstantOp>( |
173 | loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); |
174 | for (int64_t i = 0; i < numElts; i++) { |
175 | if (i != 0) { |
176 | incIdx(srcIdx, sourceVectorType, srcRank - 1); |
177 | incIdx(resIdx, resultVectorType, resRank - 1); |
178 | } |
179 | |
180 | Value ; |
181 | if (srcRank == 0) { |
182 | // 0-D vector special case |
183 | assert(srcIdx.empty() && "Unexpected indices for 0-D vector" ); |
184 | extract = rewriter.create<vector::ExtractElementOp>( |
185 | loc, op.getSourceVectorType().getElementType(), op.getSource()); |
186 | } else { |
187 | extract = |
188 | rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); |
189 | } |
190 | |
191 | if (resRank == 0) { |
192 | // 0-D vector special case |
193 | assert(resIdx.empty() && "Unexpected indices for 0-D vector" ); |
194 | result = rewriter.create<vector::InsertElementOp>(loc, extract, result); |
195 | } else { |
196 | result = |
197 | rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); |
198 | } |
199 | } |
200 | rewriter.replaceOp(op, result); |
201 | return success(); |
202 | } |
203 | }; |
204 | |
205 | /// A shape_cast lowering for scalable vectors with a single trailing scalable |
206 | /// dimension. This is similar to the general shape_cast lowering but makes use |
207 | /// of vector.scalable.insert and vector.scalable.extract to move elements a |
208 | /// subvector at a time. |
209 | /// |
210 | /// E.g.: |
211 | /// ``` |
212 | /// // Flatten scalable vector |
213 | /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> |
214 | /// ``` |
215 | /// is rewritten to: |
216 | /// ``` |
217 | /// // Flatten scalable vector |
218 | /// %c = arith.constant dense<0> : vector<[8]xi32> |
219 | /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> |
220 | /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> |
221 | /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> |
222 | /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> |
223 | /// ``` |
224 | /// or: |
225 | /// ``` |
226 | /// // Un-flatten scalable vector |
227 | /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> |
228 | /// ``` |
229 | /// is rewritten to: |
230 | /// ``` |
231 | /// // Un-flatten scalable vector |
232 | /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> |
233 | /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> |
234 | /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> |
235 | /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> |
236 | /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> |
237 | /// ``` |
238 | class ScalableShapeCastOpRewritePattern |
239 | : public OpRewritePattern<vector::ShapeCastOp> { |
240 | public: |
241 | using OpRewritePattern::OpRewritePattern; |
242 | |
243 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
244 | PatternRewriter &rewriter) const override { |
245 | |
246 | Location loc = op.getLoc(); |
247 | auto sourceVectorType = op.getSourceVectorType(); |
248 | auto resultVectorType = op.getResultVectorType(); |
249 | auto srcRank = sourceVectorType.getRank(); |
250 | auto resRank = resultVectorType.getRank(); |
251 | |
252 | // This can only lower shape_casts where both the source and result types |
253 | // have a single trailing scalable dimension. This is because there are no |
254 | // legal representation of other scalable types in LLVM (and likely won't be |
255 | // soon). There are also (currently) no operations that can index or extract |
256 | // from >= 2D scalable vectors or scalable vectors of fixed vectors. |
257 | if (!isTrailingDimScalable(type: sourceVectorType) || |
258 | !isTrailingDimScalable(type: resultVectorType)) { |
259 | return failure(); |
260 | } |
261 | |
262 | // The sizes of the trailing dimension of the source and result vectors, the |
263 | // size of subvector to move, and the number of elements in the vectors. |
264 | // These are "min" sizes as they are the size when vscale == 1. |
265 | auto minSourceTrailingSize = sourceVectorType.getShape().back(); |
266 | auto minResultTrailingSize = resultVectorType.getShape().back(); |
267 | auto = |
268 | std::min(minSourceTrailingSize, minResultTrailingSize); |
269 | int64_t minNumElts = 1; |
270 | for (auto size : sourceVectorType.getShape()) |
271 | minNumElts *= size; |
272 | |
273 | // The subvector type to move from the source to the result. Note that this |
274 | // is a scalable vector. This rewrite will generate code in terms of the |
275 | // "min" size (vscale == 1 case), that scales to any vscale. |
276 | auto = VectorType::get( |
277 | {minExtractionSize}, sourceVectorType.getElementType(), {true}); |
278 | |
279 | Value result = rewriter.create<arith::ConstantOp>( |
280 | loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); |
281 | |
282 | SmallVector<int64_t> srcIdx(srcRank); |
283 | SmallVector<int64_t> resIdx(resRank); |
284 | |
285 | // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) |
286 | // once D150000 lands. |
287 | Value currentResultScalableVector; |
288 | Value currentSourceScalableVector; |
289 | for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { |
290 | // 1. Extract a scalable subvector from the source vector. |
291 | if (!currentSourceScalableVector) { |
292 | if (srcRank != 1) { |
293 | currentSourceScalableVector = rewriter.create<vector::ExtractOp>( |
294 | loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); |
295 | } else { |
296 | currentSourceScalableVector = op.getSource(); |
297 | } |
298 | } |
299 | Value sourceSubVector = currentSourceScalableVector; |
300 | if (minExtractionSize < minSourceTrailingSize) { |
301 | sourceSubVector = rewriter.create<vector::ScalableExtractOp>( |
302 | loc, extractionVectorType, sourceSubVector, srcIdx.back()); |
303 | } |
304 | |
305 | // 2. Insert the scalable subvector into the result vector. |
306 | if (!currentResultScalableVector) { |
307 | if (minExtractionSize == minResultTrailingSize) { |
308 | currentResultScalableVector = sourceSubVector; |
309 | } else if (resRank != 1) { |
310 | currentResultScalableVector = rewriter.create<vector::ExtractOp>( |
311 | loc, result, llvm::ArrayRef(resIdx).drop_back()); |
312 | } else { |
313 | currentResultScalableVector = result; |
314 | } |
315 | } |
316 | if (minExtractionSize < minResultTrailingSize) { |
317 | currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>( |
318 | loc, sourceSubVector, currentResultScalableVector, resIdx.back()); |
319 | } |
320 | |
321 | // 3. Update the source and result scalable vectors if needed. |
322 | if (resIdx.back() + minExtractionSize >= minResultTrailingSize && |
323 | currentResultScalableVector != result) { |
324 | // Finished row of result. Insert complete scalable vector into result |
325 | // (n-D) vector. |
326 | result = rewriter.create<vector::InsertOp>( |
327 | loc, currentResultScalableVector, result, |
328 | llvm::ArrayRef(resIdx).drop_back()); |
329 | currentResultScalableVector = {}; |
330 | } |
331 | if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { |
332 | // Finished row of source. |
333 | currentSourceScalableVector = {}; |
334 | } |
335 | |
336 | // 4. Increment the insert/extract indices, stepping by minExtractionSize |
337 | // for the trailing dimensions. |
338 | incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize); |
339 | incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize); |
340 | } |
341 | |
342 | rewriter.replaceOp(op, result); |
343 | return success(); |
344 | } |
345 | |
346 | static bool isTrailingDimScalable(VectorType type) { |
347 | return type.getRank() >= 1 && type.getScalableDims().back() && |
348 | !llvm::is_contained(type.getScalableDims().drop_back(), true); |
349 | } |
350 | }; |
351 | |
352 | } // namespace |
353 | |
354 | void mlir::vector::populateVectorShapeCastLoweringPatterns( |
355 | RewritePatternSet &patterns, PatternBenefit benefit) { |
356 | patterns.add<ShapeCastOp2DDownCastRewritePattern, |
357 | ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern, |
358 | ScalableShapeCastOpRewritePattern>(arg: patterns.getContext(), |
359 | args&: benefit); |
360 | } |
361 | |