1 | //===- VectorLinearize.cpp - vector linearization transforms --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns and pass for linearizing ND vectors into 1D. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
16 | #include "mlir/IR/Attributes.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/Operation.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | #include "mlir/Support/LogicalResult.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | #include "llvm/ADT/ArrayRef.h" |
24 | #include <cstdint> |
25 | #include <numeric> |
26 | |
27 | using namespace mlir; |
28 | |
29 | static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { |
30 | auto resultTypes = op->getResultTypes(); |
31 | for (auto resType : resultTypes) { |
32 | VectorType vecType = dyn_cast<VectorType>(resType); |
33 | // Reject index since getElementTypeBitWidth will abort for Index types. |
34 | if (!vecType || vecType.getElementType().isIndex()) |
35 | return false; |
36 | // There are no dimension to fold if it is a 0-D vector. |
37 | if (vecType.getRank() == 0) |
38 | return false; |
39 | unsigned trailingVecDimBitWidth = |
40 | vecType.getShape().back() * vecType.getElementTypeBitWidth(); |
41 | if (trailingVecDimBitWidth >= targetBitWidth) |
42 | return false; |
43 | } |
44 | return true; |
45 | } |
46 | |
47 | namespace { |
48 | struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { |
49 | using OpConversionPattern::OpConversionPattern; |
50 | LinearizeConstant( |
51 | const TypeConverter &typeConverter, MLIRContext *context, |
52 | unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
53 | PatternBenefit benefit = 1) |
54 | : OpConversionPattern(typeConverter, context, benefit), |
55 | targetVectorBitWidth(targetVectBitWidth) {} |
56 | LogicalResult |
57 | matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, |
58 | ConversionPatternRewriter &rewriter) const override { |
59 | Location loc = constOp.getLoc(); |
60 | auto resType = |
61 | getTypeConverter()->convertType<VectorType>(constOp.getType()); |
62 | |
63 | if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) |
64 | return rewriter.notifyMatchFailure( |
65 | arg&: loc, |
66 | msg: "Cannot linearize a constant scalable vector that's not a splat" ); |
67 | |
68 | if (!resType) |
69 | return rewriter.notifyMatchFailure(arg&: loc, msg: "can't convert return type" ); |
70 | if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) |
71 | return rewriter.notifyMatchFailure( |
72 | arg&: loc, msg: "Can't flatten since targetBitWidth <= OpSize" ); |
73 | auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); |
74 | if (!dstElementsAttr) |
75 | return rewriter.notifyMatchFailure(arg&: loc, msg: "unsupported attr type" ); |
76 | |
77 | dstElementsAttr = dstElementsAttr.reshape(resType); |
78 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, |
79 | dstElementsAttr); |
80 | return success(); |
81 | } |
82 | |
83 | private: |
84 | unsigned targetVectorBitWidth; |
85 | }; |
86 | |
87 | struct LinearizeVectorizable final |
88 | : OpTraitConversionPattern<OpTrait::Vectorizable> { |
89 | using OpTraitConversionPattern::OpTraitConversionPattern; |
90 | |
91 | public: |
92 | LinearizeVectorizable( |
93 | const TypeConverter &typeConverter, MLIRContext *context, |
94 | unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
95 | PatternBenefit benefit = 1) |
96 | : OpTraitConversionPattern(typeConverter, context, benefit), |
97 | targetVectorBitWidth(targetVectBitWidth) {} |
98 | LogicalResult |
99 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
100 | ConversionPatternRewriter &rewriter) const override { |
101 | if (!isLessThanTargetBitWidth(op, targetBitWidth: targetVectorBitWidth)) |
102 | return rewriter.notifyMatchFailure( |
103 | arg: op->getLoc(), msg: "Can't flatten since targetBitWidth <= OpSize" ); |
104 | FailureOr<Operation *> newOp = |
105 | convertOpResultTypes(op, operands, converter: *getTypeConverter(), rewriter); |
106 | if (failed(result: newOp)) |
107 | return failure(); |
108 | |
109 | rewriter.replaceOp(op, newValues: (*newOp)->getResults()); |
110 | return success(); |
111 | } |
112 | |
113 | private: |
114 | unsigned targetVectorBitWidth; |
115 | }; |
116 | |
117 | /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works |
118 | /// on a linearized vector. |
119 | /// Following, |
120 | /// vector.extract_strided_slice %source |
121 | /// { offsets = [..], strides = [..], sizes = [..] } |
122 | /// is converted to : |
123 | /// %source_1d = vector.shape_cast %source |
124 | /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] |
125 | /// %out_nd = vector.shape_cast %out_1d |
126 | /// `shuffle_indices_1d` is computed using the offsets and sizes of the |
127 | /// extraction. |
128 | struct final |
129 | : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { |
130 | using OpConversionPattern::OpConversionPattern; |
131 | ( |
132 | const TypeConverter &typeConverter, MLIRContext *context, |
133 | unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
134 | PatternBenefit benefit = 1) |
135 | : OpConversionPattern(typeConverter, context, benefit), |
136 | targetVectorBitWidth(targetVectBitWidth) {} |
137 | |
138 | LogicalResult |
139 | matchAndRewrite(vector::ExtractStridedSliceOp , OpAdaptor adaptor, |
140 | ConversionPatternRewriter &rewriter) const override { |
141 | Type dstType = getTypeConverter()->convertType(extractOp.getType()); |
142 | assert(!(extractOp.getVector().getType().isScalable() || |
143 | dstType.cast<VectorType>().isScalable()) && |
144 | "scalable vectors are not supported." ); |
145 | if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) |
146 | return rewriter.notifyMatchFailure( |
147 | extractOp, "Can't flatten since targetBitWidth <= OpSize" ); |
148 | |
149 | ArrayAttr offsets = extractOp.getOffsets(); |
150 | ArrayAttr sizes = extractOp.getSizes(); |
151 | ArrayAttr strides = extractOp.getStrides(); |
152 | if (!isConstantIntValue(strides[0], 1)) |
153 | return rewriter.notifyMatchFailure( |
154 | extractOp, "Strided slice with stride != 1 is not supported." ); |
155 | Value srcVector = adaptor.getVector(); |
156 | // If kD offsets are specified for nD source vector (n > k), the granularity |
157 | // of the extraction is greater than 1. In this case last (n-k) dimensions |
158 | // form the extraction granularity. |
159 | // Example : |
160 | // vector.extract_strided_slice %src { |
161 | // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : |
162 | // vector<4x8x8xf32> to vector<2x2x8xf32> |
163 | // Here, extraction granularity is 8. |
164 | int64_t = 1; |
165 | int64_t nD = extractOp.getSourceVectorType().getRank(); |
166 | int64_t kD = (int64_t)offsets.size(); |
167 | int64_t k = kD; |
168 | while (k < nD) { |
169 | extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; |
170 | ++k; |
171 | } |
172 | // Get total number of extracted slices. |
173 | int64_t = 1; |
174 | for (Attribute size : sizes) { |
175 | nExtractedSlices *= size.cast<IntegerAttr>().getInt(); |
176 | } |
177 | // Compute the strides of the source vector considering first k dimensions. |
178 | llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); |
179 | for (int i = kD - 2; i >= 0; --i) { |
180 | sourceStrides[i] = sourceStrides[i + 1] * |
181 | extractOp.getSourceVectorType().getShape()[i + 1]; |
182 | } |
183 | // Final shuffle indices has nExtractedSlices * extractGranularitySize |
184 | // elements. |
185 | llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * |
186 | extractGranularitySize); |
187 | // Compute the strides of the extracted kD vector. |
188 | llvm::SmallVector<int64_t, 4> (kD, 1); |
189 | // Compute extractedStrides. |
190 | for (int i = kD - 2; i >= 0; --i) { |
191 | extractedStrides[i] = |
192 | extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt(); |
193 | } |
194 | // Iterate over all extracted slices from 0 to nExtractedSlices - 1 |
195 | // and compute the multi-dimensional index and the corresponding linearized |
196 | // index within the source vector. |
197 | for (int64_t i = 0; i < nExtractedSlices; ++i) { |
198 | int64_t index = i; |
199 | // Compute the corresponding multi-dimensional index. |
200 | llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); |
201 | for (int64_t j = 0; j < kD; ++j) { |
202 | multiDimIndex[j] = (index / extractedStrides[j]); |
203 | index -= multiDimIndex[j] * extractedStrides[j]; |
204 | } |
205 | // Compute the corresponding linearized index in the source vector |
206 | // i.e. shift the multiDimIndex by the offsets. |
207 | int64_t linearizedIndex = 0; |
208 | for (int64_t j = 0; j < kD; ++j) { |
209 | linearizedIndex += |
210 | (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) * |
211 | sourceStrides[j]; |
212 | } |
213 | // Fill the indices array form linearizedIndex to linearizedIndex + |
214 | // extractGranularitySize. |
215 | for (int64_t j = 0; j < extractGranularitySize; ++j) { |
216 | indices[i * extractGranularitySize + j] = linearizedIndex + j; |
217 | } |
218 | } |
219 | // Perform a shuffle to extract the kD vector. |
220 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
221 | extractOp, dstType, srcVector, srcVector, |
222 | rewriter.getI64ArrayAttr(indices)); |
223 | return success(); |
224 | } |
225 | |
226 | private: |
227 | unsigned ; |
228 | }; |
229 | |
230 | /// This pattern converts the ShuffleOp that works on nD (n > 1) |
231 | /// vectors to a ShuffleOp that works on linearized vectors. |
232 | /// Following, |
233 | /// vector.shuffle %v1, %v2 [ shuffle_indices ] |
234 | /// is converted to : |
235 | /// %v1_1d = vector.shape_cast %v1 |
236 | /// %v2_1d = vector.shape_cast %v2 |
237 | /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] |
238 | /// %out_nd = vector.shape_cast %out_1d |
239 | // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` |
240 | /// of the original shuffle operation. |
241 | struct LinearizeVectorShuffle final |
242 | : public OpConversionPattern<vector::ShuffleOp> { |
243 | using OpConversionPattern::OpConversionPattern; |
244 | LinearizeVectorShuffle( |
245 | const TypeConverter &typeConverter, MLIRContext *context, |
246 | unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
247 | PatternBenefit benefit = 1) |
248 | : OpConversionPattern(typeConverter, context, benefit), |
249 | targetVectorBitWidth(targetVectBitWidth) {} |
250 | |
251 | LogicalResult |
252 | matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, |
253 | ConversionPatternRewriter &rewriter) const override { |
254 | Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); |
255 | assert(!(shuffleOp.getV1VectorType().isScalable() || |
256 | shuffleOp.getV2VectorType().isScalable() || |
257 | dstType.cast<VectorType>().isScalable()) && |
258 | "scalable vectors are not supported." ); |
259 | if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) |
260 | return rewriter.notifyMatchFailure( |
261 | shuffleOp, "Can't flatten since targetBitWidth <= OpSize" ); |
262 | |
263 | Value vec1 = adaptor.getV1(); |
264 | Value vec2 = adaptor.getV2(); |
265 | int shuffleSliceLen = 1; |
266 | int rank = shuffleOp.getV1().getType().getRank(); |
267 | |
268 | // If rank > 1, we need to do the shuffle in the granularity of slices |
269 | // instead of scalars. Size of the slice is equal to the rank-1 innermost |
270 | // dims. Mask of the shuffle op specifies which slice to take from the |
271 | // outermost dim. |
272 | if (rank > 1) { |
273 | llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); |
274 | for (unsigned i = 1; i < shape.size(); ++i) { |
275 | shuffleSliceLen *= shape[i]; |
276 | } |
277 | } |
278 | |
279 | // For each value in the mask, we generate the indices of the source vectors |
280 | // that needs to be shuffled to the destination vector. If shuffleSliceLen > |
281 | // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of |
282 | // elements) instead of scalars. |
283 | ArrayAttr mask = shuffleOp.getMask(); |
284 | int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; |
285 | llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); |
286 | for (auto [i, value] : |
287 | llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) { |
288 | |
289 | int64_t v = value.getZExtValue(); |
290 | std::iota(indices.begin() + shuffleSliceLen * i, |
291 | indices.begin() + shuffleSliceLen * (i + 1), |
292 | shuffleSliceLen * v); |
293 | } |
294 | |
295 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
296 | shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); |
297 | return success(); |
298 | } |
299 | |
300 | private: |
301 | unsigned targetVectorBitWidth; |
302 | }; |
303 | |
304 | /// This pattern converts the ExtractOp to a ShuffleOp that works on a |
305 | /// linearized vector. |
306 | /// Following, |
307 | /// vector.extract %source [ position ] |
308 | /// is converted to : |
309 | /// %source_1d = vector.shape_cast %source |
310 | /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] |
311 | /// %out_nd = vector.shape_cast %out_1d |
312 | /// `shuffle_indices_1d` is computed using the position of the original extract. |
313 | struct final |
314 | : public OpConversionPattern<vector::ExtractOp> { |
315 | using OpConversionPattern::OpConversionPattern; |
316 | ( |
317 | const TypeConverter &typeConverter, MLIRContext *context, |
318 | unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
319 | PatternBenefit benefit = 1) |
320 | : OpConversionPattern(typeConverter, context, benefit), |
321 | targetVectorBitWidth(targetVectBitWidth) {} |
322 | LogicalResult |
323 | matchAndRewrite(vector::ExtractOp , OpAdaptor adaptor, |
324 | ConversionPatternRewriter &rewriter) const override { |
325 | Type dstTy = getTypeConverter()->convertType(extractOp.getType()); |
326 | assert(!(extractOp.getVector().getType().isScalable() || |
327 | dstTy.cast<VectorType>().isScalable()) && |
328 | "scalable vectors are not supported." ); |
329 | if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) |
330 | return rewriter.notifyMatchFailure( |
331 | extractOp, "Can't flatten since targetBitWidth <= OpSize" ); |
332 | |
333 | // Dynamic position is not supported. |
334 | if (extractOp.hasDynamicPosition()) |
335 | return rewriter.notifyMatchFailure(extractOp, |
336 | "dynamic position is not supported." ); |
337 | |
338 | llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); |
339 | int64_t size = extractOp.getVector().getType().getNumElements(); |
340 | |
341 | // Compute linearized offset. |
342 | int64_t linearizedOffset = 0; |
343 | llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); |
344 | for (auto [i, off] : llvm::enumerate(offsets)) { |
345 | size /= shape[i]; |
346 | linearizedOffset += offsets[i] * size; |
347 | } |
348 | |
349 | llvm::SmallVector<int64_t, 2> indices(size); |
350 | std::iota(first: indices.begin(), last: indices.end(), value: linearizedOffset); |
351 | rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
352 | extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), |
353 | rewriter.getI64ArrayAttr(indices)); |
354 | |
355 | return success(); |
356 | } |
357 | |
358 | private: |
359 | unsigned ; |
360 | }; |
361 | } // namespace |
362 | |
363 | void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( |
364 | TypeConverter &typeConverter, RewritePatternSet &patterns, |
365 | ConversionTarget &target, unsigned targetBitWidth) { |
366 | |
367 | typeConverter.addConversion(callback: [](VectorType type) -> std::optional<Type> { |
368 | if (!isLinearizableVector(type)) |
369 | return type; |
370 | |
371 | return VectorType::get(type.getNumElements(), type.getElementType(), |
372 | type.isScalable()); |
373 | }); |
374 | |
375 | auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, |
376 | Location loc) -> Value { |
377 | if (inputs.size() != 1 || !isa<VectorType>(Val: inputs.front().getType()) || |
378 | !isa<VectorType>(Val: type)) |
379 | return nullptr; |
380 | |
381 | return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); |
382 | }; |
383 | typeConverter.addArgumentMaterialization(callback&: materializeCast); |
384 | typeConverter.addSourceMaterialization(callback&: materializeCast); |
385 | typeConverter.addTargetMaterialization(callback&: materializeCast); |
386 | target.markUnknownOpDynamicallyLegal( |
387 | fn: [=](Operation *op) -> std::optional<bool> { |
388 | if ((isa<arith::ConstantOp>(op) || |
389 | op->hasTrait<OpTrait::Vectorizable>())) { |
390 | return (isLessThanTargetBitWidth(op, targetBitWidth) |
391 | ? typeConverter.isLegal(op) |
392 | : true); |
393 | } |
394 | return std::nullopt; |
395 | }); |
396 | |
397 | patterns.add<LinearizeConstant, LinearizeVectorizable>( |
398 | arg&: typeConverter, args: patterns.getContext(), args&: targetBitWidth); |
399 | } |
400 | |
401 | void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( |
402 | TypeConverter &typeConverter, RewritePatternSet &patterns, |
403 | ConversionTarget &target, unsigned int targetBitWidth) { |
404 | target.addDynamicallyLegalOp<vector::ShuffleOp>( |
405 | [=](vector::ShuffleOp shuffleOp) -> bool { |
406 | return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) |
407 | ? (typeConverter.isLegal(shuffleOp) && |
408 | shuffleOp.getResult() |
409 | .getType() |
410 | .cast<mlir::VectorType>() |
411 | .getRank() == 1) |
412 | : true; |
413 | }); |
414 | patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, |
415 | LinearizeVectorExtractStridedSlice>( |
416 | arg&: typeConverter, args: patterns.getContext(), args&: targetBitWidth); |
417 | } |
418 | |