1 | //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass legalizes vector operations so they can be lowered to ArmSME. |
10 | // |
11 | // Note: In the context of this pass 'tile' always refers to an SME tile. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
17 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
18 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | #include "mlir/Transforms/OneToNTypeConversion.h" |
25 | |
26 | #define DEBUG_TYPE "arm-sme-vector-legalization" |
27 | |
28 | namespace mlir::arm_sme { |
29 | #define GEN_PASS_DEF_VECTORLEGALIZATION |
30 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
31 | } // namespace mlir::arm_sme |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::arm_sme; |
35 | |
36 | namespace { |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Decomposition of vector operations larger than an SME tile |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | // Common match failure reasons. |
43 | static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( |
44 | "op vector size is not multiple of SME tiles" ); |
45 | static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( |
46 | "op mask is unsupported for legalization/decomposition" ); |
47 | static constexpr StringLiteral |
48 | kMatchFailureNonPermutationMap("op affine map is not a permutation" ); |
49 | static constexpr StringLiteral kMatchFailureNotIllegalToLegal( |
50 | "expected transpose from illegal type to legal type" ); |
51 | |
52 | /// An SMESubTile represents a single SME-sized sub-tile from decomposing a |
53 | /// larger vector type. The (`row`, `col`) are the position of the tile in the |
54 | /// original vector type. For example for an [8]x[8] tile with four [4]x[4] |
55 | /// sub-tiles, we would have: |
56 | /// |
57 | /// 8 x vscale |
58 | /// ┌─────────────┬─────────────┐ |
59 | /// │(0,0) │(0,4) │ |
60 | /// │ │ │ |
61 | /// ├─────────────┼─────────────┤ 8 x vscale |
62 | /// │(4,0) │(4,4) │ |
63 | /// │ │ │ |
64 | /// └─────────────┴─────────────┘ |
65 | struct SMESubTile { |
66 | // Note: The units of (row, col) are vscale (as SME tiles are scalable). |
67 | int row{0}; |
68 | int col{0}; |
69 | // The SME tile type. |
70 | VectorType type; |
71 | }; |
72 | |
73 | /// Adds a constant elementwise scalable offset to `indices` (which are of equal |
74 | /// length). For example, in the 2D case this would return: |
75 | // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } |
76 | SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, |
77 | Location loc, |
78 | ValueRange indices, |
79 | ArrayRef<int> scalableOffsets) { |
80 | auto vscale = builder.create<vector::VectorScaleOp>(loc); |
81 | return llvm::map_to_vector( |
82 | C: llvm::zip_equal(t&: indices, u&: scalableOffsets), F: [&](auto pair) -> Value { |
83 | auto [index, base] = pair; |
84 | auto offset = builder.create<arith::MulIOp>( |
85 | loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); |
86 | return builder.create<arith::AddIOp>(loc, index, offset); |
87 | }); |
88 | } |
89 | |
90 | /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to |
91 | /// indices for one of the SME sub-tiles it will decompose into. |
92 | /// |
93 | /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the |
94 | /// indices for each tile would need to be adjusted as follows: |
95 | /// |
96 | /// initial indices = [a,b], inital size = 8x8, target size = 4x4 |
97 | /// ┌─────────────┬─────────────┐ |
98 | /// │[a,b] │[a,b+4] │ |
99 | /// │ │ │ |
100 | /// ├─────────────┼─────────────┤ |
101 | /// │[a+4,b] │[a+4,b+4] │ |
102 | /// │ │ │ |
103 | /// └─────────────┴─────────────┘ |
104 | SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, |
105 | ValueRange indices, |
106 | SMESubTile smeTile) { |
107 | return addConstantScalableOffset(builder, loc, indices, |
108 | scalableOffsets: {smeTile.row, smeTile.col}); |
109 | } |
110 | |
111 | /// Returns true if `mask` is generated by an operation that can be decomposed |
112 | /// for SME. Currently, that is just no mask, or vector.create_mask. |
113 | /// TODO: Add support for vector.constant_mask once required for SME. |
114 | bool isSupportedMaskOp(Value mask) { |
115 | return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); |
116 | } |
117 | |
118 | /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. |
119 | Value (OpBuilder &builder, Location loc, Value mask, |
120 | SMESubTile smeTile) { |
121 | assert(isSupportedMaskOp(mask)); |
122 | if (!mask) |
123 | return Value{}; |
124 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
125 | // The operands of `vector.create_mask` (from a 2D perspective) are the |
126 | // coordinates where the mask ends. So we subtract where this tile starts, |
127 | // from the mask operands to get the parameters for this sub-tile. |
128 | auto smeTileMaskDims = addConstantScalableOffset( |
129 | builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); |
130 | auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( |
131 | loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); |
132 | return smeTileCreateMask.getResult(); |
133 | } |
134 | |
135 | /// Constructs an iterator that returns each SME tile (with coordinates) |
136 | /// contained within a VectorType. For example, if decomposing an [8]x[8] into |
137 | /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), |
138 | /// (4, 4). |
139 | auto decomposeToSMETiles(OpBuilder &builder, VectorType type, |
140 | VectorType smeTileType, |
141 | bool transposeIndices = false) { |
142 | assert(isMultipleOfSMETileVectorType(type) && |
143 | "`type` not multiple of SME tiles" ); |
144 | return llvm::map_range( |
145 | StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0), |
146 | smeTileType.getDimSize(1)}), |
147 | [=](auto indices) { |
148 | int row = int(indices[0]); |
149 | int col = int(indices[1]); |
150 | if (transposeIndices) |
151 | std::swap(a&: row, b&: col); |
152 | return SMESubTile{row, col, smeTileType}; |
153 | }); |
154 | } |
155 | |
156 | /// Returns the number of SME tiles that fit into the (2D-scalable) vector type |
157 | /// `type`. |
158 | int getNumberOfSMETilesForVectorType(VectorType type) { |
159 | assert(isMultipleOfSMETileVectorType(type) && |
160 | "`type` not multiple of SME tiles" ); |
161 | int64_t vectorRows = type.getDimSize(0); |
162 | int64_t vectorCols = type.getDimSize(1); |
163 | auto elementType = type.getElementType(); |
164 | unsigned minNumElts = getSMETileSliceMinNumElts(elementType); |
165 | return (vectorRows * vectorCols) / (minNumElts * minNumElts); |
166 | } |
167 | |
168 | /// Legalize `arith.constant dense<value>` splat operations to fit within SME |
169 | /// tiles by decomposing them into tile-sized operations. |
170 | struct LegalizeArithConstantOpsByDecomposition |
171 | : public OneToNOpConversionPattern<arith::ConstantOp> { |
172 | using OneToNOpConversionPattern::OneToNOpConversionPattern; |
173 | |
174 | LogicalResult |
175 | matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, |
176 | OneToNPatternRewriter &rewriter) const override { |
177 | auto vectorType = dyn_cast<VectorType>(constantOp.getType()); |
178 | auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); |
179 | if (!vectorType || !denseAttr || !denseAttr.isSplat()) |
180 | return failure(); |
181 | |
182 | if (!isMultipleOfSMETileVectorType(vectorType)) |
183 | return rewriter.notifyMatchFailure(constantOp, |
184 | kMatchFailureNotSMETileTypeMultiple); |
185 | |
186 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
187 | auto tileCount = getNumberOfSMETilesForVectorType(vectorType); |
188 | auto tileSplat = rewriter.create<arith::ConstantOp>( |
189 | constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); |
190 | rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat), |
191 | adaptor.getResultMapping()); |
192 | |
193 | return success(); |
194 | } |
195 | }; |
196 | |
197 | /// Legalize `vector.outerproduct` operations to fit within SME tiles by |
198 | /// decomposing them into tile-sized operations. |
199 | struct LegalizeVectorOuterProductOpsByDecomposition |
200 | : public OneToNOpConversionPattern<vector::OuterProductOp> { |
201 | using OneToNOpConversionPattern::OneToNOpConversionPattern; |
202 | |
203 | LogicalResult |
204 | matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor, |
205 | OneToNPatternRewriter &rewriter) const override { |
206 | auto vectorType = outerProductOp.getResultVectorType(); |
207 | if (!isMultipleOfSMETileVectorType(vectorType)) |
208 | return rewriter.notifyMatchFailure(outerProductOp, |
209 | kMatchFailureNotSMETileTypeMultiple); |
210 | |
211 | Value mask; |
212 | Operation *rootOp = outerProductOp; |
213 | auto loc = outerProductOp.getLoc(); |
214 | if (outerProductOp.isMasked()) { |
215 | auto maskOp = outerProductOp.getMaskingOp(); |
216 | mask = maskOp.getMask(); |
217 | rootOp = maskOp; |
218 | } |
219 | |
220 | if (!isSupportedMaskOp(mask)) |
221 | return rewriter.notifyMatchFailure(outerProductOp, |
222 | kMatchFailureUnsupportedMaskOp); |
223 | |
224 | ValueRange accSMETiles = adaptor.getAcc(); |
225 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
226 | VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); |
227 | |
228 | SmallVector<Value> resultSMETiles; |
229 | for (auto [index, smeTile] : llvm::enumerate( |
230 | decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
231 | |
232 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
233 | auto lhs = rewriter.create<vector::ScalableExtractOp>( |
234 | loc, sliceType, outerProductOp.getLhs(), smeTile.row); |
235 | auto rhs = rewriter.create<vector::ScalableExtractOp>( |
236 | loc, sliceType, outerProductOp.getRhs(), smeTile.col); |
237 | auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( |
238 | loc, smeTileType, lhs, rhs, |
239 | !accSMETiles.empty() ? accSMETiles[index] : Value{}, |
240 | outerProductOp.getKind()); |
241 | |
242 | auto maskedOuterProduct = |
243 | vector::maskOperation(rewriter, smeOuterProduct, smeMask); |
244 | resultSMETiles.push_back(maskedOuterProduct->getResult(0)); |
245 | } |
246 | |
247 | rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping()); |
248 | return success(); |
249 | } |
250 | }; |
251 | |
252 | // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to |
253 | // get the help of the type conversion), but doing so results in the type |
254 | // conversion adding target materializations in the `vector.mask` region |
255 | // (invalid). This pattern matches on `vector.mask` then calls into the |
256 | // `vector.outerproduct` pattern to work around this issue. |
257 | struct LegalizeMaskedVectorOuterProductOpsByDecomposition |
258 | : public OneToNOpConversionPattern<vector::MaskOp> { |
259 | using OneToNOpConversionPattern::OneToNOpConversionPattern; |
260 | |
261 | LogicalResult |
262 | matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, |
263 | OneToNPatternRewriter &rewriter) const override { |
264 | if (auto outerProductOp = |
265 | llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) { |
266 | LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), |
267 | getContext()); |
268 | return static_cast<RewritePattern &>(pattern).matchAndRewrite( |
269 | outerProductOp, rewriter); |
270 | } |
271 | return failure(); |
272 | } |
273 | }; |
274 | |
275 | /// Legalize `vector.transfer_read` operations to fit within SME tiles by |
276 | /// decomposing them into tile-sized operations. |
277 | struct LegalizeTransferReadOpsByDecomposition |
278 | : public OneToNOpConversionPattern<vector::TransferReadOp> { |
279 | using OneToNOpConversionPattern::OneToNOpConversionPattern; |
280 | |
281 | LogicalResult |
282 | matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, |
283 | OneToNPatternRewriter &rewriter) const override { |
284 | auto vectorType = readOp.getVectorType(); |
285 | if (!isMultipleOfSMETileVectorType(vectorType)) |
286 | return rewriter.notifyMatchFailure(readOp, |
287 | kMatchFailureNotSMETileTypeMultiple); |
288 | |
289 | auto mask = readOp.getMask(); |
290 | if (!isSupportedMaskOp(mask)) |
291 | return rewriter.notifyMatchFailure(readOp, |
292 | kMatchFailureUnsupportedMaskOp); |
293 | |
294 | auto permutationMap = readOp.getPermutationMap(); |
295 | if (!permutationMap.isPermutation()) |
296 | return rewriter.notifyMatchFailure(readOp, |
297 | kMatchFailureNonPermutationMap); |
298 | |
299 | // Note: For 2D vector types the only non-identity permutation is a simple |
300 | // tranpose [1, 0]. |
301 | bool transposed = !permutationMap.isIdentity(); |
302 | |
303 | auto loc = readOp.getLoc(); |
304 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
305 | |
306 | SmallVector<Value> resultSMETiles; |
307 | for (SMESubTile smeTile : |
308 | decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { |
309 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
310 | auto smeRead = rewriter.create<vector::TransferReadOp>( |
311 | loc, smeTileType, readOp.getSource(), |
312 | getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), |
313 | readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, |
314 | readOp.getInBoundsAttr()); |
315 | resultSMETiles.push_back(smeRead); |
316 | } |
317 | |
318 | rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping()); |
319 | return success(); |
320 | } |
321 | }; |
322 | |
323 | /// Legalize `vector.transfer_write` operations to fit within SME tiles by |
324 | /// decomposing them into tile-sized operations. |
325 | struct LegalizeTransferWriteOpsByDecomposition |
326 | : public OneToNOpConversionPattern<vector::TransferWriteOp> { |
327 | using OneToNOpConversionPattern::OneToNOpConversionPattern; |
328 | |
329 | LogicalResult |
330 | matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, |
331 | OneToNPatternRewriter &rewriter) const override { |
332 | auto vectorType = writeOp.getVectorType(); |
333 | if (!isMultipleOfSMETileVectorType(vectorType)) |
334 | return rewriter.notifyMatchFailure(writeOp, |
335 | kMatchFailureNotSMETileTypeMultiple); |
336 | |
337 | auto mask = writeOp.getMask(); |
338 | if (!isSupportedMaskOp(mask)) |
339 | return rewriter.notifyMatchFailure(writeOp, |
340 | kMatchFailureUnsupportedMaskOp); |
341 | |
342 | auto permutationMap = writeOp.getPermutationMap(); |
343 | if (!permutationMap.isPermutation()) |
344 | return rewriter.notifyMatchFailure(writeOp, |
345 | kMatchFailureNonPermutationMap); |
346 | |
347 | // Note: For 2D vector types the only non-identity permutation is a simple |
348 | // tranpose [1, 0]. |
349 | bool transposed = !permutationMap.isIdentity(); |
350 | |
351 | auto loc = writeOp.getLoc(); |
352 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
353 | auto inputSMETiles = adaptor.getVector(); |
354 | |
355 | Value destTensorOrMemref = writeOp.getSource(); |
356 | for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( |
357 | rewriter, vectorType, smeTileType, transposed))) { |
358 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
359 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
360 | loc, inputSMETiles[index], destTensorOrMemref, |
361 | getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), |
362 | writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); |
363 | if (writeOp.hasPureTensorSemantics()) |
364 | destTensorOrMemref = smeWrite.getResult(); |
365 | } |
366 | |
367 | if (writeOp.hasPureTensorSemantics()) |
368 | rewriter.replaceOp(writeOp, destTensorOrMemref); |
369 | else |
370 | rewriter.eraseOp(op: writeOp); |
371 | |
372 | return success(); |
373 | } |
374 | }; |
375 | |
376 | //===----------------------------------------------------------------------===// |
377 | // ArmSME-specific fixup canonicalizations/folds |
378 | //===----------------------------------------------------------------------===// |
379 | |
380 | /// Folds an extract from a 3D `vector.create_mask` (which is a vector of |
381 | /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is |
382 | /// necessary for the mask to be lowered to ArmSME. |
383 | /// |
384 | /// Example: |
385 | /// |
386 | /// BEFORE: |
387 | /// ```mlir |
388 | /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> |
389 | /// %subMask = vector.extract %mask[2] |
390 | /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> |
391 | /// ``` |
392 | /// |
393 | /// AFTER: |
394 | /// ```mlir |
395 | /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index |
396 | /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index |
397 | /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> |
398 | /// ``` |
399 | struct |
400 | : public OpRewritePattern<vector::ExtractOp> { |
401 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
402 | |
403 | LogicalResult matchAndRewrite(vector::ExtractOp , |
404 | PatternRewriter &rewriter) const override { |
405 | auto loc = extractOp.getLoc(); |
406 | auto createMaskOp = |
407 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
408 | if (!createMaskOp) |
409 | return rewriter.notifyMatchFailure( |
410 | extractOp, "extract not from vector.create_mask op" ); |
411 | |
412 | VectorType = |
413 | llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); |
414 | if (!extractedMaskType) |
415 | return rewriter.notifyMatchFailure(extractOp, |
416 | "extracted type is not a vector type" ); |
417 | |
418 | auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true); |
419 | if (numScalable != 2) |
420 | return rewriter.notifyMatchFailure( |
421 | extractOp, "expected extracted type to be an SME-like mask" ); |
422 | |
423 | // TODO: Support multiple extraction indices. |
424 | if (extractOp.getStaticPosition().size() != 1) |
425 | return rewriter.notifyMatchFailure( |
426 | extractOp, "only a single extraction index is supported" ); |
427 | |
428 | auto frontMaskDim = createMaskOp.getOperand(0); |
429 | if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) |
430 | return rewriter.notifyMatchFailure( |
431 | extractOp, |
432 | "constant vector.create_masks dims should be folded elsewhere" ); |
433 | |
434 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
435 | auto = getValueOrCreateConstantIndexOp( |
436 | rewriter, loc, extractOp.getMixedPosition()[0]); |
437 | auto = rewriter.create<arith::CmpIOp>( |
438 | loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, |
439 | frontMaskDim); |
440 | auto newMaskFrontDim = rewriter.create<arith::SelectOp>( |
441 | loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); |
442 | |
443 | rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( |
444 | extractOp, extractedMaskType, |
445 | ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); |
446 | return success(); |
447 | } |
448 | }; |
449 | |
450 | /// A vector type where no fixed dimension comes after a scalable dimension. |
451 | bool isLegalVectorType(VectorType vType) { |
452 | bool seenFixedDim = false; |
453 | for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { |
454 | seenFixedDim |= !scalableFlag; |
455 | if (seenFixedDim && scalableFlag) |
456 | return false; |
457 | } |
458 | return true; |
459 | } |
460 | |
461 | /// Lifts an illegal vector.transpose and vector.transfer_read to a |
462 | /// memref.subview + memref.transpose, followed by a legal read. |
463 | /// |
464 | /// 'Illegal' here means a leading scalable dimension and a fixed trailing |
465 | /// dimension, which has no valid lowering. |
466 | /// |
467 | /// The memref.transpose is metadata-only transpose that produces a strided |
468 | /// memref, which eventually becomes a loop reading individual elements. |
469 | /// |
470 | /// Example: |
471 | /// |
472 | /// BEFORE: |
473 | /// ```mlir |
474 | /// %illegalRead = vector.transfer_read %memref[%a, %b] |
475 | /// : memref<?x?xf32>, vector<[8]x4xf32> |
476 | /// %legalType = vector.transpose %illegalRead, [1, 0] |
477 | /// : vector<[8]x4xf32> to vector<4x[8]xf32> |
478 | /// ``` |
479 | /// |
480 | /// AFTER: |
481 | /// ```mlir |
482 | /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] |
483 | /// : memref<?x?xf32> to memref<?x?xf32> |
484 | /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) |
485 | /// : memref<?x?xf32> to memref<?x?xf32> |
486 | /// %legalType = vector.transfer_read %transpose[%c0, %c0] |
487 | /// : memref<?x?xf32>, vector<4x[8]xf32> |
488 | /// ``` |
489 | struct LiftIllegalVectorTransposeToMemory |
490 | : public OpRewritePattern<vector::TransposeOp> { |
491 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
492 | |
493 | static Value getExtensionSource(Operation *op) { |
494 | if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op)) |
495 | return op->getOperand(idx: 0); |
496 | return {}; |
497 | } |
498 | |
499 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
500 | PatternRewriter &rewriter) const override { |
501 | auto sourceType = transposeOp.getSourceVectorType(); |
502 | auto resultType = transposeOp.getResultVectorType(); |
503 | if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
504 | return rewriter.notifyMatchFailure(transposeOp, |
505 | kMatchFailureNotIllegalToLegal); |
506 | |
507 | // Look through extend for transfer_read. |
508 | Value maybeRead = transposeOp.getVector(); |
509 | auto *transposeSourceOp = maybeRead.getDefiningOp(); |
510 | Operation *extendOp = nullptr; |
511 | if (Value extendSource = getExtensionSource(op: transposeSourceOp)) { |
512 | maybeRead = extendSource; |
513 | extendOp = transposeSourceOp; |
514 | } |
515 | |
516 | auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); |
517 | if (!illegalRead) |
518 | return rewriter.notifyMatchFailure( |
519 | transposeOp, |
520 | "expected source to be (possibly extended) transfer_read" ); |
521 | |
522 | if (!illegalRead.getPermutationMap().isIdentity()) |
523 | return rewriter.notifyMatchFailure( |
524 | illegalRead, "expected read to have identity permutation map" ); |
525 | |
526 | auto loc = transposeOp.getLoc(); |
527 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
528 | auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
529 | |
530 | // Create a subview that matches the size of the illegal read vector type. |
531 | auto readType = illegalRead.getVectorType(); |
532 | auto readSizes = llvm::map_to_vector( |
533 | llvm::zip_equal(readType.getShape(), readType.getScalableDims()), |
534 | [&](auto dim) -> Value { |
535 | auto [size, isScalable] = dim; |
536 | auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); |
537 | if (!isScalable) |
538 | return dimSize; |
539 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
540 | return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); |
541 | }); |
542 | SmallVector<Value> strides(readType.getRank(), Value(one)); |
543 | auto readSubview = rewriter.create<memref::SubViewOp>( |
544 | loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes, |
545 | strides); |
546 | |
547 | // Apply the transpose to all values/attributes of the transfer_read: |
548 | // - The mask |
549 | Value mask = illegalRead.getMask(); |
550 | if (mask) { |
551 | // Note: The transpose for the mask should fold into the |
552 | // vector.create_mask/constant_mask op, which will then become legal. |
553 | mask = rewriter.create<vector::TransposeOp>(loc, mask, |
554 | transposeOp.getPermutation()); |
555 | } |
556 | // - The source memref |
557 | mlir::AffineMap transposeMap = AffineMap::getPermutationMap( |
558 | transposeOp.getPermutation(), getContext()); |
559 | auto transposedSubview = rewriter.create<memref::TransposeOp>( |
560 | loc, readSubview, AffineMapAttr::get(transposeMap)); |
561 | ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); |
562 | // - The `in_bounds` attribute |
563 | if (inBoundsAttr) { |
564 | SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), |
565 | inBoundsAttr.end()); |
566 | applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); |
567 | inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); |
568 | } |
569 | |
570 | VectorType legalReadType = resultType.clone(readType.getElementType()); |
571 | // Note: The indices are all zero as the subview is already offset. |
572 | SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); |
573 | auto legalRead = rewriter.create<vector::TransferReadOp>( |
574 | loc, legalReadType, transposedSubview, readIndices, |
575 | illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, |
576 | inBoundsAttr); |
577 | |
578 | // Replace the transpose with the new read, extending the result if |
579 | // necessary. |
580 | rewriter.replaceOp(transposeOp, [&]() -> Operation * { |
581 | if (extendOp) |
582 | return rewriter.create(loc, extendOp->getName().getIdentifier(), |
583 | Value(legalRead), resultType); |
584 | return legalRead; |
585 | }()); |
586 | |
587 | return success(); |
588 | } |
589 | }; |
590 | |
591 | /// A rewrite to turn unit dim transpose-like vector.shape_casts into |
592 | /// vector.transposes. The shape_cast has to be from an illegal vector type to a |
593 | /// legal one (as defined by isLegalVectorType). |
594 | /// |
595 | /// The reasoning for this is if we've got to this pass and we still have |
596 | /// shape_casts of illegal types, then they likely will not cancel out. Turning |
597 | /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to |
598 | /// eliminate them. |
599 | /// |
600 | /// Example: |
601 | /// |
602 | /// BEFORE: |
603 | /// ```mlir |
604 | /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> |
605 | /// ``` |
606 | /// |
607 | /// AFTER: |
608 | /// ```mlir |
609 | /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
610 | /// ``` |
611 | struct ConvertIllegalShapeCastOpsToTransposes |
612 | : public OpRewritePattern<vector::ShapeCastOp> { |
613 | using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; |
614 | |
615 | LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, |
616 | PatternRewriter &rewriter) const override { |
617 | auto sourceType = shapeCastOp.getSourceVectorType(); |
618 | auto resultType = shapeCastOp.getResultVectorType(); |
619 | if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
620 | return rewriter.notifyMatchFailure(shapeCastOp, |
621 | kMatchFailureNotIllegalToLegal); |
622 | |
623 | // Note: If we know that `sourceType` is an illegal vector type (and 2D) |
624 | // then dim 0 is scalable and dim 1 is fixed. |
625 | if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) |
626 | return rewriter.notifyMatchFailure( |
627 | shapeCastOp, "expected source to be a 2D scalable vector with a " |
628 | "trailing unit dim" ); |
629 | |
630 | auto loc = shapeCastOp.getLoc(); |
631 | auto transpose = rewriter.create<vector::TransposeOp>( |
632 | loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); |
633 | |
634 | if (resultType.getRank() == 1) |
635 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, |
636 | transpose); |
637 | else |
638 | rewriter.replaceOp(shapeCastOp, transpose); |
639 | |
640 | return success(); |
641 | } |
642 | }; |
643 | |
644 | struct VectorLegalizationPass |
645 | : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { |
646 | void runOnOperation() override { |
647 | auto *context = &getContext(); |
648 | OneToNTypeConverter converter; |
649 | RewritePatternSet patterns(context); |
650 | converter.addConversion(callback: [](Type type) { return type; }); |
651 | converter.addConversion( |
652 | callback: [](VectorType vectorType, |
653 | SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { |
654 | if (!isMultipleOfSMETileVectorType(vectorType)) |
655 | return std::nullopt; |
656 | auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); |
657 | auto smeTileType = |
658 | getSMETileTypeForElement(vectorType.getElementType()); |
659 | types = SmallVector<Type>(smeTileCount, smeTileType); |
660 | return success(); |
661 | }); |
662 | |
663 | patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, |
664 | LiftIllegalVectorTransposeToMemory, |
665 | ConvertIllegalShapeCastOpsToTransposes>(context); |
666 | // Note: High benefit to ensure masked outer products are lowered first. |
667 | patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>( |
668 | converter, context, 1024); |
669 | patterns.add<LegalizeArithConstantOpsByDecomposition, |
670 | LegalizeVectorOuterProductOpsByDecomposition, |
671 | LegalizeTransferReadOpsByDecomposition, |
672 | LegalizeTransferWriteOpsByDecomposition>(converter, context); |
673 | populateFuncTypeConversionPatterns(typeConverter&: converter, patterns); |
674 | scf::populateSCFStructuralOneToNTypeConversions(typeConverter&: converter, patterns); |
675 | |
676 | if (failed(applyPartialOneToNConversion(getOperation(), converter, |
677 | std::move(patterns)))) |
678 | return signalPassFailure(); |
679 | } |
680 | }; |
681 | |
682 | } // namespace |
683 | |
684 | std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { |
685 | return std::make_unique<VectorLegalizationPass>(); |
686 | } |
687 | |