1//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.shape_cast' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27#include "mlir/IR/BuiltinAttributeInterfaces.h"
28#include "mlir/IR/BuiltinTypes.h"
29#include "mlir/IR/ImplicitLocOpBuilder.h"
30#include "mlir/IR/Location.h"
31#include "mlir/IR/Matchers.h"
32#include "mlir/IR/PatternMatch.h"
33#include "mlir/IR/TypeUtilities.h"
34#include "mlir/Interfaces/VectorInterfaces.h"
35#include "mlir/Support/LogicalResult.h"
36
37#define DEBUG_TYPE "vector-shape-cast-lowering"
38
39using namespace mlir;
40using namespace mlir::vector;
41
42namespace {
43/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
44/// vectors progressively on the way to target llvm.matrix intrinsics.
45/// This iterates over the most major dimension of the 2-D vector and performs
46/// rewrites into:
47/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
48class ShapeCastOp2DDownCastRewritePattern
49 : public OpRewritePattern<vector::ShapeCastOp> {
50public:
51 using OpRewritePattern::OpRewritePattern;
52
53 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
54 PatternRewriter &rewriter) const override {
55 auto sourceVectorType = op.getSourceVectorType();
56 auto resultVectorType = op.getResultVectorType();
57
58 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
59 return failure();
60
61 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
62 return failure();
63
64 auto loc = op.getLoc();
65 Value desc = rewriter.create<arith::ConstantOp>(
66 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
67 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
68 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
69 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
70 desc = rewriter.create<vector::InsertStridedSliceOp>(
71 loc, vec, desc,
72 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
73 }
74 rewriter.replaceOp(op, desc);
75 return success();
76 }
77};
78
79/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
80/// vectors progressively.
81/// This iterates over the most major dimension of the 2-D vector and performs
82/// rewrites into:
83/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
84/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
85class ShapeCastOp2DUpCastRewritePattern
86 : public OpRewritePattern<vector::ShapeCastOp> {
87public:
88 using OpRewritePattern::OpRewritePattern;
89
90 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
91 PatternRewriter &rewriter) const override {
92 auto sourceVectorType = op.getSourceVectorType();
93 auto resultVectorType = op.getResultVectorType();
94
95 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
96 return failure();
97
98 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
99 return failure();
100
101 auto loc = op.getLoc();
102 Value desc = rewriter.create<arith::ConstantOp>(
103 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
104 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
105 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
106 Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
107 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
108 /*sizes=*/mostMinorVectorSize,
109 /*strides=*/1);
110 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
111 }
112 rewriter.replaceOp(op, desc);
113 return success();
114 }
115};
116
117static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
118 int dimIdx, int initialStep = 1) {
119 int step = initialStep;
120 for (int d = dimIdx; d >= 0; d--) {
121 idx[d] += step;
122 if (idx[d] >= tp.getDimSize(d)) {
123 idx[d] = 0;
124 step = 1;
125 } else {
126 break;
127 }
128 }
129}
130
131// We typically should not lower general shape cast operations into data
132// movement instructions, since the assumption is that these casts are
133// optimized away during progressive lowering. For completeness, however,
134// we fall back to a reference implementation that moves all elements
135// into the right place if we get here.
136class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
137public:
138 using OpRewritePattern::OpRewritePattern;
139
140 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
141 PatternRewriter &rewriter) const override {
142 Location loc = op.getLoc();
143 auto sourceVectorType = op.getSourceVectorType();
144 auto resultVectorType = op.getResultVectorType();
145
146 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
147 return failure();
148
149 // Special case 2D / 1D lowerings with better implementations.
150 // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
151 int64_t srcRank = sourceVectorType.getRank();
152 int64_t resRank = resultVectorType.getRank();
153 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
154 return failure();
155
156 // Generic ShapeCast lowering path goes all the way down to unrolled scalar
157 // extract/insert chains.
158 // TODO: consider evolving the semantics to only allow 1D source or dest and
159 // drop this potentially very expensive lowering.
160 // Compute number of elements involved in the reshape.
161 int64_t numElts = 1;
162 for (int64_t r = 0; r < srcRank; r++)
163 numElts *= sourceVectorType.getDimSize(r);
164 // Replace with data movement operations:
165 // x[0,0,0] = y[0,0]
166 // x[0,0,1] = y[0,1]
167 // x[0,1,0] = y[0,2]
168 // etc., incrementing the two index vectors "row-major"
169 // within the source and result shape.
170 SmallVector<int64_t> srcIdx(srcRank);
171 SmallVector<int64_t> resIdx(resRank);
172 Value result = rewriter.create<arith::ConstantOp>(
173 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
174 for (int64_t i = 0; i < numElts; i++) {
175 if (i != 0) {
176 incIdx(srcIdx, sourceVectorType, srcRank - 1);
177 incIdx(resIdx, resultVectorType, resRank - 1);
178 }
179
180 Value extract;
181 if (srcRank == 0) {
182 // 0-D vector special case
183 assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
184 extract = rewriter.create<vector::ExtractElementOp>(
185 loc, op.getSourceVectorType().getElementType(), op.getSource());
186 } else {
187 extract =
188 rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
189 }
190
191 if (resRank == 0) {
192 // 0-D vector special case
193 assert(resIdx.empty() && "Unexpected indices for 0-D vector");
194 result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
195 } else {
196 result =
197 rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
198 }
199 }
200 rewriter.replaceOp(op, result);
201 return success();
202 }
203};
204
205/// A shape_cast lowering for scalable vectors with a single trailing scalable
206/// dimension. This is similar to the general shape_cast lowering but makes use
207/// of vector.scalable.insert and vector.scalable.extract to move elements a
208/// subvector at a time.
209///
210/// E.g.:
211/// ```
212/// // Flatten scalable vector
213/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
214/// ```
215/// is rewritten to:
216/// ```
217/// // Flatten scalable vector
218/// %c = arith.constant dense<0> : vector<[8]xi32>
219/// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
220/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
221/// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
222/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
223/// ```
224/// or:
225/// ```
226/// // Un-flatten scalable vector
227/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
228/// ```
229/// is rewritten to:
230/// ```
231/// // Un-flatten scalable vector
232/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
233/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
234/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
235/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
236/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
237/// ```
238class ScalableShapeCastOpRewritePattern
239 : public OpRewritePattern<vector::ShapeCastOp> {
240public:
241 using OpRewritePattern::OpRewritePattern;
242
243 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
244 PatternRewriter &rewriter) const override {
245
246 Location loc = op.getLoc();
247 auto sourceVectorType = op.getSourceVectorType();
248 auto resultVectorType = op.getResultVectorType();
249 auto srcRank = sourceVectorType.getRank();
250 auto resRank = resultVectorType.getRank();
251
252 // This can only lower shape_casts where both the source and result types
253 // have a single trailing scalable dimension. This is because there are no
254 // legal representation of other scalable types in LLVM (and likely won't be
255 // soon). There are also (currently) no operations that can index or extract
256 // from >= 2D scalable vectors or scalable vectors of fixed vectors.
257 if (!isTrailingDimScalable(type: sourceVectorType) ||
258 !isTrailingDimScalable(type: resultVectorType)) {
259 return failure();
260 }
261
262 // The sizes of the trailing dimension of the source and result vectors, the
263 // size of subvector to move, and the number of elements in the vectors.
264 // These are "min" sizes as they are the size when vscale == 1.
265 auto minSourceTrailingSize = sourceVectorType.getShape().back();
266 auto minResultTrailingSize = resultVectorType.getShape().back();
267 auto minExtractionSize =
268 std::min(minSourceTrailingSize, minResultTrailingSize);
269 int64_t minNumElts = 1;
270 for (auto size : sourceVectorType.getShape())
271 minNumElts *= size;
272
273 // The subvector type to move from the source to the result. Note that this
274 // is a scalable vector. This rewrite will generate code in terms of the
275 // "min" size (vscale == 1 case), that scales to any vscale.
276 auto extractionVectorType = VectorType::get(
277 {minExtractionSize}, sourceVectorType.getElementType(), {true});
278
279 Value result = rewriter.create<arith::ConstantOp>(
280 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
281
282 SmallVector<int64_t> srcIdx(srcRank);
283 SmallVector<int64_t> resIdx(resRank);
284
285 // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
286 // once D150000 lands.
287 Value currentResultScalableVector;
288 Value currentSourceScalableVector;
289 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
290 // 1. Extract a scalable subvector from the source vector.
291 if (!currentSourceScalableVector) {
292 if (srcRank != 1) {
293 currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
294 loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
295 } else {
296 currentSourceScalableVector = op.getSource();
297 }
298 }
299 Value sourceSubVector = currentSourceScalableVector;
300 if (minExtractionSize < minSourceTrailingSize) {
301 sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
302 loc, extractionVectorType, sourceSubVector, srcIdx.back());
303 }
304
305 // 2. Insert the scalable subvector into the result vector.
306 if (!currentResultScalableVector) {
307 if (minExtractionSize == minResultTrailingSize) {
308 currentResultScalableVector = sourceSubVector;
309 } else if (resRank != 1) {
310 currentResultScalableVector = rewriter.create<vector::ExtractOp>(
311 loc, result, llvm::ArrayRef(resIdx).drop_back());
312 } else {
313 currentResultScalableVector = result;
314 }
315 }
316 if (minExtractionSize < minResultTrailingSize) {
317 currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
318 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
319 }
320
321 // 3. Update the source and result scalable vectors if needed.
322 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
323 currentResultScalableVector != result) {
324 // Finished row of result. Insert complete scalable vector into result
325 // (n-D) vector.
326 result = rewriter.create<vector::InsertOp>(
327 loc, currentResultScalableVector, result,
328 llvm::ArrayRef(resIdx).drop_back());
329 currentResultScalableVector = {};
330 }
331 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
332 // Finished row of source.
333 currentSourceScalableVector = {};
334 }
335
336 // 4. Increment the insert/extract indices, stepping by minExtractionSize
337 // for the trailing dimensions.
338 incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
339 incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
340 }
341
342 rewriter.replaceOp(op, result);
343 return success();
344 }
345
346 static bool isTrailingDimScalable(VectorType type) {
347 return type.getRank() >= 1 && type.getScalableDims().back() &&
348 !llvm::is_contained(type.getScalableDims().drop_back(), true);
349 }
350};
351
352} // namespace
353
354void mlir::vector::populateVectorShapeCastLoweringPatterns(
355 RewritePatternSet &patterns, PatternBenefit benefit) {
356 patterns.add<ShapeCastOp2DDownCastRewritePattern,
357 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358 ScalableShapeCastOpRewritePattern>(arg: patterns.getContext(),
359 args&: benefit);
360}
361

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp