1//===- VectorLinearize.cpp - vector linearization transforms --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns and pass for linearizing ND vectors into 1D.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/UB/IR/UBOps.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/ArrayRef.h"
23#include <cstdint>
24#include <numeric>
25#include <optional>
26
27using namespace mlir;
28
29static FailureOr<Attribute>
30linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
31 VectorType resType, Attribute value) {
32
33 if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(Val&: value)) {
34 if (resType.isScalable() && !isa<SplatElementsAttr>(Val: value))
35 return rewriter.notifyMatchFailure(
36 arg&: loc,
37 msg: "Cannot linearize a constant scalable vector that's not a splat");
38
39 return dstElementsAttr.reshape(newType: resType);
40 }
41
42 if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(Val&: value))
43 return poisonAttr;
44
45 return rewriter.notifyMatchFailure(arg&: loc, msg: "unsupported attr type");
46}
47
48namespace {
49
50struct LinearizeConstantLike final
51 : OpTraitConversionPattern<OpTrait::ConstantLike> {
52 using OpTraitConversionPattern::OpTraitConversionPattern;
53
54 LinearizeConstantLike(const TypeConverter &typeConverter,
55 MLIRContext *context, PatternBenefit benefit = 1)
56 : OpTraitConversionPattern(typeConverter, context, benefit) {}
57 LogicalResult
58 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59 ConversionPatternRewriter &rewriter) const override {
60 Location loc = op->getLoc();
61 if (op->getNumResults() != 1)
62 return rewriter.notifyMatchFailure(arg&: loc, msg: "expected 1 result");
63
64 const TypeConverter &typeConverter = *getTypeConverter();
65 auto resType =
66 typeConverter.convertType<VectorType>(t: op->getResult(idx: 0).getType());
67 assert(resType && "expected 1-D vector type");
68
69 StringAttr attrName = rewriter.getStringAttr(bytes: "value");
70 Attribute value = op->getAttr(name: attrName);
71 if (!value)
72 return rewriter.notifyMatchFailure(arg&: loc, msg: "no 'value' attr");
73
74 FailureOr<Attribute> newValue =
75 linearizeConstAttr(loc, rewriter, resType, value);
76 if (failed(Result: newValue))
77 return failure();
78
79 FailureOr<Operation *> convertResult =
80 convertOpResultTypes(op, /*operands=*/{}, converter: typeConverter, rewriter);
81 if (failed(Result: convertResult))
82 return failure();
83
84 Operation *newOp = *convertResult;
85 newOp->setAttr(name: attrName, value: *newValue);
86 rewriter.replaceOp(op, newOp);
87 return success();
88 }
89};
90
91struct LinearizeVectorizable final
92 : OpTraitConversionPattern<OpTrait::Vectorizable> {
93 using OpTraitConversionPattern::OpTraitConversionPattern;
94
95public:
96 LinearizeVectorizable(const TypeConverter &typeConverter,
97 MLIRContext *context, PatternBenefit benefit = 1)
98 : OpTraitConversionPattern(typeConverter, context, benefit) {}
99 LogicalResult
100 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101 ConversionPatternRewriter &rewriter) const override {
102 FailureOr<Operation *> newOp =
103 convertOpResultTypes(op, operands, converter: *getTypeConverter(), rewriter);
104 if (failed(Result: newOp))
105 return failure();
106
107 rewriter.replaceOp(op, newValues: (*newOp)->getResults());
108 return success();
109 }
110};
111
112template <typename TOp>
113static bool stridesAllOne(TOp op) {
114 static_assert(
115 std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116 std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117 "expected vector.extract_strided_slice or vector.insert_strided_slice");
118 ArrayAttr strides = op.getStrides();
119 return llvm::all_of(Range&: strides, P: isOneInteger);
120}
121
122/// Convert an array of attributes into a vector of integers, if possible.
123static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124 if (!attrs)
125 return failure();
126 SmallVector<int64_t> ints;
127 ints.reserve(N: attrs.size());
128 for (auto attr : attrs) {
129 if (auto intAttr = dyn_cast<IntegerAttr>(Val&: attr)) {
130 ints.push_back(Elt: intAttr.getInt());
131 } else {
132 return failure();
133 }
134 }
135 return ints;
136}
137
138/// Consider inserting a vector of shape `small` into a vector of shape `large`,
139/// at position `offsets`: this function enumeratates all the indices in `large`
140/// that are written to. The enumeration is with row-major ordering.
141///
142/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143/// positions written to are (1,3) and (1,4), which have linearized indices 8
144/// and 9. So [8,9] is returned.
145///
146/// The length of the returned vector is equal to the number of elements in
147/// the shape `small` (i.e. the product of dimensions of `small`).
148SmallVector<int64_t> static getStridedSliceInsertionIndices(
149 ArrayRef<int64_t> small, ArrayRef<int64_t> large,
150 ArrayRef<int64_t> offsets) {
151
152 // Example of alignment between, `large`, `small` and `offsets`:
153 // large = 4, 5, 6, 7, 8
154 // small = 1, 6, 7, 8
155 // offsets = 2, 3, 0
156 //
157 // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158 assert((large.size() >= small.size()) &&
159 "rank of 'large' cannot be lower than rank of 'small'");
160 assert((large.size() >= offsets.size()) &&
161 "rank of 'large' cannot be lower than the number of offsets");
162 unsigned delta = large.size() - small.size();
163 unsigned nOffsets = offsets.size();
164 auto getSmall = [&](int64_t i) -> int64_t {
165 return i >= delta ? small[i - delta] : 1;
166 };
167 auto getOffset = [&](int64_t i) -> int64_t {
168 return i < nOffsets ? offsets[i] : 0;
169 };
170
171 // Using 2 vectors of indices, at each iteration populate the updated set of
172 // indices based on the old set of indices, and the size of the small vector
173 // in the current iteration.
174 SmallVector<int64_t> indices{0};
175 int64_t stride = 1;
176 for (int i = large.size() - 1; i >= 0; --i) {
177 int64_t currentSize = indices.size();
178 int64_t smallSize = getSmall(i);
179 int64_t nextSize = currentSize * smallSize;
180 SmallVector<int64_t> nextIndices(nextSize);
181 int64_t *base = nextIndices.begin();
182 int64_t offset = getOffset(i) * stride;
183 for (int j = 0; j < smallSize; ++j) {
184 for (int k = 0; k < currentSize; ++k) {
185 base[k] = indices[k] + offset;
186 }
187 offset += stride;
188 base += currentSize;
189 }
190 stride *= large[i];
191 indices = std::move(nextIndices);
192 }
193 return indices;
194}
195
196/// This pattern converts a vector.extract_strided_slice operation into a
197/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198///
199/// For example, the following:
200///
201/// ```
202/// vector.extract_strided_slice %source
203/// { offsets = [..], strides = [..], sizes = [..] }
204/// ```
205///
206/// is converted to :
207/// ```
208/// %source_1d = vector.shape_cast %source
209/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
210/// %out_nd = vector.shape_cast %out_1d
211/// ```
212///
213/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
214/// vector.extract_strided_slice operation.
215struct LinearizeVectorExtractStridedSlice final
216 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
217 using OpConversionPattern::OpConversionPattern;
218 LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
219 MLIRContext *context,
220 PatternBenefit benefit = 1)
221 : OpConversionPattern(typeConverter, context, benefit) {}
222
223 LogicalResult
224 matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
225 OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const override {
227
228 VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229 t: extractStridedSliceOp.getType());
230 assert(flatOutputType && "vector type expected");
231
232 // Expect a legalization failure if the strides are not all 1 (if ever the
233 // verifier for extract_strided_slice allows non-1 strides).
234 if (!stridesAllOne(op: extractStridedSliceOp)) {
235 return rewriter.notifyMatchFailure(
236 arg&: extractStridedSliceOp,
237 msg: "extract_strided_slice with strides != 1 not supported");
238 }
239
240 FailureOr<SmallVector<int64_t>> offsets =
241 intsFromArrayAttr(attrs: extractStridedSliceOp.getOffsets());
242 if (failed(Result: offsets)) {
243 return rewriter.notifyMatchFailure(arg&: extractStridedSliceOp,
244 msg: "failed to get integer offsets");
245 }
246
247 ArrayRef<int64_t> inputShape =
248 extractStridedSliceOp.getSourceVectorType().getShape();
249
250 ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251
252 SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253 small: outputShape, large: inputShape, offsets: offsets.value());
254
255 Value srcVector = adaptor.getVector();
256 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257 op: extractStridedSliceOp, args&: flatOutputType, args&: srcVector, args&: srcVector, args&: indices);
258 return success();
259 }
260};
261
262/// This pattern converts a vector.insert_strided_slice operation into a
263/// vector.shuffle operation that has rank-1 (linearized) operands and result.
264///
265/// For example, the following:
266/// ```
267/// %0 = vector.insert_strided_slice %to_store, %into
268/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
269/// : vector<2x2xi8> into vector<2x1x3x2xi8>
270/// ```
271///
272/// is converted to
273/// ```
274/// %to_store_1d
275/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
276/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
277/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
278/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
279/// ```
280///
281/// where shuffle_indices_1d in this case is
282/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
283/// ^^^^^^^^^^^^^^
284/// to_store_1d
285///
286struct LinearizeVectorInsertStridedSlice final
287 : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
288 using OpConversionPattern::OpConversionPattern;
289 LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
290 MLIRContext *context,
291 PatternBenefit benefit = 1)
292 : OpConversionPattern(typeConverter, context, benefit) {}
293
294 LogicalResult
295 matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
296 OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298
299 // Expect a legalization failure if the strides are not all 1 (if ever the
300 // verifier for insert_strided_slice allows non-1 strides).
301 if (!stridesAllOne(op: insertStridedSliceOp)) {
302 return rewriter.notifyMatchFailure(
303 arg&: insertStridedSliceOp,
304 msg: "insert_strided_slice with strides != 1 not supported");
305 }
306
307 VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308 ArrayRef<int64_t> inputShape = inputType.getShape();
309
310 VectorType outputType = insertStridedSliceOp.getType();
311 ArrayRef<int64_t> outputShape = outputType.getShape();
312 int64_t nOutputElements = outputType.getNumElements();
313
314 FailureOr<SmallVector<int64_t>> offsets =
315 intsFromArrayAttr(attrs: insertStridedSliceOp.getOffsets());
316 if (failed(Result: offsets)) {
317 return rewriter.notifyMatchFailure(arg&: insertStridedSliceOp,
318 msg: "failed to get integer offsets");
319 }
320 SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321 small: inputShape, large: outputShape, offsets: offsets.value());
322
323 SmallVector<int64_t> indices(nOutputElements);
324 std::iota(first: indices.begin(), last: indices.end(), value: 0);
325 for (auto [index, sliceIndex] : llvm::enumerate(First&: sliceIndices)) {
326 indices[sliceIndex] = index + nOutputElements;
327 }
328
329 Value flatToStore = adaptor.getValueToStore();
330 Value flatDest = adaptor.getDest();
331 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(op: insertStridedSliceOp,
332 args: flatDest.getType(), args&: flatDest,
333 args&: flatToStore, args&: indices);
334 return success();
335 }
336};
337
338/// This pattern converts the ShuffleOp that works on nD (n > 1)
339/// vectors to a ShuffleOp that works on linearized vectors.
340/// Following,
341/// vector.shuffle %v1, %v2 [ shuffle_indices ]
342/// is converted to :
343/// %v1_1d = vector.shape_cast %v1
344/// %v2_1d = vector.shape_cast %v2
345/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
346/// %out_nd = vector.shape_cast %out_1d
347// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
348/// of the original shuffle operation.
349struct LinearizeVectorShuffle final
350 : public OpConversionPattern<vector::ShuffleOp> {
351 using OpConversionPattern::OpConversionPattern;
352 LinearizeVectorShuffle(const TypeConverter &typeConverter,
353 MLIRContext *context, PatternBenefit benefit = 1)
354 : OpConversionPattern(typeConverter, context, benefit) {}
355
356 LogicalResult
357 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 VectorType dstType =
360 getTypeConverter()->convertType<VectorType>(t: shuffleOp.getType());
361 assert(dstType && "vector type destination expected.");
362
363 Value vec1 = adaptor.getV1();
364 Value vec2 = adaptor.getV2();
365 int shuffleSliceLen = 1;
366 int rank = shuffleOp.getV1().getType().getRank();
367
368 // If rank > 1, we need to do the shuffle in the granularity of slices
369 // instead of scalars. Size of the slice is equal to the rank-1 innermost
370 // dims. Mask of the shuffle op specifies which slice to take from the
371 // outermost dim.
372 if (rank > 1) {
373 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
374 for (unsigned i = 1; i < shape.size(); ++i) {
375 shuffleSliceLen *= shape[i];
376 }
377 }
378
379 // For each value in the mask, we generate the indices of the source vectors
380 // that need to be shuffled to the destination vector. If shuffleSliceLen >
381 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
382 // elements) instead of scalars.
383 ArrayRef<int64_t> mask = shuffleOp.getMask();
384 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
385 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
386 for (auto [i, value] : llvm::enumerate(First&: mask)) {
387 std::iota(first: indices.begin() + shuffleSliceLen * i,
388 last: indices.begin() + shuffleSliceLen * (i + 1),
389 value: shuffleSliceLen * value);
390 }
391
392 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(op: shuffleOp, args&: dstType, args&: vec1,
393 args&: vec2, args&: indices);
394 return success();
395 }
396};
397
398/// This pattern linearizes `vector.extract` operations. It generates a 1-D
399/// version of the `vector.extract` operation when extracting a scalar from a
400/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
401/// subvector from a larger vector.
402///
403/// Example #1:
404///
405/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406///
407/// is converted to:
408///
409/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411/// 24, 25, 26, 27, 28, 29, 30, 31] :
412/// vector<32xf32>, vector<32xf32>
413/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414///
415/// Example #2:
416///
417/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418///
419/// is converted to:
420///
421/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422/// %1 = vector.extract %0[6] : i32 from vector<8xi32>
423///
424struct LinearizeVectorExtract final
425 : public OpConversionPattern<vector::ExtractOp> {
426 using OpConversionPattern::OpConversionPattern;
427 LinearizeVectorExtract(const TypeConverter &typeConverter,
428 MLIRContext *context, PatternBenefit benefit = 1)
429 : OpConversionPattern(typeConverter, context, benefit) {}
430 LogicalResult
431 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter) const override {
433 Type dstTy = getTypeConverter()->convertType(t: extractOp.getType());
434 assert(dstTy && "expected 1-D vector type");
435
436 // Dynamic position is not supported.
437 if (extractOp.hasDynamicPosition())
438 return rewriter.notifyMatchFailure(arg&: extractOp,
439 msg: "dynamic position is not supported.");
440
441 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
442 int64_t size = extractOp.getVector().getType().getNumElements();
443
444 // Compute linearized offset.
445 int64_t linearizedOffset = 0;
446 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
447 for (auto [i, off] : llvm::enumerate(First&: offsets)) {
448 size /= shape[i];
449 linearizedOffset += offsets[i] * size;
450 }
451
452 Value srcVector = adaptor.getVector();
453 if (!isa<VectorType>(Val: extractOp.getType())) {
454 // Scalar case: generate a 1-D extract.
455 Value result = rewriter.createOrFold<vector::ExtractOp>(
456 location: extractOp.getLoc(), args&: srcVector, args&: linearizedOffset);
457 rewriter.replaceOp(op: extractOp, newValues: result);
458 return success();
459 }
460
461 // Vector case: generate a shuffle.
462
463 llvm::SmallVector<int64_t, 2> indices(size);
464 std::iota(first: indices.begin(), last: indices.end(), value: linearizedOffset);
465 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(op: extractOp, args&: dstTy, args&: srcVector,
466 args&: srcVector, args&: indices);
467
468 return success();
469 }
470};
471
472/// This pattern linearizes `vector.insert` operations. It generates a 1-D
473/// version of the `vector.insert` operation when inserting a scalar into a
474/// vector. It generates a 1-D `vector.shuffle` operation when inserting a
475/// vector into another vector.
476///
477/// Example #1:
478///
479/// %0 = vector.insert %source, %destination[0] :
480/// vector<2x4xf32> into vector<2x2x4xf32>
481///
482/// is converted to:
483///
484/// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
485/// %1 = vector.shape_cast %destination :
486/// vector<2x2x4xf32> to vector<16xf32>
487/// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
488/// 8, 9, 10, 11, 12, 13, 14, 15] :
489/// vector<16xf32>, vector<8xf32>
490/// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
491///
492/// Example #2:
493///
494/// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
495///
496/// is converted to:
497///
498/// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
499/// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
500/// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
501///
502struct LinearizeVectorInsert final
503 : public OpConversionPattern<vector::InsertOp> {
504 using OpConversionPattern::OpConversionPattern;
505 LinearizeVectorInsert(const TypeConverter &typeConverter,
506 MLIRContext *context, PatternBenefit benefit = 1)
507 : OpConversionPattern(typeConverter, context, benefit) {}
508 LogicalResult
509 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter) const override {
511 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
512 t: insertOp.getDestVectorType());
513 assert(dstTy && "vector type destination expected.");
514
515 // Dynamic position is not supported.
516 if (insertOp.hasDynamicPosition())
517 return rewriter.notifyMatchFailure(arg&: insertOp,
518 msg: "dynamic position is not supported.");
519 auto srcTy = insertOp.getValueToStoreType();
520 auto srcAsVec = dyn_cast<VectorType>(Val&: srcTy);
521 uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
522
523 auto dstShape = insertOp.getDestVectorType().getShape();
524 const auto dstSize = insertOp.getDestVectorType().getNumElements();
525 auto dstSizeForOffsets = dstSize;
526
527 // Compute linearized offset.
528 int64_t linearizedOffset = 0;
529 auto offsetsNd = insertOp.getStaticPosition();
530 for (auto [dim, offset] : llvm::enumerate(First&: offsetsNd)) {
531 dstSizeForOffsets /= dstShape[dim];
532 linearizedOffset += offset * dstSizeForOffsets;
533 }
534
535 Location loc = insertOp.getLoc();
536 Value valueToStore = adaptor.getValueToStore();
537
538 if (!isa<VectorType>(Val: valueToStore.getType())) {
539 // Scalar case: generate a 1-D insert.
540 Value result = rewriter.createOrFold<vector::InsertOp>(
541 location: loc, args&: valueToStore, args: adaptor.getDest(), args&: linearizedOffset);
542 rewriter.replaceOp(op: insertOp, newValues: result);
543 return success();
544 }
545
546 // Vector case: generate a shuffle.
547 llvm::SmallVector<int64_t, 2> indices(dstSize);
548 auto *origValsUntil = indices.begin();
549 std::advance(i&: origValsUntil, n: linearizedOffset);
550
551 // Original values that remain [0, offset).
552 std::iota(first: indices.begin(), last: origValsUntil, value: 0);
553 auto *newValsUntil = origValsUntil;
554 std::advance(i&: newValsUntil, n: srcSize);
555 // New values [offset, offset+srcNumElements).
556 std::iota(first: origValsUntil, last: newValsUntil, value: dstSize);
557 // The rest of original values [offset+srcNumElements, end);
558 std::iota(first: newValsUntil, last: indices.end(), value: linearizedOffset + srcSize);
559
560 Value result = rewriter.createOrFold<vector::ShuffleOp>(
561 location: loc, args&: dstTy, args: adaptor.getDest(), args&: valueToStore, args&: indices);
562
563 rewriter.replaceOp(op: insertOp, newValues: result);
564 return success();
565 }
566};
567
568/// This pattern converts the BitCastOp that works on nD (n > 1)
569/// vectors to a BitCastOp that works on linearized vectors.
570/// Following,
571/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
572/// is converted to :
573/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
574/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
575/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
576struct LinearizeVectorBitCast final
577 : public OpConversionPattern<vector::BitCastOp> {
578 using OpConversionPattern::OpConversionPattern;
579 LinearizeVectorBitCast(const TypeConverter &typeConverter,
580 MLIRContext *context, PatternBenefit benefit = 1)
581 : OpConversionPattern(typeConverter, context, benefit) {}
582 LogicalResult
583 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
584 ConversionPatternRewriter &rewriter) const override {
585 auto resType = getTypeConverter()->convertType(t: castOp.getType());
586 assert(resType && "expected 1-D vector type");
587 rewriter.replaceOpWithNewOp<vector::BitCastOp>(op: castOp, args&: resType,
588 args: adaptor.getSource());
589 return mlir::success();
590 }
591};
592
593/// This pattern converts the SplatOp to work on a linearized vector.
594/// Following,
595/// vector.splat %value : vector<4x4xf32>
596/// is converted to:
597/// %out_1d = vector.splat %value : vector<16xf32>
598/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
599struct LinearizeVectorSplat final
600 : public OpConversionPattern<vector::SplatOp> {
601 using OpConversionPattern::OpConversionPattern;
602
603 LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
604 PatternBenefit benefit = 1)
605 : OpConversionPattern(typeConverter, context, benefit) {}
606
607 LogicalResult
608 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const override {
610 auto dstTy = getTypeConverter()->convertType(t: splatOp.getType());
611 if (!dstTy)
612 return rewriter.notifyMatchFailure(arg&: splatOp, msg: "cannot convert type.");
613 rewriter.replaceOpWithNewOp<vector::SplatOp>(op: splatOp, args: adaptor.getInput(),
614 args&: dstTy);
615 return success();
616 }
617};
618
619/// This pattern converts the CreateMaskOp to work on a linearized vector.
620/// It currently supports only 2D masks with a unit outer dimension.
621/// Following,
622/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
623/// is converted to:
624/// %zero = arith.constant 0 : index
625/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
626/// %index = arith.index_cast %cmpi : i1 to index
627/// %mul = arith.andi %index, %arg1 : index
628/// %mask = vector.create_mask %mul : vector<4xi1>
629/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
630struct LinearizeVectorCreateMask final
631 : OpConversionPattern<vector::CreateMaskOp> {
632 using OpConversionPattern::OpConversionPattern;
633
634 LinearizeVectorCreateMask(const TypeConverter &typeConverter,
635 MLIRContext *context, PatternBenefit benefit = 1)
636 : OpConversionPattern(typeConverter, context, benefit) {}
637
638 LogicalResult
639 matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
640 ConversionPatternRewriter &rewriter) const override {
641 Location loc = createMaskOp.getLoc();
642 VectorType srcTy = createMaskOp.getType();
643 auto srcShape = srcTy.getShape();
644 if (srcShape.size() != 2)
645 return rewriter.notifyMatchFailure(arg&: createMaskOp,
646 msg: "only 2D mask is supported.");
647
648 if (srcShape[0] != 1)
649 return rewriter.notifyMatchFailure(
650 arg&: createMaskOp, msg: "only unit outer dimension is supported.");
651
652 auto dstTy = getTypeConverter()->convertType(t: srcTy);
653 if (!dstTy)
654 return rewriter.notifyMatchFailure(arg&: createMaskOp, msg: "cannot convert type.");
655
656 // Compare the first operand with 0. If it is greater than 0, the
657 // corresponding mask element is set to true, otherwise false.
658 // The result of the comparison is then multiplied with
659 // the second operand of create_mask to get the 1D mask.
660 auto firstOperand = adaptor.getOperands().front();
661 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(location: loc, args: 0);
662 auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
663 location: loc, args: mlir::arith::CmpIPredicate::sgt, args&: firstOperand, args&: zero);
664 auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
665 location: loc, args: rewriter.getIndexType(), args&: isNonZero);
666 auto secondOperand = adaptor.getOperands().back();
667 auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
668 location: loc, args: rewriter.getIndexType(), args&: isNonZeroIndex, args&: secondOperand);
669
670 auto newMask =
671 rewriter.create<mlir::vector::CreateMaskOp>(location: loc, args&: dstTy, args&: maskSize);
672 rewriter.replaceOp(op: createMaskOp, newOp: newMask);
673 return success();
674 }
675};
676
677/// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
678/// It currently supports linearization where all but the last dimension are 1
679/// The following,
680/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
681/// is converted to:
682/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
683/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
684/// For generic cases, the vector unroll pass should be used to unroll the load
685/// to vector<1x1x...xN> form and then linearized
686struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
687 using OpConversionPattern::OpConversionPattern;
688 LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
689 PatternBenefit benefit = 1)
690 : OpConversionPattern(typeConverter, context, benefit) {}
691
692 LogicalResult
693 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter) const override {
695 VectorType vecTy = loadOp.getType();
696 if (!vecTy)
697 return rewriter.notifyMatchFailure(arg&: loadOp, msg: "expected vector type");
698
699 auto shape = vecTy.getShape();
700 auto scalableDims = vecTy.getScalableDims();
701 // All but the last dim must be 1, and only the last dim may be scalable (if
702 // any).
703 if (!llvm::all_of(Range: shape.drop_back(N: 1), P: [](auto d) { return d == 1; }))
704 return rewriter.notifyMatchFailure(arg&: loadOp,
705 msg: "only vector<1x1x...xN> supported");
706
707 if (llvm::any_of(Range: scalableDims.drop_back(N: 1), P: [](bool s) { return s; }))
708 return rewriter.notifyMatchFailure(arg&: loadOp,
709 msg: "only innermost dim may be scalable");
710
711 auto linearTy = typeConverter->convertType<VectorType>(t: vecTy);
712
713 auto newLoad = rewriter.create<vector::LoadOp>(
714 location: loadOp.getLoc(), args&: linearTy, args: adaptor.getBase(), args: adaptor.getIndices());
715 rewriter.replaceOp(op: loadOp, newValues: newLoad.getResult());
716 return success();
717 }
718};
719
720/// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
721/// It currently supports linearization where all but the last dimension are 1
722/// The following,
723/// vector.store %arg0, %arg1[%c0, %c0]s
724/// : vector<1x4xf32>, memref<1x4xf32>
725/// is converted to:
726/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
727/// vector.store %arg0, %arg1[%c0, %c0]
728/// : vector<4xf32>, memref<1x4xf32>
729/// For generic cases, the vector unroll pass should be used to unroll the store
730/// to vector<1x1x...xN> form and then linearized
731struct LinearizeVectorStore final
732 : public OpConversionPattern<vector::StoreOp> {
733 using OpConversionPattern::OpConversionPattern;
734 LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
735 PatternBenefit benefit = 1)
736 : OpConversionPattern(typeConverter, context, benefit) {}
737
738 LogicalResult
739 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
740 ConversionPatternRewriter &rewriter) const override {
741 VectorType vecTy = storeOp.getValueToStore().getType();
742 if (!vecTy)
743 return rewriter.notifyMatchFailure(arg&: storeOp, msg: "expected vector type");
744
745 auto shape = vecTy.getShape();
746 auto scalableDims = vecTy.getScalableDims();
747 // All but the last dim must be 1, and only the last dim may be scalable (if
748 // any).
749 if (!llvm::all_of(Range: shape.drop_back(N: 1), P: [](auto d) { return d == 1; }))
750 return rewriter.notifyMatchFailure(arg&: storeOp,
751 msg: "only vector<1x1x...xN> supported");
752
753 if (llvm::any_of(Range: scalableDims.drop_back(N: 1), P: [](bool s) { return s; }))
754 return rewriter.notifyMatchFailure(arg&: storeOp,
755 msg: "only innermost dim may be scalable");
756
757 rewriter.replaceOpWithNewOp<vector::StoreOp>(
758 op: storeOp, args: adaptor.getValueToStore(), args: adaptor.getBase(),
759 args: adaptor.getIndices());
760 return success();
761 }
762};
763
764} // namespace
765
766/// This method defines the set of operations that are linearizable, and hence
767/// that are considered illegal for the conversion target.
768static bool isLinearizable(Operation *op) {
769
770 // Only ops that are in the vector dialect, are ConstantLike, or
771 // are Vectorizable might be linearized currently.
772 StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
773 StringRef opDialect = op->getDialect()->getNamespace();
774 bool supported = (opDialect == vectorDialect) ||
775 op->hasTrait<OpTrait::ConstantLike>() ||
776 op->hasTrait<OpTrait::Vectorizable>();
777 if (!supported)
778 return false;
779
780 return TypeSwitch<Operation *, bool>(op)
781 // As type legalization is done with vector.shape_cast, shape_cast
782 // itself cannot be linearized (will create new shape_casts to linearize
783 // ad infinitum).
784 .Case<vector::ShapeCastOp>(caseFn: [&](auto) { return false; })
785 // The operations
786 // - vector.extract_strided_slice
787 // - vector.extract
788 // - vector.insert_strided_slice
789 // - vector.insert
790 // are linearized to a rank-1 vector.shuffle by the current patterns.
791 // vector.shuffle only supports fixed size vectors, so it is impossible to
792 // use this approach to linearize these ops if they operate on scalable
793 // vectors.
794 .Case<vector::ExtractStridedSliceOp>(
795 caseFn: [&](vector::ExtractStridedSliceOp extractOp) {
796 return !extractOp.getType().isScalable();
797 })
798 .Case<vector::InsertStridedSliceOp>(
799 caseFn: [&](vector::InsertStridedSliceOp insertOp) {
800 return !insertOp.getType().isScalable();
801 })
802 .Case<vector::InsertOp>(caseFn: [&](vector::InsertOp insertOp) {
803 return !insertOp.getType().isScalable();
804 })
805 .Case<vector::ExtractOp>(caseFn: [&](vector::ExtractOp extractOp) {
806 return !extractOp.getSourceVectorType().isScalable();
807 })
808 .Default(defaultFn: [&](auto) { return true; });
809}
810
811void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
812 ConversionTarget &target) {
813
814 auto convertType = [](Type type) -> std::optional<Type> {
815 VectorType vectorType = dyn_cast<VectorType>(Val&: type);
816 if (!vectorType || !isLinearizableVector(type: vectorType))
817 return type;
818
819 VectorType linearizedType =
820 VectorType::get(shape: vectorType.getNumElements(),
821 elementType: vectorType.getElementType(), scalableDims: vectorType.isScalable());
822 return linearizedType;
823 };
824 typeConverter.addConversion(callback&: convertType);
825
826 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
827 Location loc) -> Value {
828 if (inputs.size() != 1)
829 return nullptr;
830
831 Value value = inputs.front();
832 if (!isa<VectorType>(Val: type) || !isa<VectorType>(Val: value.getType()))
833 return nullptr;
834
835 return builder.create<vector::ShapeCastOp>(location: loc, args&: type, args&: value);
836 };
837 typeConverter.addSourceMaterialization(callback&: materializeCast);
838 typeConverter.addTargetMaterialization(callback&: materializeCast);
839
840 target.markUnknownOpDynamicallyLegal(
841 fn: [=](Operation *op) -> std::optional<bool> {
842 if (!isLinearizable(op))
843 return true;
844 // This will return true if, for all operand and result types `t`,
845 // convertType(t) = t. This is true if there are no rank>=2 vectors.
846 return typeConverter.isLegal(op);
847 });
848}
849
850void mlir::vector::populateVectorLinearizeBasePatterns(
851 const TypeConverter &typeConverter, const ConversionTarget &target,
852 RewritePatternSet &patterns) {
853 patterns
854 .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
855 LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
856 LinearizeVectorStore>(arg: typeConverter, args: patterns.getContext());
857}
858
859void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
860 const TypeConverter &typeConverter, const ConversionTarget &target,
861 RewritePatternSet &patterns) {
862 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
863 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
864 LinearizeVectorInsertStridedSlice>(arg: typeConverter,
865 args: patterns.getContext());
866}
867

source code of mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp