1 | //===- VectorLinearize.cpp - vector linearization transforms --------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns and pass for linearizing ND vectors into 1D. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/UB/IR/UBOps.h" |
14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
16 | #include "mlir/IR/Attributes.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/Operation.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "llvm/ADT/ArrayRef.h" |
23 | #include <cstdint> |
24 | #include <numeric> |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | |
29 | static FailureOr<Attribute> |
30 | linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, |
31 | VectorType resType, Attribute value) { |
32 | |
33 | if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(Val&: value)) { |
34 | if (resType.isScalable() && !isa<SplatElementsAttr>(Val: value)) |
35 | return rewriter.notifyMatchFailure( |
36 | arg&: loc, |
37 | msg: "Cannot linearize a constant scalable vector that's not a splat"); |
38 | |
39 | return dstElementsAttr.reshape(resType); |
40 | } |
41 | |
42 | if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value)) |
43 | return poisonAttr; |
44 | |
45 | return rewriter.notifyMatchFailure(arg&: loc, msg: "unsupported attr type"); |
46 | } |
47 | |
48 | namespace { |
49 | |
50 | struct LinearizeConstantLike final |
51 | : OpTraitConversionPattern<OpTrait::ConstantLike> { |
52 | using OpTraitConversionPattern::OpTraitConversionPattern; |
53 | |
54 | LinearizeConstantLike(const TypeConverter &typeConverter, |
55 | MLIRContext *context, PatternBenefit benefit = 1) |
56 | : OpTraitConversionPattern(typeConverter, context, benefit) {} |
57 | LogicalResult |
58 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
59 | ConversionPatternRewriter &rewriter) const override { |
60 | Location loc = op->getLoc(); |
61 | if (op->getNumResults() != 1) |
62 | return rewriter.notifyMatchFailure(arg&: loc, msg: "expected 1 result"); |
63 | |
64 | const TypeConverter &typeConverter = *getTypeConverter(); |
65 | auto resType = |
66 | typeConverter.convertType<VectorType>(op->getResult(idx: 0).getType()); |
67 | assert(resType && "expected 1-D vector type"); |
68 | |
69 | StringAttr attrName = rewriter.getStringAttr("value"); |
70 | Attribute value = op->getAttr(attrName); |
71 | if (!value) |
72 | return rewriter.notifyMatchFailure(arg&: loc, msg: "no 'value' attr"); |
73 | |
74 | FailureOr<Attribute> newValue = |
75 | linearizeConstAttr(loc, rewriter, resType, value); |
76 | if (failed(Result: newValue)) |
77 | return failure(); |
78 | |
79 | FailureOr<Operation *> convertResult = |
80 | convertOpResultTypes(op, /*operands=*/{}, converter: typeConverter, rewriter); |
81 | if (failed(Result: convertResult)) |
82 | return failure(); |
83 | |
84 | Operation *newOp = *convertResult; |
85 | newOp->setAttr(attrName, *newValue); |
86 | rewriter.replaceOp(op, newOp); |
87 | return success(); |
88 | } |
89 | }; |
90 | |
91 | struct LinearizeVectorizable final |
92 | : OpTraitConversionPattern<OpTrait::Vectorizable> { |
93 | using OpTraitConversionPattern::OpTraitConversionPattern; |
94 | |
95 | public: |
96 | LinearizeVectorizable(const TypeConverter &typeConverter, |
97 | MLIRContext *context, PatternBenefit benefit = 1) |
98 | : OpTraitConversionPattern(typeConverter, context, benefit) {} |
99 | LogicalResult |
100 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
101 | ConversionPatternRewriter &rewriter) const override { |
102 | FailureOr<Operation *> newOp = |
103 | convertOpResultTypes(op, operands, converter: *getTypeConverter(), rewriter); |
104 | if (failed(Result: newOp)) |
105 | return failure(); |
106 | |
107 | rewriter.replaceOp(op, newValues: (*newOp)->getResults()); |
108 | return success(); |
109 | } |
110 | }; |
111 | |
112 | template <typename TOp> |
113 | static bool stridesAllOne(TOp op) { |
114 | static_assert( |
115 | std::is_same_v<TOp, vector::ExtractStridedSliceOp> || |
116 | std::is_same_v<TOp, vector::InsertStridedSliceOp>, |
117 | "expected vector.extract_strided_slice or vector.insert_strided_slice"); |
118 | ArrayAttr strides = op.getStrides(); |
119 | return llvm::all_of(strides, isOneInteger); |
120 | } |
121 | |
122 | /// Convert an array of attributes into a vector of integers, if possible. |
123 | static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) { |
124 | if (!attrs) |
125 | return failure(); |
126 | SmallVector<int64_t> ints; |
127 | ints.reserve(N: attrs.size()); |
128 | for (auto attr : attrs) { |
129 | if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
130 | ints.push_back(intAttr.getInt()); |
131 | } else { |
132 | return failure(); |
133 | } |
134 | } |
135 | return ints; |
136 | } |
137 | |
138 | /// Consider inserting a vector of shape `small` into a vector of shape `large`, |
139 | /// at position `offsets`: this function enumeratates all the indices in `large` |
140 | /// that are written to. The enumeration is with row-major ordering. |
141 | /// |
142 | /// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 |
143 | /// positions written to are (1,3) and (1,4), which have linearized indices 8 |
144 | /// and 9. So [8,9] is returned. |
145 | /// |
146 | /// The length of the returned vector is equal to the number of elements in |
147 | /// the shape `small` (i.e. the product of dimensions of `small`). |
148 | SmallVector<int64_t> static getStridedSliceInsertionIndices( |
149 | ArrayRef<int64_t> small, ArrayRef<int64_t> large, |
150 | ArrayRef<int64_t> offsets) { |
151 | |
152 | // Example of alignment between, `large`, `small` and `offsets`: |
153 | // large = 4, 5, 6, 7, 8 |
154 | // small = 1, 6, 7, 8 |
155 | // offsets = 2, 3, 0 |
156 | // |
157 | // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. |
158 | assert((large.size() >= small.size()) && |
159 | "rank of 'large' cannot be lower than rank of 'small'"); |
160 | assert((large.size() >= offsets.size()) && |
161 | "rank of 'large' cannot be lower than the number of offsets"); |
162 | unsigned delta = large.size() - small.size(); |
163 | unsigned nOffsets = offsets.size(); |
164 | auto getSmall = [&](int64_t i) -> int64_t { |
165 | return i >= delta ? small[i - delta] : 1; |
166 | }; |
167 | auto getOffset = [&](int64_t i) -> int64_t { |
168 | return i < nOffsets ? offsets[i] : 0; |
169 | }; |
170 | |
171 | // Using 2 vectors of indices, at each iteration populate the updated set of |
172 | // indices based on the old set of indices, and the size of the small vector |
173 | // in the current iteration. |
174 | SmallVector<int64_t> indices{0}; |
175 | int64_t stride = 1; |
176 | for (int i = large.size() - 1; i >= 0; --i) { |
177 | int64_t currentSize = indices.size(); |
178 | int64_t smallSize = getSmall(i); |
179 | int64_t nextSize = currentSize * smallSize; |
180 | SmallVector<int64_t> nextIndices(nextSize); |
181 | int64_t *base = nextIndices.begin(); |
182 | int64_t offset = getOffset(i) * stride; |
183 | for (int j = 0; j < smallSize; ++j) { |
184 | for (int k = 0; k < currentSize; ++k) { |
185 | base[k] = indices[k] + offset; |
186 | } |
187 | offset += stride; |
188 | base += currentSize; |
189 | } |
190 | stride *= large[i]; |
191 | indices = std::move(nextIndices); |
192 | } |
193 | return indices; |
194 | } |
195 | |
196 | /// This pattern converts a vector.extract_strided_slice operation into a |
197 | /// vector.shuffle operation that has a rank-1 (linearized) operand and result. |
198 | /// |
199 | /// For example, the following: |
200 | /// |
201 | /// ``` |
202 | /// vector.extract_strided_slice %source |
203 | /// { offsets = [..], strides = [..], sizes = [..] } |
204 | /// ``` |
205 | /// |
206 | /// is converted to : |
207 | /// ``` |
208 | /// %source_1d = vector.shape_cast %source |
209 | /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] |
210 | /// %out_nd = vector.shape_cast %out_1d |
211 | /// ``` |
212 | /// |
213 | /// `shuffle_indices_1d` is computed using the offsets and sizes of the original |
214 | /// vector.extract_strided_slice operation. |
215 | struct LinearizeVectorExtractStridedSlice final |
216 | : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { |
217 | using OpConversionPattern::OpConversionPattern; |
218 | LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, |
219 | MLIRContext *context, |
220 | PatternBenefit benefit = 1) |
221 | : OpConversionPattern(typeConverter, context, benefit) {} |
222 | |
223 | LogicalResult |
224 | matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp, |
225 | OpAdaptor adaptor, |
226 | ConversionPatternRewriter &rewriter) const override { |
227 | |
228 | VectorType flatOutputType = getTypeConverter()->convertType<VectorType>( |
229 | extractStridedSliceOp.getType()); |
230 | assert(flatOutputType && "vector type expected"); |
231 | |
232 | // Expect a legalization failure if the strides are not all 1 (if ever the |
233 | // verifier for extract_strided_slice allows non-1 strides). |
234 | if (!stridesAllOne(extractStridedSliceOp)) { |
235 | return rewriter.notifyMatchFailure( |
236 | extractStridedSliceOp, |
237 | "extract_strided_slice with strides != 1 not supported"); |
238 | } |
239 | |
240 | FailureOr<SmallVector<int64_t>> offsets = |
241 | intsFromArrayAttr(extractStridedSliceOp.getOffsets()); |
242 | if (failed(Result: offsets)) { |
243 | return rewriter.notifyMatchFailure(extractStridedSliceOp, |
244 | "failed to get integer offsets"); |
245 | } |
246 | |
247 | ArrayRef<int64_t> inputShape = |
248 | extractStridedSliceOp.getSourceVectorType().getShape(); |
249 | |
250 | ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape(); |
251 | |
252 | SmallVector<int64_t> indices = getStridedSliceInsertionIndices( |
253 | small: outputShape, large: inputShape, offsets: offsets.value()); |
254 | |
255 | Value srcVector = adaptor.getVector(); |
256 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
257 | extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); |
258 | return success(); |
259 | } |
260 | }; |
261 | |
262 | /// This pattern converts a vector.insert_strided_slice operation into a |
263 | /// vector.shuffle operation that has rank-1 (linearized) operands and result. |
264 | /// |
265 | /// For example, the following: |
266 | /// ``` |
267 | /// %0 = vector.insert_strided_slice %to_store, %into |
268 | /// {offsets = [1, 0, 0, 0], strides = [1, 1]} |
269 | /// : vector<2x2xi8> into vector<2x1x3x2xi8> |
270 | /// ``` |
271 | /// |
272 | /// is converted to |
273 | /// ``` |
274 | /// %to_store_1d |
275 | /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> |
276 | /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> |
277 | /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] |
278 | /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> |
279 | /// ``` |
280 | /// |
281 | /// where shuffle_indices_1d in this case is |
282 | /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. |
283 | /// ^^^^^^^^^^^^^^ |
284 | /// to_store_1d |
285 | /// |
286 | struct LinearizeVectorInsertStridedSlice final |
287 | : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> { |
288 | using OpConversionPattern::OpConversionPattern; |
289 | LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, |
290 | MLIRContext *context, |
291 | PatternBenefit benefit = 1) |
292 | : OpConversionPattern(typeConverter, context, benefit) {} |
293 | |
294 | LogicalResult |
295 | matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp, |
296 | OpAdaptor adaptor, |
297 | ConversionPatternRewriter &rewriter) const override { |
298 | |
299 | // Expect a legalization failure if the strides are not all 1 (if ever the |
300 | // verifier for insert_strided_slice allows non-1 strides). |
301 | if (!stridesAllOne(insertStridedSliceOp)) { |
302 | return rewriter.notifyMatchFailure( |
303 | insertStridedSliceOp, |
304 | "insert_strided_slice with strides != 1 not supported"); |
305 | } |
306 | |
307 | VectorType inputType = insertStridedSliceOp.getValueToStore().getType(); |
308 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
309 | |
310 | VectorType outputType = insertStridedSliceOp.getType(); |
311 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
312 | int64_t nOutputElements = outputType.getNumElements(); |
313 | |
314 | FailureOr<SmallVector<int64_t>> offsets = |
315 | intsFromArrayAttr(insertStridedSliceOp.getOffsets()); |
316 | if (failed(Result: offsets)) { |
317 | return rewriter.notifyMatchFailure(insertStridedSliceOp, |
318 | "failed to get integer offsets"); |
319 | } |
320 | SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices( |
321 | small: inputShape, large: outputShape, offsets: offsets.value()); |
322 | |
323 | SmallVector<int64_t> indices(nOutputElements); |
324 | std::iota(first: indices.begin(), last: indices.end(), value: 0); |
325 | for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { |
326 | indices[sliceIndex] = index + nOutputElements; |
327 | } |
328 | |
329 | Value flatToStore = adaptor.getValueToStore(); |
330 | Value flatDest = adaptor.getDest(); |
331 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp, |
332 | flatDest.getType(), flatDest, |
333 | flatToStore, indices); |
334 | return success(); |
335 | } |
336 | }; |
337 | |
338 | /// This pattern converts the ShuffleOp that works on nD (n > 1) |
339 | /// vectors to a ShuffleOp that works on linearized vectors. |
340 | /// Following, |
341 | /// vector.shuffle %v1, %v2 [ shuffle_indices ] |
342 | /// is converted to : |
343 | /// %v1_1d = vector.shape_cast %v1 |
344 | /// %v2_1d = vector.shape_cast %v2 |
345 | /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] |
346 | /// %out_nd = vector.shape_cast %out_1d |
347 | // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` |
348 | /// of the original shuffle operation. |
349 | struct LinearizeVectorShuffle final |
350 | : public OpConversionPattern<vector::ShuffleOp> { |
351 | using OpConversionPattern::OpConversionPattern; |
352 | LinearizeVectorShuffle(const TypeConverter &typeConverter, |
353 | MLIRContext *context, PatternBenefit benefit = 1) |
354 | : OpConversionPattern(typeConverter, context, benefit) {} |
355 | |
356 | LogicalResult |
357 | matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, |
358 | ConversionPatternRewriter &rewriter) const override { |
359 | VectorType dstType = |
360 | getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); |
361 | assert(dstType && "vector type destination expected."); |
362 | |
363 | Value vec1 = adaptor.getV1(); |
364 | Value vec2 = adaptor.getV2(); |
365 | int shuffleSliceLen = 1; |
366 | int rank = shuffleOp.getV1().getType().getRank(); |
367 | |
368 | // If rank > 1, we need to do the shuffle in the granularity of slices |
369 | // instead of scalars. Size of the slice is equal to the rank-1 innermost |
370 | // dims. Mask of the shuffle op specifies which slice to take from the |
371 | // outermost dim. |
372 | if (rank > 1) { |
373 | llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); |
374 | for (unsigned i = 1; i < shape.size(); ++i) { |
375 | shuffleSliceLen *= shape[i]; |
376 | } |
377 | } |
378 | |
379 | // For each value in the mask, we generate the indices of the source vectors |
380 | // that need to be shuffled to the destination vector. If shuffleSliceLen > |
381 | // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of |
382 | // elements) instead of scalars. |
383 | ArrayRef<int64_t> mask = shuffleOp.getMask(); |
384 | int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; |
385 | llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); |
386 | for (auto [i, value] : llvm::enumerate(mask)) { |
387 | std::iota(indices.begin() + shuffleSliceLen * i, |
388 | indices.begin() + shuffleSliceLen * (i + 1), |
389 | shuffleSliceLen * value); |
390 | } |
391 | |
392 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1, |
393 | vec2, indices); |
394 | return success(); |
395 | } |
396 | }; |
397 | |
398 | /// This pattern converts the ExtractOp to a ShuffleOp that works on a |
399 | /// linearized vector. |
400 | /// Following, |
401 | /// vector.extract %source [ position ] |
402 | /// is converted to : |
403 | /// %source_1d = vector.shape_cast %source |
404 | /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] |
405 | /// %out_nd = vector.shape_cast %out_1d |
406 | /// `shuffle_indices_1d` is computed using the position of the original extract. |
407 | struct LinearizeVectorExtract final |
408 | : public OpConversionPattern<vector::ExtractOp> { |
409 | using OpConversionPattern::OpConversionPattern; |
410 | LinearizeVectorExtract(const TypeConverter &typeConverter, |
411 | MLIRContext *context, PatternBenefit benefit = 1) |
412 | : OpConversionPattern(typeConverter, context, benefit) {} |
413 | LogicalResult |
414 | matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, |
415 | ConversionPatternRewriter &rewriter) const override { |
416 | // Skip if result is not a vector type |
417 | if (!isa<VectorType>(extractOp.getType())) |
418 | return rewriter.notifyMatchFailure(extractOp, |
419 | "scalar extract not supported"); |
420 | Type dstTy = getTypeConverter()->convertType(extractOp.getType()); |
421 | assert(dstTy && "expected 1-D vector type"); |
422 | |
423 | // Dynamic position is not supported. |
424 | if (extractOp.hasDynamicPosition()) |
425 | return rewriter.notifyMatchFailure(extractOp, |
426 | "dynamic position is not supported."); |
427 | |
428 | llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); |
429 | int64_t size = extractOp.getVector().getType().getNumElements(); |
430 | |
431 | // Compute linearized offset. |
432 | int64_t linearizedOffset = 0; |
433 | llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); |
434 | for (auto [i, off] : llvm::enumerate(offsets)) { |
435 | size /= shape[i]; |
436 | linearizedOffset += offsets[i] * size; |
437 | } |
438 | |
439 | llvm::SmallVector<int64_t, 2> indices(size); |
440 | std::iota(first: indices.begin(), last: indices.end(), value: linearizedOffset); |
441 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
442 | extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); |
443 | |
444 | return success(); |
445 | } |
446 | }; |
447 | |
448 | /// This pattern converts the InsertOp to a ShuffleOp that works on a |
449 | /// linearized vector. |
450 | /// Following, |
451 | /// vector.insert %source %destination [ position ] |
452 | /// is converted to : |
453 | /// %source_1d = vector.shape_cast %source |
454 | /// %destination_1d = vector.shape_cast %destination |
455 | /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d |
456 | /// ] %out_nd = vector.shape_cast %out_1d |
457 | /// `shuffle_indices_1d` is computed using the position of the original insert. |
458 | struct LinearizeVectorInsert final |
459 | : public OpConversionPattern<vector::InsertOp> { |
460 | using OpConversionPattern::OpConversionPattern; |
461 | LinearizeVectorInsert(const TypeConverter &typeConverter, |
462 | MLIRContext *context, PatternBenefit benefit = 1) |
463 | : OpConversionPattern(typeConverter, context, benefit) {} |
464 | LogicalResult |
465 | matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, |
466 | ConversionPatternRewriter &rewriter) const override { |
467 | VectorType dstTy = getTypeConverter()->convertType<VectorType>( |
468 | insertOp.getDestVectorType()); |
469 | assert(dstTy && "vector type destination expected."); |
470 | |
471 | // dynamic position is not supported |
472 | if (insertOp.hasDynamicPosition()) |
473 | return rewriter.notifyMatchFailure(insertOp, |
474 | "dynamic position is not supported."); |
475 | auto srcTy = insertOp.getValueToStoreType(); |
476 | auto srcAsVec = dyn_cast<VectorType>(srcTy); |
477 | uint64_t srcSize = 0; |
478 | if (srcAsVec) { |
479 | srcSize = srcAsVec.getNumElements(); |
480 | } else { |
481 | return rewriter.notifyMatchFailure(insertOp, |
482 | "scalars are not supported."); |
483 | } |
484 | |
485 | auto dstShape = insertOp.getDestVectorType().getShape(); |
486 | const auto dstSize = insertOp.getDestVectorType().getNumElements(); |
487 | auto dstSizeForOffsets = dstSize; |
488 | |
489 | // compute linearized offset |
490 | int64_t linearizedOffset = 0; |
491 | auto offsetsNd = insertOp.getStaticPosition(); |
492 | for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { |
493 | dstSizeForOffsets /= dstShape[dim]; |
494 | linearizedOffset += offset * dstSizeForOffsets; |
495 | } |
496 | |
497 | llvm::SmallVector<int64_t, 2> indices(dstSize); |
498 | auto *origValsUntil = indices.begin(); |
499 | std::advance(origValsUntil, linearizedOffset); |
500 | std::iota(indices.begin(), origValsUntil, |
501 | 0); // original values that remain [0, offset) |
502 | auto *newValsUntil = origValsUntil; |
503 | std::advance(newValsUntil, srcSize); |
504 | std::iota(origValsUntil, newValsUntil, |
505 | dstSize); // new values [offset, offset+srcNumElements) |
506 | std::iota(newValsUntil, indices.end(), |
507 | linearizedOffset + srcSize); // the rest of original values |
508 | // [offset+srcNumElements, end) |
509 | |
510 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
511 | insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices); |
512 | |
513 | return success(); |
514 | } |
515 | }; |
516 | |
517 | /// This pattern converts the BitCastOp that works on nD (n > 1) |
518 | /// vectors to a BitCastOp that works on linearized vectors. |
519 | /// Following, |
520 | /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> |
521 | /// is converted to : |
522 | /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> |
523 | /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> |
524 | /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> |
525 | struct LinearizeVectorBitCast final |
526 | : public OpConversionPattern<vector::BitCastOp> { |
527 | using OpConversionPattern::OpConversionPattern; |
528 | LinearizeVectorBitCast(const TypeConverter &typeConverter, |
529 | MLIRContext *context, PatternBenefit benefit = 1) |
530 | : OpConversionPattern(typeConverter, context, benefit) {} |
531 | LogicalResult |
532 | matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, |
533 | ConversionPatternRewriter &rewriter) const override { |
534 | auto resType = getTypeConverter()->convertType(castOp.getType()); |
535 | assert(resType && "expected 1-D vector type"); |
536 | rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, |
537 | adaptor.getSource()); |
538 | return mlir::success(); |
539 | } |
540 | }; |
541 | |
542 | /// This pattern converts the SplatOp to work on a linearized vector. |
543 | /// Following, |
544 | /// vector.splat %value : vector<4x4xf32> |
545 | /// is converted to: |
546 | /// %out_1d = vector.splat %value : vector<16xf32> |
547 | /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> |
548 | struct LinearizeVectorSplat final |
549 | : public OpConversionPattern<vector::SplatOp> { |
550 | using OpConversionPattern::OpConversionPattern; |
551 | |
552 | LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, |
553 | PatternBenefit benefit = 1) |
554 | : OpConversionPattern(typeConverter, context, benefit) {} |
555 | |
556 | LogicalResult |
557 | matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, |
558 | ConversionPatternRewriter &rewriter) const override { |
559 | auto dstTy = getTypeConverter()->convertType(splatOp.getType()); |
560 | if (!dstTy) |
561 | return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); |
562 | rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(), |
563 | dstTy); |
564 | return success(); |
565 | } |
566 | }; |
567 | |
568 | /// This pattern converts the CreateMaskOp to work on a linearized vector. |
569 | /// It currently supports only 2D masks with a unit outer dimension. |
570 | /// Following, |
571 | /// vector.create_mask %arg0, %arg1 : vector<1x4xi1> |
572 | /// is converted to: |
573 | /// %zero = arith.constant 0 : index |
574 | /// %cmpi = arith.cmpi sgt, %arg0, %zero : index |
575 | /// %index = arith.index_cast %cmpi : i1 to index |
576 | /// %mul = arith.andi %index, %arg1 : index |
577 | /// %mask = vector.create_mask %mul : vector<4xi1> |
578 | /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> |
579 | struct LinearizeVectorCreateMask final |
580 | : OpConversionPattern<vector::CreateMaskOp> { |
581 | using OpConversionPattern::OpConversionPattern; |
582 | |
583 | LinearizeVectorCreateMask(const TypeConverter &typeConverter, |
584 | MLIRContext *context, PatternBenefit benefit = 1) |
585 | : OpConversionPattern(typeConverter, context, benefit) {} |
586 | |
587 | LogicalResult |
588 | matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, |
589 | ConversionPatternRewriter &rewriter) const override { |
590 | Location loc = createMaskOp.getLoc(); |
591 | VectorType srcTy = createMaskOp.getType(); |
592 | auto srcShape = srcTy.getShape(); |
593 | if (srcShape.size() != 2) |
594 | return rewriter.notifyMatchFailure(createMaskOp, |
595 | "only 2D mask is supported."); |
596 | |
597 | if (srcShape[0] != 1) |
598 | return rewriter.notifyMatchFailure( |
599 | createMaskOp, "only unit outer dimension is supported."); |
600 | |
601 | auto dstTy = getTypeConverter()->convertType(srcTy); |
602 | if (!dstTy) |
603 | return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); |
604 | |
605 | // Compare the first operand with 0. If it is greater than 0, the |
606 | // corresponding mask element is set to true, otherwise false. |
607 | // The result of the comparison is then multiplied with |
608 | // the second operand of create_mask to get the 1D mask. |
609 | auto firstOperand = adaptor.getOperands().front(); |
610 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(location: loc, args: 0); |
611 | auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>( |
612 | loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); |
613 | auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>( |
614 | loc, rewriter.getIndexType(), isNonZero); |
615 | auto secondOperand = adaptor.getOperands().back(); |
616 | auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>( |
617 | loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); |
618 | |
619 | auto newMask = |
620 | rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize); |
621 | rewriter.replaceOp(createMaskOp, newMask); |
622 | return success(); |
623 | } |
624 | }; |
625 | |
626 | } // namespace |
627 | |
628 | /// This method defines the set of operations that are linearizable, and hence |
629 | /// that are considered illegal for the conversion target. |
630 | static bool isLinearizable(Operation *op) { |
631 | |
632 | // Only ops that are in the vector dialect, are ConstantLike, or |
633 | // are Vectorizable might be linearized currently. |
634 | StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); |
635 | StringRef opDialect = op->getDialect()->getNamespace(); |
636 | bool supported = (opDialect == vectorDialect) || |
637 | op->hasTrait<OpTrait::ConstantLike>() || |
638 | op->hasTrait<OpTrait::Vectorizable>(); |
639 | if (!supported) |
640 | return false; |
641 | |
642 | return TypeSwitch<Operation *, bool>(op) |
643 | // As type legalization is done with vector.shape_cast, shape_cast |
644 | // itself cannot be linearized (will create new shape_casts to linearize |
645 | // ad infinitum). |
646 | .Case<vector::ShapeCastOp>([&](auto) { return false; }) |
647 | // The operations |
648 | // - vector.extract_strided_slice |
649 | // - vector.extract |
650 | // - vector.insert_strided_slice |
651 | // - vector.insert |
652 | // are linearized to a rank-1 vector.shuffle by the current patterns. |
653 | // vector.shuffle only supports fixed size vectors, so it is impossible to |
654 | // use this approach to linearize these ops if they operate on scalable |
655 | // vectors. |
656 | .Case<vector::ExtractStridedSliceOp>( |
657 | [&](vector::ExtractStridedSliceOp extractOp) { |
658 | return !extractOp.getType().isScalable(); |
659 | }) |
660 | .Case<vector::InsertStridedSliceOp>( |
661 | [&](vector::InsertStridedSliceOp insertOp) { |
662 | return !insertOp.getType().isScalable(); |
663 | }) |
664 | .Case<vector::InsertOp>([&](vector::InsertOp insertOp) { |
665 | return !insertOp.getType().isScalable(); |
666 | }) |
667 | .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) { |
668 | return !extractOp.getSourceVectorType().isScalable(); |
669 | }) |
670 | .Default([&](auto) { return true; }); |
671 | } |
672 | |
673 | void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, |
674 | ConversionTarget &target) { |
675 | |
676 | auto convertType = [](Type type) -> std::optional<Type> { |
677 | VectorType vectorType = dyn_cast<VectorType>(type); |
678 | if (!vectorType || !isLinearizableVector(vectorType)) |
679 | return type; |
680 | |
681 | VectorType linearizedType = |
682 | VectorType::get(vectorType.getNumElements(), |
683 | vectorType.getElementType(), vectorType.isScalable()); |
684 | return linearizedType; |
685 | }; |
686 | typeConverter.addConversion(callback&: convertType); |
687 | |
688 | auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, |
689 | Location loc) -> Value { |
690 | if (inputs.size() != 1) |
691 | return nullptr; |
692 | |
693 | Value value = inputs.front(); |
694 | if (!isa<VectorType>(Val: type) || !isa<VectorType>(Val: value.getType())) |
695 | return nullptr; |
696 | |
697 | return builder.create<vector::ShapeCastOp>(loc, type, value); |
698 | }; |
699 | typeConverter.addSourceMaterialization(callback&: materializeCast); |
700 | typeConverter.addTargetMaterialization(callback&: materializeCast); |
701 | |
702 | target.markUnknownOpDynamicallyLegal( |
703 | fn: [=](Operation *op) -> std::optional<bool> { |
704 | if (!isLinearizable(op)) |
705 | return true; |
706 | // This will return true if, for all operand and result types `t`, |
707 | // convertType(t) = t. This is true if there are no rank>=2 vectors. |
708 | return typeConverter.isLegal(op); |
709 | }); |
710 | } |
711 | |
712 | void mlir::vector::populateVectorLinearizeBasePatterns( |
713 | const TypeConverter &typeConverter, const ConversionTarget &target, |
714 | RewritePatternSet &patterns) { |
715 | patterns |
716 | .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, |
717 | LinearizeVectorSplat, LinearizeVectorCreateMask>( |
718 | arg: typeConverter, args: patterns.getContext()); |
719 | } |
720 | |
721 | void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( |
722 | const TypeConverter &typeConverter, const ConversionTarget &target, |
723 | RewritePatternSet &patterns) { |
724 | patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, |
725 | LinearizeVectorInsert, LinearizeVectorExtractStridedSlice, |
726 | LinearizeVectorInsertStridedSlice>(arg: typeConverter, |
727 | args: patterns.getContext()); |
728 | } |
729 |
Definitions
- linearizeConstAttr
- LinearizeConstantLike
- LinearizeConstantLike
- matchAndRewrite
- LinearizeVectorizable
- LinearizeVectorizable
- matchAndRewrite
- stridesAllOne
- intsFromArrayAttr
- getStridedSliceInsertionIndices
- LinearizeVectorExtractStridedSlice
- LinearizeVectorExtractStridedSlice
- matchAndRewrite
- LinearizeVectorInsertStridedSlice
- LinearizeVectorInsertStridedSlice
- matchAndRewrite
- LinearizeVectorShuffle
- LinearizeVectorShuffle
- matchAndRewrite
- LinearizeVectorExtract
- LinearizeVectorExtract
- matchAndRewrite
- LinearizeVectorInsert
- LinearizeVectorInsert
- matchAndRewrite
- LinearizeVectorBitCast
- LinearizeVectorBitCast
- matchAndRewrite
- LinearizeVectorSplat
- LinearizeVectorSplat
- matchAndRewrite
- LinearizeVectorCreateMask
- LinearizeVectorCreateMask
- matchAndRewrite
- isLinearizable
- populateForVectorLinearize
- populateVectorLinearizeBasePatterns
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more