1//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.shape_cast' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/UB//IR/UBOps.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
19#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include <numeric>
25
26#define DEBUG_TYPE "vector-shape-cast-lowering"
27
28using namespace mlir;
29
30/// Perform the inplace update
31/// rhs <- lhs + rhs
32///
33/// where `rhs` is a number expressed in mixed base `base` with most signficant
34/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
35/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
36///
37/// Some examples where `base` is {5,3,2}:
38/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
39/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
40/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
41///
42/// Invalid:
43/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
44///
45/// Overflows not handled correctly:
46/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
47static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
48 MutableArrayRef<int64_t> rhs) {
49
50 // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
51 for (int dim : llvm::reverse(C: llvm::seq<int>(Begin: 0, End: rhs.size()))) {
52 int64_t dimBase = base[dim];
53 assert(rhs[dim] < dimBase && "rhs not in base");
54
55 int64_t incremented = rhs[dim] + lhs;
56
57 // If the incremented value excedes the dimension base, we must spill to the
58 // next most significant dimension and repeat (we might need to spill to
59 // more significant dimensions multiple times).
60 lhs = incremented / dimBase;
61 rhs[dim] = incremented % dimBase;
62 if (lhs == 0)
63 break;
64 }
65}
66
67namespace {
68
69/// shape_cast is converted to a sequence of extract, extract_strided_slice,
70/// insert_strided_slice, and insert operations. The running example will be:
71///
72/// %0 = vector.shape_cast %arg0 :
73/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
74///
75/// In this example the source and result shapes share a common suffix of 7x11.
76/// This means we can always decompose the shape_cast into extract, insert, and
77/// their strided equivalents, on vectors with shape suffix 7x11.
78///
79/// The greatest common divisor (gcd) of the first dimension preceding the
80/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
81/// on vectors with shapes that are `multiples` of (what we define as) the
82/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
83///
84/// vector<2x2x3x4x7x11xi8> to
85/// vector<8x6x7x11xi8>
86/// | ||||
87/// | ++++------------> common suffix of 7x11
88/// +-----------------> gcd(4,6) is 2 | |
89/// | | |
90/// v v v
91/// atomic shape <----- 2x7x11
92///
93///
94///
95/// The decomposition implemented in this pattern consists of a sequence of
96/// repeated steps:
97///
98/// (1) Extract vectors from the suffix of the source.
99/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
100///
101/// (2) Do extract_strided_slice down to the atomic shape.
102/// In our example this is 4x7x11 -> 2x7x11.
103///
104/// (3) Do insert_strided_slice to the suffix of the result.
105/// In our example this is 2x7x11 -> 6x7x11.
106///
107/// (4) insert these vectors into the result vector.
108/// In our example this is 6x7x11 -> 8x6x7x11.
109///
110/// These steps occur with different periods. In this example
111/// (1) occurs 12 times,
112/// (2) and (3) occur 24 times, and
113/// (4) occurs 8 times.
114///
115/// Two special cases are handled independently in this pattern
116/// (i) A shape_cast that just does leading 1 insertion/removal
117/// (ii) A shape_cast where the gcd is 1.
118///
119/// These 2 cases can have more compact IR generated by not using the generic
120/// algorithm described above.
121///
122class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
123
124 // Case (i) of description.
125 // Assumes source and result shapes are identical up to some leading ones.
126 static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
127 PatternRewriter &rewriter) {
128
129 const Location loc = shapeCast.getLoc();
130 const VectorType sourceType = shapeCast.getSourceVectorType();
131 const VectorType resultType = shapeCast.getResultVectorType();
132
133 const int64_t sourceRank = sourceType.getRank();
134 const int64_t resultRank = resultType.getRank();
135 const int64_t delta = sourceRank - resultRank;
136 const int64_t sourceLeading = delta > 0 ? delta : 0;
137 const int64_t resultLeading = delta > 0 ? 0 : -delta;
138
139 const Value source = shapeCast.getSource();
140 const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
141 const Value extracted = rewriter.create<vector::ExtractOp>(
142 loc, source, SmallVector<int64_t>(sourceLeading, 0));
143 const Value result = rewriter.create<vector::InsertOp>(
144 loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
145
146 rewriter.replaceOp(shapeCast, result);
147 return success();
148 }
149
150 // Case (ii) of description.
151 // Assumes a shape_cast where the suffix shape of the source starting at
152 // `sourceDim` and the suffix shape of the result starting at `resultDim` are
153 // identical.
154 static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
155 int64_t sourceDim,
156 int64_t resultDim,
157 PatternRewriter &rewriter) {
158
159 const Location loc = shapeCast.getLoc();
160
161 const Value source = shapeCast.getSource();
162 const ArrayRef<int64_t> sourceShape =
163 shapeCast.getSourceVectorType().getShape();
164
165 const VectorType resultType = shapeCast.getResultVectorType();
166 const ArrayRef<int64_t> resultShape = resultType.getShape();
167
168 const int64_t nSlices =
169 std::accumulate(first: sourceShape.begin(), last: sourceShape.begin() + sourceDim, init: 1,
170 binary_op: std::multiplies<int64_t>());
171
172 SmallVector<int64_t> extractIndex(sourceDim, 0);
173 SmallVector<int64_t> insertIndex(resultDim, 0);
174 Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
175
176 for (int i = 0; i < nSlices; ++i) {
177 Value extracted =
178 rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
179
180 result = rewriter.create<vector::InsertOp>(loc, extracted, result,
181 insertIndex);
182
183 inplaceAdd(lhs: 1, base: sourceShape.take_front(N: sourceDim), rhs: extractIndex);
184 inplaceAdd(lhs: 1, base: resultShape.take_front(N: resultDim), rhs: insertIndex);
185 }
186 rewriter.replaceOp(shapeCast, result);
187 return success();
188 }
189
190public:
191 using OpRewritePattern::OpRewritePattern;
192
193 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
194 PatternRewriter &rewriter) const override {
195 Location loc = op.getLoc();
196 VectorType sourceType = op.getSourceVectorType();
197 VectorType resultType = op.getResultVectorType();
198
199 if (sourceType.isScalable() || resultType.isScalable())
200 return rewriter.notifyMatchFailure(
201 op,
202 "shape_cast where vectors are scalable not handled by this pattern");
203
204 const ArrayRef<int64_t> sourceShape = sourceType.getShape();
205 const ArrayRef<int64_t> resultShape = resultType.getShape();
206 const int64_t sourceRank = sourceType.getRank();
207 const int64_t resultRank = resultType.getRank();
208 const int64_t numElms = sourceType.getNumElements();
209 const Value source = op.getSource();
210
211 // Set the first dimension (starting at the end) in the source and result
212 // respectively where the dimension sizes differ. Using the running example:
213 //
214 // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
215 // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
216 // ^ ^
217 // | |
218 // sourceSuffixStartDim is 3 |
219 // |
220 // resultSuffixStartDim is 1
221 int64_t sourceSuffixStartDim = sourceRank - 1;
222 int64_t resultSuffixStartDim = resultRank - 1;
223 while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
224 (sourceType.getDimSize(sourceSuffixStartDim) ==
225 resultType.getDimSize(resultSuffixStartDim))) {
226 --sourceSuffixStartDim;
227 --resultSuffixStartDim;
228 }
229
230 // This is the case (i) where there are just some leading ones to contend
231 // with in the source or result. It can be handled with a single
232 // extract/insert pair.
233 if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
234 return leadingOnesLowering(op, rewriter);
235
236 const int64_t sourceSuffixStartDimSize =
237 sourceType.getDimSize(sourceSuffixStartDim);
238 const int64_t resultSuffixStartDimSize =
239 resultType.getDimSize(resultSuffixStartDim);
240 const int64_t greatestCommonDivisor =
241 std::gcd(m: sourceSuffixStartDimSize, n: resultSuffixStartDimSize);
242 const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
243 const size_t extractPeriod =
244 sourceSuffixStartDimSize / greatestCommonDivisor;
245 const size_t insertPeriod =
246 resultSuffixStartDimSize / greatestCommonDivisor;
247
248 SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
249 sourceShape.end());
250 atomicShape[0] = greatestCommonDivisor;
251
252 const int64_t numAtomicElms = std::accumulate(
253 first: atomicShape.begin(), last: atomicShape.end(), init: 1, binary_op: std::multiplies<int64_t>());
254 const size_t nAtomicSlices = numElms / numAtomicElms;
255
256 // This is the case (ii) where the strided dimension size is 1. More compact
257 // IR is generated in this case if we just extract and insert the elements
258 // directly. In other words, we don't use extract_strided_slice and
259 // insert_strided_slice.
260 if (greatestCommonDivisor == 1)
261 return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
262 resultSuffixStartDim + 1, rewriter);
263
264 // The insert_strided_slice result's type
265 const ArrayRef<int64_t> insertStridedShape =
266 resultShape.drop_front(N: resultSuffixStartDim);
267 const VectorType insertStridedType =
268 VectorType::get(insertStridedShape, resultType.getElementType());
269
270 SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
271 SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
272 SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
273 SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
274 const SmallVector<int64_t> sizes(stridedSliceRank, 1);
275
276 Value extracted = {};
277 Value extractedStrided = {};
278 Value insertedSlice = {};
279 Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
280 const Value partResult =
281 rewriter.create<ub::PoisonOp>(loc, insertStridedType);
282
283 for (size_t i = 0; i < nAtomicSlices; ++i) {
284
285 const size_t extractStridedPhase = i % extractPeriod;
286 const size_t insertStridedPhase = i % insertPeriod;
287
288 // vector.extract
289 if (extractStridedPhase == 0) {
290 extracted =
291 rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
292 inplaceAdd(lhs: 1, base: sourceShape.take_front(N: sourceSuffixStartDim),
293 rhs: extractIndex);
294 }
295
296 // vector.extract_strided_slice
297 extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
298 extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>(
299 loc, extracted, extractOffsets, atomicShape, sizes);
300
301 // vector.insert_strided_slice
302 if (insertStridedPhase == 0) {
303 insertedSlice = partResult;
304 }
305 insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
306 insertedSlice = rewriter.create<vector::InsertStridedSliceOp>(
307 loc, extractedStrided, insertedSlice, insertOffsets, sizes);
308
309 // vector.insert
310 if (insertStridedPhase + 1 == insertPeriod) {
311 result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result,
312 insertIndex);
313 inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
314 insertIndex);
315 }
316 }
317 rewriter.replaceOp(op, result);
318 return success();
319 }
320};
321
322/// A shape_cast lowering for scalable vectors with a single trailing scalable
323/// dimension. This is similar to the general shape_cast lowering but makes use
324/// of vector.scalable.insert and vector.scalable.extract to move elements a
325/// subvector at a time.
326///
327/// E.g.:
328/// ```
329/// // Flatten scalable vector
330/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
331/// ```
332/// is rewritten to:
333/// ```
334/// // Flatten scalable vector
335/// %c = arith.constant dense<0> : vector<[8]xi32>
336/// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
337/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
338/// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
339/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
340/// ```
341/// or:
342/// ```
343/// // Un-flatten scalable vector
344/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
345/// ```
346/// is rewritten to:
347/// ```
348/// // Un-flatten scalable vector
349/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
350/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
351/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
352/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
353/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
354/// ```
355class ScalableShapeCastOpRewritePattern
356 : public OpRewritePattern<vector::ShapeCastOp> {
357public:
358 using OpRewritePattern::OpRewritePattern;
359
360 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
361 PatternRewriter &rewriter) const override {
362
363 Location loc = op.getLoc();
364 auto sourceVectorType = op.getSourceVectorType();
365 auto resultVectorType = op.getResultVectorType();
366 auto srcRank = sourceVectorType.getRank();
367 auto resRank = resultVectorType.getRank();
368
369 // This can only lower shape_casts where both the source and result types
370 // have a single trailing scalable dimension. This is because there are no
371 // legal representation of other scalable types in LLVM (and likely won't be
372 // soon). There are also (currently) no operations that can index or extract
373 // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
374 if (!isTrailingDimScalable(type: sourceVectorType) ||
375 !isTrailingDimScalable(type: resultVectorType)) {
376 return rewriter.notifyMatchFailure(
377 op, "trailing dims are not scalable, not handled by this pattern");
378 }
379
380 // The sizes of the trailing dimension of the source and result vectors, the
381 // size of subvector to move, and the number of elements in the vectors.
382 // These are "min" sizes as they are the size when vscale == 1.
383 auto minSourceTrailingSize = sourceVectorType.getShape().back();
384 auto minResultTrailingSize = resultVectorType.getShape().back();
385 auto minExtractionSize =
386 std::min(minSourceTrailingSize, minResultTrailingSize);
387 int64_t minNumElts = 1;
388 for (auto size : sourceVectorType.getShape())
389 minNumElts *= size;
390
391 // The subvector type to move from the source to the result. Note that this
392 // is a scalable vector. This rewrite will generate code in terms of the
393 // "min" size (vscale == 1 case), that scales to any vscale.
394 auto extractionVectorType = VectorType::get(
395 {minExtractionSize}, sourceVectorType.getElementType(), {true});
396
397 Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
398 SmallVector<int64_t> srcIdx(srcRank, 0);
399 SmallVector<int64_t> resIdx(resRank, 0);
400
401 // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
402 // once D150000 lands.
403 Value currentResultScalableVector;
404 Value currentSourceScalableVector;
405 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
406 // 1. Extract a scalable subvector from the source vector.
407 if (!currentSourceScalableVector) {
408 if (srcRank != 1) {
409 currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
410 loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
411 } else {
412 currentSourceScalableVector = op.getSource();
413 }
414 }
415 Value sourceSubVector = currentSourceScalableVector;
416 if (minExtractionSize < minSourceTrailingSize) {
417 sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
418 loc, extractionVectorType, sourceSubVector, srcIdx.back());
419 }
420
421 // 2. Insert the scalable subvector into the result vector.
422 if (!currentResultScalableVector) {
423 if (minExtractionSize == minResultTrailingSize) {
424 currentResultScalableVector = sourceSubVector;
425 } else if (resRank != 1) {
426 currentResultScalableVector = rewriter.create<vector::ExtractOp>(
427 loc, result, llvm::ArrayRef(resIdx).drop_back());
428 } else {
429 currentResultScalableVector = result;
430 }
431 }
432 if (minExtractionSize < minResultTrailingSize) {
433 currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
434 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
435 }
436
437 // 3. Update the source and result scalable vectors if needed.
438 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
439 currentResultScalableVector != result) {
440 // Finished row of result. Insert complete scalable vector into result
441 // (n-D) vector.
442 result = rewriter.create<vector::InsertOp>(
443 loc, currentResultScalableVector, result,
444 llvm::ArrayRef(resIdx).drop_back());
445 currentResultScalableVector = {};
446 }
447 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
448 // Finished row of source.
449 currentSourceScalableVector = {};
450 }
451
452 // 4. Increment the insert/extract indices, stepping by minExtractionSize
453 // for the trailing dimensions.
454 inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
455 inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
456 }
457
458 rewriter.replaceOp(op, result);
459 return success();
460 }
461
462 static bool isTrailingDimScalable(VectorType type) {
463 return type.getRank() >= 1 && type.getScalableDims().back() &&
464 !llvm::is_contained(type.getScalableDims().drop_back(), true);
465 }
466};
467
468} // namespace
469
470void mlir::vector::populateVectorShapeCastLoweringPatterns(
471 RewritePatternSet &patterns, PatternBenefit benefit) {
472 patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
473 arg: patterns.getContext(), args&: benefit);
474}
475

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp