1//===- VectorLinearize.cpp - vector linearization transforms --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns and pass for linearizing ND vectors into 1D.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/UB/IR/UBOps.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/ArrayRef.h"
23#include <cstdint>
24#include <numeric>
25#include <optional>
26
27using namespace mlir;
28
29static FailureOr<Attribute>
30linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
31 VectorType resType, Attribute value) {
32
33 if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(Val&: value)) {
34 if (resType.isScalable() && !isa<SplatElementsAttr>(Val: value))
35 return rewriter.notifyMatchFailure(
36 arg&: loc,
37 msg: "Cannot linearize a constant scalable vector that's not a splat");
38
39 return dstElementsAttr.reshape(resType);
40 }
41
42 if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
43 return poisonAttr;
44
45 return rewriter.notifyMatchFailure(arg&: loc, msg: "unsupported attr type");
46}
47
48namespace {
49
50struct LinearizeConstantLike final
51 : OpTraitConversionPattern<OpTrait::ConstantLike> {
52 using OpTraitConversionPattern::OpTraitConversionPattern;
53
54 LinearizeConstantLike(const TypeConverter &typeConverter,
55 MLIRContext *context, PatternBenefit benefit = 1)
56 : OpTraitConversionPattern(typeConverter, context, benefit) {}
57 LogicalResult
58 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59 ConversionPatternRewriter &rewriter) const override {
60 Location loc = op->getLoc();
61 if (op->getNumResults() != 1)
62 return rewriter.notifyMatchFailure(arg&: loc, msg: "expected 1 result");
63
64 const TypeConverter &typeConverter = *getTypeConverter();
65 auto resType =
66 typeConverter.convertType<VectorType>(op->getResult(idx: 0).getType());
67 assert(resType && "expected 1-D vector type");
68
69 StringAttr attrName = rewriter.getStringAttr("value");
70 Attribute value = op->getAttr(attrName);
71 if (!value)
72 return rewriter.notifyMatchFailure(arg&: loc, msg: "no 'value' attr");
73
74 FailureOr<Attribute> newValue =
75 linearizeConstAttr(loc, rewriter, resType, value);
76 if (failed(Result: newValue))
77 return failure();
78
79 FailureOr<Operation *> convertResult =
80 convertOpResultTypes(op, /*operands=*/{}, converter: typeConverter, rewriter);
81 if (failed(Result: convertResult))
82 return failure();
83
84 Operation *newOp = *convertResult;
85 newOp->setAttr(attrName, *newValue);
86 rewriter.replaceOp(op, newOp);
87 return success();
88 }
89};
90
91struct LinearizeVectorizable final
92 : OpTraitConversionPattern<OpTrait::Vectorizable> {
93 using OpTraitConversionPattern::OpTraitConversionPattern;
94
95public:
96 LinearizeVectorizable(const TypeConverter &typeConverter,
97 MLIRContext *context, PatternBenefit benefit = 1)
98 : OpTraitConversionPattern(typeConverter, context, benefit) {}
99 LogicalResult
100 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101 ConversionPatternRewriter &rewriter) const override {
102 FailureOr<Operation *> newOp =
103 convertOpResultTypes(op, operands, converter: *getTypeConverter(), rewriter);
104 if (failed(Result: newOp))
105 return failure();
106
107 rewriter.replaceOp(op, newValues: (*newOp)->getResults());
108 return success();
109 }
110};
111
112template <typename TOp>
113static bool stridesAllOne(TOp op) {
114 static_assert(
115 std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116 std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117 "expected vector.extract_strided_slice or vector.insert_strided_slice");
118 ArrayAttr strides = op.getStrides();
119 return llvm::all_of(strides, isOneInteger);
120}
121
122/// Convert an array of attributes into a vector of integers, if possible.
123static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124 if (!attrs)
125 return failure();
126 SmallVector<int64_t> ints;
127 ints.reserve(N: attrs.size());
128 for (auto attr : attrs) {
129 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130 ints.push_back(intAttr.getInt());
131 } else {
132 return failure();
133 }
134 }
135 return ints;
136}
137
138/// Consider inserting a vector of shape `small` into a vector of shape `large`,
139/// at position `offsets`: this function enumeratates all the indices in `large`
140/// that are written to. The enumeration is with row-major ordering.
141///
142/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143/// positions written to are (1,3) and (1,4), which have linearized indices 8
144/// and 9. So [8,9] is returned.
145///
146/// The length of the returned vector is equal to the number of elements in
147/// the shape `small` (i.e. the product of dimensions of `small`).
148SmallVector<int64_t> static getStridedSliceInsertionIndices(
149 ArrayRef<int64_t> small, ArrayRef<int64_t> large,
150 ArrayRef<int64_t> offsets) {
151
152 // Example of alignment between, `large`, `small` and `offsets`:
153 // large = 4, 5, 6, 7, 8
154 // small = 1, 6, 7, 8
155 // offsets = 2, 3, 0
156 //
157 // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158 assert((large.size() >= small.size()) &&
159 "rank of 'large' cannot be lower than rank of 'small'");
160 assert((large.size() >= offsets.size()) &&
161 "rank of 'large' cannot be lower than the number of offsets");
162 unsigned delta = large.size() - small.size();
163 unsigned nOffsets = offsets.size();
164 auto getSmall = [&](int64_t i) -> int64_t {
165 return i >= delta ? small[i - delta] : 1;
166 };
167 auto getOffset = [&](int64_t i) -> int64_t {
168 return i < nOffsets ? offsets[i] : 0;
169 };
170
171 // Using 2 vectors of indices, at each iteration populate the updated set of
172 // indices based on the old set of indices, and the size of the small vector
173 // in the current iteration.
174 SmallVector<int64_t> indices{0};
175 int64_t stride = 1;
176 for (int i = large.size() - 1; i >= 0; --i) {
177 int64_t currentSize = indices.size();
178 int64_t smallSize = getSmall(i);
179 int64_t nextSize = currentSize * smallSize;
180 SmallVector<int64_t> nextIndices(nextSize);
181 int64_t *base = nextIndices.begin();
182 int64_t offset = getOffset(i) * stride;
183 for (int j = 0; j < smallSize; ++j) {
184 for (int k = 0; k < currentSize; ++k) {
185 base[k] = indices[k] + offset;
186 }
187 offset += stride;
188 base += currentSize;
189 }
190 stride *= large[i];
191 indices = std::move(nextIndices);
192 }
193 return indices;
194}
195
196/// This pattern converts a vector.extract_strided_slice operation into a
197/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198///
199/// For example, the following:
200///
201/// ```
202/// vector.extract_strided_slice %source
203/// { offsets = [..], strides = [..], sizes = [..] }
204/// ```
205///
206/// is converted to :
207/// ```
208/// %source_1d = vector.shape_cast %source
209/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
210/// %out_nd = vector.shape_cast %out_1d
211/// ```
212///
213/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
214/// vector.extract_strided_slice operation.
215struct LinearizeVectorExtractStridedSlice final
216 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
217 using OpConversionPattern::OpConversionPattern;
218 LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
219 MLIRContext *context,
220 PatternBenefit benefit = 1)
221 : OpConversionPattern(typeConverter, context, benefit) {}
222
223 LogicalResult
224 matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
225 OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const override {
227
228 VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229 extractStridedSliceOp.getType());
230 assert(flatOutputType && "vector type expected");
231
232 // Expect a legalization failure if the strides are not all 1 (if ever the
233 // verifier for extract_strided_slice allows non-1 strides).
234 if (!stridesAllOne(extractStridedSliceOp)) {
235 return rewriter.notifyMatchFailure(
236 extractStridedSliceOp,
237 "extract_strided_slice with strides != 1 not supported");
238 }
239
240 FailureOr<SmallVector<int64_t>> offsets =
241 intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242 if (failed(Result: offsets)) {
243 return rewriter.notifyMatchFailure(extractStridedSliceOp,
244 "failed to get integer offsets");
245 }
246
247 ArrayRef<int64_t> inputShape =
248 extractStridedSliceOp.getSourceVectorType().getShape();
249
250 ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251
252 SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253 small: outputShape, large: inputShape, offsets: offsets.value());
254
255 Value srcVector = adaptor.getVector();
256 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257 extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
258 return success();
259 }
260};
261
262/// This pattern converts a vector.insert_strided_slice operation into a
263/// vector.shuffle operation that has rank-1 (linearized) operands and result.
264///
265/// For example, the following:
266/// ```
267/// %0 = vector.insert_strided_slice %to_store, %into
268/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
269/// : vector<2x2xi8> into vector<2x1x3x2xi8>
270/// ```
271///
272/// is converted to
273/// ```
274/// %to_store_1d
275/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
276/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
277/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
278/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
279/// ```
280///
281/// where shuffle_indices_1d in this case is
282/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
283/// ^^^^^^^^^^^^^^
284/// to_store_1d
285///
286struct LinearizeVectorInsertStridedSlice final
287 : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
288 using OpConversionPattern::OpConversionPattern;
289 LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
290 MLIRContext *context,
291 PatternBenefit benefit = 1)
292 : OpConversionPattern(typeConverter, context, benefit) {}
293
294 LogicalResult
295 matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
296 OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298
299 // Expect a legalization failure if the strides are not all 1 (if ever the
300 // verifier for insert_strided_slice allows non-1 strides).
301 if (!stridesAllOne(insertStridedSliceOp)) {
302 return rewriter.notifyMatchFailure(
303 insertStridedSliceOp,
304 "insert_strided_slice with strides != 1 not supported");
305 }
306
307 VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308 ArrayRef<int64_t> inputShape = inputType.getShape();
309
310 VectorType outputType = insertStridedSliceOp.getType();
311 ArrayRef<int64_t> outputShape = outputType.getShape();
312 int64_t nOutputElements = outputType.getNumElements();
313
314 FailureOr<SmallVector<int64_t>> offsets =
315 intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316 if (failed(Result: offsets)) {
317 return rewriter.notifyMatchFailure(insertStridedSliceOp,
318 "failed to get integer offsets");
319 }
320 SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321 small: inputShape, large: outputShape, offsets: offsets.value());
322
323 SmallVector<int64_t> indices(nOutputElements);
324 std::iota(first: indices.begin(), last: indices.end(), value: 0);
325 for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
326 indices[sliceIndex] = index + nOutputElements;
327 }
328
329 Value flatToStore = adaptor.getValueToStore();
330 Value flatDest = adaptor.getDest();
331 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
332 flatDest.getType(), flatDest,
333 flatToStore, indices);
334 return success();
335 }
336};
337
338/// This pattern converts the ShuffleOp that works on nD (n > 1)
339/// vectors to a ShuffleOp that works on linearized vectors.
340/// Following,
341/// vector.shuffle %v1, %v2 [ shuffle_indices ]
342/// is converted to :
343/// %v1_1d = vector.shape_cast %v1
344/// %v2_1d = vector.shape_cast %v2
345/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
346/// %out_nd = vector.shape_cast %out_1d
347// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
348/// of the original shuffle operation.
349struct LinearizeVectorShuffle final
350 : public OpConversionPattern<vector::ShuffleOp> {
351 using OpConversionPattern::OpConversionPattern;
352 LinearizeVectorShuffle(const TypeConverter &typeConverter,
353 MLIRContext *context, PatternBenefit benefit = 1)
354 : OpConversionPattern(typeConverter, context, benefit) {}
355
356 LogicalResult
357 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 VectorType dstType =
360 getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361 assert(dstType && "vector type destination expected.");
362
363 Value vec1 = adaptor.getV1();
364 Value vec2 = adaptor.getV2();
365 int shuffleSliceLen = 1;
366 int rank = shuffleOp.getV1().getType().getRank();
367
368 // If rank > 1, we need to do the shuffle in the granularity of slices
369 // instead of scalars. Size of the slice is equal to the rank-1 innermost
370 // dims. Mask of the shuffle op specifies which slice to take from the
371 // outermost dim.
372 if (rank > 1) {
373 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
374 for (unsigned i = 1; i < shape.size(); ++i) {
375 shuffleSliceLen *= shape[i];
376 }
377 }
378
379 // For each value in the mask, we generate the indices of the source vectors
380 // that need to be shuffled to the destination vector. If shuffleSliceLen >
381 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
382 // elements) instead of scalars.
383 ArrayRef<int64_t> mask = shuffleOp.getMask();
384 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
385 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
386 for (auto [i, value] : llvm::enumerate(mask)) {
387 std::iota(indices.begin() + shuffleSliceLen * i,
388 indices.begin() + shuffleSliceLen * (i + 1),
389 shuffleSliceLen * value);
390 }
391
392 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
393 vec2, indices);
394 return success();
395 }
396};
397
398/// This pattern converts the ExtractOp to a ShuffleOp that works on a
399/// linearized vector.
400/// Following,
401/// vector.extract %source [ position ]
402/// is converted to :
403/// %source_1d = vector.shape_cast %source
404/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
405/// %out_nd = vector.shape_cast %out_1d
406/// `shuffle_indices_1d` is computed using the position of the original extract.
407struct LinearizeVectorExtract final
408 : public OpConversionPattern<vector::ExtractOp> {
409 using OpConversionPattern::OpConversionPattern;
410 LinearizeVectorExtract(const TypeConverter &typeConverter,
411 MLIRContext *context, PatternBenefit benefit = 1)
412 : OpConversionPattern(typeConverter, context, benefit) {}
413 LogicalResult
414 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter) const override {
416 // Skip if result is not a vector type
417 if (!isa<VectorType>(extractOp.getType()))
418 return rewriter.notifyMatchFailure(extractOp,
419 "scalar extract not supported");
420 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
421 assert(dstTy && "expected 1-D vector type");
422
423 // Dynamic position is not supported.
424 if (extractOp.hasDynamicPosition())
425 return rewriter.notifyMatchFailure(extractOp,
426 "dynamic position is not supported.");
427
428 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
429 int64_t size = extractOp.getVector().getType().getNumElements();
430
431 // Compute linearized offset.
432 int64_t linearizedOffset = 0;
433 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
434 for (auto [i, off] : llvm::enumerate(offsets)) {
435 size /= shape[i];
436 linearizedOffset += offsets[i] * size;
437 }
438
439 llvm::SmallVector<int64_t, 2> indices(size);
440 std::iota(first: indices.begin(), last: indices.end(), value: linearizedOffset);
441 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
442 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
443
444 return success();
445 }
446};
447
448/// This pattern converts the InsertOp to a ShuffleOp that works on a
449/// linearized vector.
450/// Following,
451/// vector.insert %source %destination [ position ]
452/// is converted to :
453/// %source_1d = vector.shape_cast %source
454/// %destination_1d = vector.shape_cast %destination
455/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
456/// ] %out_nd = vector.shape_cast %out_1d
457/// `shuffle_indices_1d` is computed using the position of the original insert.
458struct LinearizeVectorInsert final
459 : public OpConversionPattern<vector::InsertOp> {
460 using OpConversionPattern::OpConversionPattern;
461 LinearizeVectorInsert(const TypeConverter &typeConverter,
462 MLIRContext *context, PatternBenefit benefit = 1)
463 : OpConversionPattern(typeConverter, context, benefit) {}
464 LogicalResult
465 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
466 ConversionPatternRewriter &rewriter) const override {
467 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
468 insertOp.getDestVectorType());
469 assert(dstTy && "vector type destination expected.");
470
471 // dynamic position is not supported
472 if (insertOp.hasDynamicPosition())
473 return rewriter.notifyMatchFailure(insertOp,
474 "dynamic position is not supported.");
475 auto srcTy = insertOp.getValueToStoreType();
476 auto srcAsVec = dyn_cast<VectorType>(srcTy);
477 uint64_t srcSize = 0;
478 if (srcAsVec) {
479 srcSize = srcAsVec.getNumElements();
480 } else {
481 return rewriter.notifyMatchFailure(insertOp,
482 "scalars are not supported.");
483 }
484
485 auto dstShape = insertOp.getDestVectorType().getShape();
486 const auto dstSize = insertOp.getDestVectorType().getNumElements();
487 auto dstSizeForOffsets = dstSize;
488
489 // compute linearized offset
490 int64_t linearizedOffset = 0;
491 auto offsetsNd = insertOp.getStaticPosition();
492 for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
493 dstSizeForOffsets /= dstShape[dim];
494 linearizedOffset += offset * dstSizeForOffsets;
495 }
496
497 llvm::SmallVector<int64_t, 2> indices(dstSize);
498 auto *origValsUntil = indices.begin();
499 std::advance(origValsUntil, linearizedOffset);
500 std::iota(indices.begin(), origValsUntil,
501 0); // original values that remain [0, offset)
502 auto *newValsUntil = origValsUntil;
503 std::advance(newValsUntil, srcSize);
504 std::iota(origValsUntil, newValsUntil,
505 dstSize); // new values [offset, offset+srcNumElements)
506 std::iota(newValsUntil, indices.end(),
507 linearizedOffset + srcSize); // the rest of original values
508 // [offset+srcNumElements, end)
509
510 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
511 insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
512
513 return success();
514 }
515};
516
517/// This pattern converts the BitCastOp that works on nD (n > 1)
518/// vectors to a BitCastOp that works on linearized vectors.
519/// Following,
520/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
521/// is converted to :
522/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
523/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
524/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
525struct LinearizeVectorBitCast final
526 : public OpConversionPattern<vector::BitCastOp> {
527 using OpConversionPattern::OpConversionPattern;
528 LinearizeVectorBitCast(const TypeConverter &typeConverter,
529 MLIRContext *context, PatternBenefit benefit = 1)
530 : OpConversionPattern(typeConverter, context, benefit) {}
531 LogicalResult
532 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
533 ConversionPatternRewriter &rewriter) const override {
534 auto resType = getTypeConverter()->convertType(castOp.getType());
535 assert(resType && "expected 1-D vector type");
536 rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
537 adaptor.getSource());
538 return mlir::success();
539 }
540};
541
542/// This pattern converts the SplatOp to work on a linearized vector.
543/// Following,
544/// vector.splat %value : vector<4x4xf32>
545/// is converted to:
546/// %out_1d = vector.splat %value : vector<16xf32>
547/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
548struct LinearizeVectorSplat final
549 : public OpConversionPattern<vector::SplatOp> {
550 using OpConversionPattern::OpConversionPattern;
551
552 LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
553 PatternBenefit benefit = 1)
554 : OpConversionPattern(typeConverter, context, benefit) {}
555
556 LogicalResult
557 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter) const override {
559 auto dstTy = getTypeConverter()->convertType(splatOp.getType());
560 if (!dstTy)
561 return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
562 rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
563 dstTy);
564 return success();
565 }
566};
567
568/// This pattern converts the CreateMaskOp to work on a linearized vector.
569/// It currently supports only 2D masks with a unit outer dimension.
570/// Following,
571/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
572/// is converted to:
573/// %zero = arith.constant 0 : index
574/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
575/// %index = arith.index_cast %cmpi : i1 to index
576/// %mul = arith.andi %index, %arg1 : index
577/// %mask = vector.create_mask %mul : vector<4xi1>
578/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
579struct LinearizeVectorCreateMask final
580 : OpConversionPattern<vector::CreateMaskOp> {
581 using OpConversionPattern::OpConversionPattern;
582
583 LinearizeVectorCreateMask(const TypeConverter &typeConverter,
584 MLIRContext *context, PatternBenefit benefit = 1)
585 : OpConversionPattern(typeConverter, context, benefit) {}
586
587 LogicalResult
588 matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
589 ConversionPatternRewriter &rewriter) const override {
590 Location loc = createMaskOp.getLoc();
591 VectorType srcTy = createMaskOp.getType();
592 auto srcShape = srcTy.getShape();
593 if (srcShape.size() != 2)
594 return rewriter.notifyMatchFailure(createMaskOp,
595 "only 2D mask is supported.");
596
597 if (srcShape[0] != 1)
598 return rewriter.notifyMatchFailure(
599 createMaskOp, "only unit outer dimension is supported.");
600
601 auto dstTy = getTypeConverter()->convertType(srcTy);
602 if (!dstTy)
603 return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
604
605 // Compare the first operand with 0. If it is greater than 0, the
606 // corresponding mask element is set to true, otherwise false.
607 // The result of the comparison is then multiplied with
608 // the second operand of create_mask to get the 1D mask.
609 auto firstOperand = adaptor.getOperands().front();
610 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(location: loc, args: 0);
611 auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
612 loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
613 auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
614 loc, rewriter.getIndexType(), isNonZero);
615 auto secondOperand = adaptor.getOperands().back();
616 auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
617 loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
618
619 auto newMask =
620 rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
621 rewriter.replaceOp(createMaskOp, newMask);
622 return success();
623 }
624};
625
626} // namespace
627
628/// This method defines the set of operations that are linearizable, and hence
629/// that are considered illegal for the conversion target.
630static bool isLinearizable(Operation *op) {
631
632 // Only ops that are in the vector dialect, are ConstantLike, or
633 // are Vectorizable might be linearized currently.
634 StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
635 StringRef opDialect = op->getDialect()->getNamespace();
636 bool supported = (opDialect == vectorDialect) ||
637 op->hasTrait<OpTrait::ConstantLike>() ||
638 op->hasTrait<OpTrait::Vectorizable>();
639 if (!supported)
640 return false;
641
642 return TypeSwitch<Operation *, bool>(op)
643 // As type legalization is done with vector.shape_cast, shape_cast
644 // itself cannot be linearized (will create new shape_casts to linearize
645 // ad infinitum).
646 .Case<vector::ShapeCastOp>([&](auto) { return false; })
647 // The operations
648 // - vector.extract_strided_slice
649 // - vector.extract
650 // - vector.insert_strided_slice
651 // - vector.insert
652 // are linearized to a rank-1 vector.shuffle by the current patterns.
653 // vector.shuffle only supports fixed size vectors, so it is impossible to
654 // use this approach to linearize these ops if they operate on scalable
655 // vectors.
656 .Case<vector::ExtractStridedSliceOp>(
657 [&](vector::ExtractStridedSliceOp extractOp) {
658 return !extractOp.getType().isScalable();
659 })
660 .Case<vector::InsertStridedSliceOp>(
661 [&](vector::InsertStridedSliceOp insertOp) {
662 return !insertOp.getType().isScalable();
663 })
664 .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
665 return !insertOp.getType().isScalable();
666 })
667 .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
668 return !extractOp.getSourceVectorType().isScalable();
669 })
670 .Default([&](auto) { return true; });
671}
672
673void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
674 ConversionTarget &target) {
675
676 auto convertType = [](Type type) -> std::optional<Type> {
677 VectorType vectorType = dyn_cast<VectorType>(type);
678 if (!vectorType || !isLinearizableVector(vectorType))
679 return type;
680
681 VectorType linearizedType =
682 VectorType::get(vectorType.getNumElements(),
683 vectorType.getElementType(), vectorType.isScalable());
684 return linearizedType;
685 };
686 typeConverter.addConversion(callback&: convertType);
687
688 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
689 Location loc) -> Value {
690 if (inputs.size() != 1)
691 return nullptr;
692
693 Value value = inputs.front();
694 if (!isa<VectorType>(Val: type) || !isa<VectorType>(Val: value.getType()))
695 return nullptr;
696
697 return builder.create<vector::ShapeCastOp>(loc, type, value);
698 };
699 typeConverter.addSourceMaterialization(callback&: materializeCast);
700 typeConverter.addTargetMaterialization(callback&: materializeCast);
701
702 target.markUnknownOpDynamicallyLegal(
703 fn: [=](Operation *op) -> std::optional<bool> {
704 if (!isLinearizable(op))
705 return true;
706 // This will return true if, for all operand and result types `t`,
707 // convertType(t) = t. This is true if there are no rank>=2 vectors.
708 return typeConverter.isLegal(op);
709 });
710}
711
712void mlir::vector::populateVectorLinearizeBasePatterns(
713 const TypeConverter &typeConverter, const ConversionTarget &target,
714 RewritePatternSet &patterns) {
715 patterns
716 .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
717 LinearizeVectorSplat, LinearizeVectorCreateMask>(
718 arg: typeConverter, args: patterns.getContext());
719}
720
721void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
722 const TypeConverter &typeConverter, const ConversionTarget &target,
723 RewritePatternSet &patterns) {
724 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
725 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
726 LinearizeVectorInsertStridedSlice>(arg: typeConverter,
727 args: patterns.getContext());
728}
729

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp