1 | //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass legalizes vector operations so they can be lowered to ArmSME. |
10 | // |
11 | // Note: In the context of this pass 'tile' always refers to an SME tile. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
17 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
18 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
21 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
22 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
28 | #include "mlir/Transforms/DialectConversion.h" |
29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
30 | |
31 | #define DEBUG_TYPE "arm-sme-vector-legalization" |
32 | |
33 | namespace mlir::arm_sme { |
34 | #define GEN_PASS_DEF_VECTORLEGALIZATION |
35 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
36 | } // namespace mlir::arm_sme |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::arm_sme; |
40 | |
41 | namespace { |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // Decomposition of vector operations larger than an SME tile |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | // Common match failure reasons. |
48 | static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( |
49 | "op vector size is not multiple of SME tiles"); |
50 | static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( |
51 | "op mask is unsupported for legalization/decomposition"); |
52 | static constexpr StringLiteral |
53 | kMatchFailureNonPermutationMap("op affine map is not a permutation"); |
54 | static constexpr StringLiteral kMatchFailureNotIllegalToLegal( |
55 | "expected transpose from illegal type to legal type"); |
56 | |
57 | /// An SMESubTile represents a single SME-sized sub-tile from decomposing a |
58 | /// larger vector type. The (`row`, `col`) are the position of the tile in the |
59 | /// original vector type. For example for an [8]x[8] tile with four [4]x[4] |
60 | /// sub-tiles, we would have: |
61 | /// |
62 | /// 8 x vscale |
63 | /// ┌─────────────┬─────────────┐ |
64 | /// │(0,0) │(0,4) │ |
65 | /// │ │ │ |
66 | /// ├─────────────┼─────────────┤ 8 x vscale |
67 | /// │(4,0) │(4,4) │ |
68 | /// │ │ │ |
69 | /// └─────────────┴─────────────┘ |
70 | struct SMESubTile { |
71 | // Note: The units of (row, col) are vscale (as SME tiles are scalable). |
72 | int row{0}; |
73 | int col{0}; |
74 | // The SME tile type. |
75 | VectorType type; |
76 | }; |
77 | |
78 | /// Adds a constant elementwise scalable offset to `indices` (which are of equal |
79 | /// length). For example, in the 2D case this would return: |
80 | // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } |
81 | SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, |
82 | Location loc, |
83 | ValueRange indices, |
84 | ArrayRef<int> scalableOffsets) { |
85 | auto vscale = builder.create<vector::VectorScaleOp>(loc); |
86 | return llvm::map_to_vector( |
87 | C: llvm::zip_equal(t&: indices, u&: scalableOffsets), F: [&](auto pair) -> Value { |
88 | auto [index, base] = pair; |
89 | auto offset = builder.create<arith::MulIOp>( |
90 | loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); |
91 | return builder.create<arith::AddIOp>(loc, index, offset); |
92 | }); |
93 | } |
94 | |
95 | /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to |
96 | /// indices for one of the SME sub-tiles it will decompose into. |
97 | /// |
98 | /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the |
99 | /// indices for each tile would need to be adjusted as follows: |
100 | /// |
101 | /// initial indices = [a,b], inital size = 8x8, target size = 4x4 |
102 | /// ┌─────────────┬─────────────┐ |
103 | /// │[a,b] │[a,b+4] │ |
104 | /// │ │ │ |
105 | /// ├─────────────┼─────────────┤ |
106 | /// │[a+4,b] │[a+4,b+4] │ |
107 | /// │ │ │ |
108 | /// └─────────────┴─────────────┘ |
109 | SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, |
110 | ValueRange indices, |
111 | SMESubTile smeTile) { |
112 | return addConstantScalableOffset(builder, loc, indices, |
113 | scalableOffsets: {smeTile.row, smeTile.col}); |
114 | } |
115 | |
116 | /// Returns true if `mask` is generated by an operation that can be decomposed |
117 | /// for SME. Currently, that is just no mask, or vector.create_mask. |
118 | /// TODO: Add support for vector.constant_mask once required for SME. |
119 | bool isSupportedMaskOp(Value mask) { |
120 | return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); |
121 | } |
122 | |
123 | /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. |
124 | Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, |
125 | SMESubTile smeTile) { |
126 | assert(isSupportedMaskOp(mask)); |
127 | if (!mask) |
128 | return Value{}; |
129 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
130 | // The operands of `vector.create_mask` (from a 2D perspective) are the |
131 | // coordinates where the mask ends. So we subtract where this tile starts, |
132 | // from the mask operands to get the parameters for this sub-tile. |
133 | auto smeTileMaskDims = addConstantScalableOffset( |
134 | builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); |
135 | auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( |
136 | loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); |
137 | return smeTileCreateMask.getResult(); |
138 | } |
139 | |
140 | /// Constructs an iterator that returns each SME tile (with coordinates) |
141 | /// contained within a VectorType. For example, if decomposing an [8]x[8] into |
142 | /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), |
143 | /// (4, 4). |
144 | auto decomposeToSMETiles(OpBuilder &builder, VectorType type, |
145 | VectorType smeTileType, |
146 | bool transposeIndices = false) { |
147 | return llvm::map_range( |
148 | StaticTileOffsetRange( |
149 | type.getShape(), |
150 | {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), |
151 | std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), |
152 | [=](auto indices) { |
153 | int row = int(indices[0]); |
154 | int col = int(indices[1]); |
155 | if (transposeIndices) |
156 | std::swap(a&: row, b&: col); |
157 | return SMESubTile{row, col, smeTileType}; |
158 | }); |
159 | } |
160 | |
161 | /// Returns the number of SME tiles that fit into the (2D-scalable) vector type |
162 | /// `type`. |
163 | int getNumberOfSMETilesForVectorType(VectorType type) { |
164 | assert(isMultipleOfSMETileVectorType(type) && |
165 | "`type` not multiple of SME tiles"); |
166 | int64_t vectorRows = type.getDimSize(0); |
167 | int64_t vectorCols = type.getDimSize(1); |
168 | auto elementType = type.getElementType(); |
169 | unsigned minNumElts = getSMETileSliceMinNumElts(elementType); |
170 | return (vectorRows * vectorCols) / (minNumElts * minNumElts); |
171 | } |
172 | |
173 | /// Legalize `arith.constant dense<value>` splat operations to fit within SME |
174 | /// tiles by decomposing them into tile-sized operations. |
175 | struct LegalizeArithConstantOpsByDecomposition |
176 | : public OpConversionPattern<arith::ConstantOp> { |
177 | using OpConversionPattern::OpConversionPattern; |
178 | |
179 | LogicalResult |
180 | matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, |
181 | ConversionPatternRewriter &rewriter) const override { |
182 | auto vectorType = dyn_cast<VectorType>(constantOp.getType()); |
183 | auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); |
184 | if (!vectorType || !denseAttr || !denseAttr.isSplat()) |
185 | return failure(); |
186 | |
187 | if (!isMultipleOfSMETileVectorType(vectorType)) |
188 | return rewriter.notifyMatchFailure(constantOp, |
189 | kMatchFailureNotSMETileTypeMultiple); |
190 | |
191 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
192 | auto tileCount = getNumberOfSMETilesForVectorType(vectorType); |
193 | auto tileSplat = rewriter.create<arith::ConstantOp>( |
194 | constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); |
195 | SmallVector<Value> repl(tileCount, tileSplat); |
196 | rewriter.replaceOpWithMultiple(constantOp, {repl}); |
197 | |
198 | return success(); |
199 | } |
200 | }; |
201 | |
202 | /// Legalize `vector.outerproduct` operations to fit within SME tiles by |
203 | /// decomposing them into tile-sized operations. |
204 | struct LegalizeVectorOuterProductOpsByDecomposition |
205 | : public OpConversionPattern<vector::OuterProductOp> { |
206 | using OpConversionPattern::OpConversionPattern; |
207 | |
208 | LogicalResult |
209 | matchAndRewrite(vector::OuterProductOp outerProductOp, |
210 | OneToNOpAdaptor adaptor, |
211 | ConversionPatternRewriter &rewriter) const override { |
212 | auto vectorType = outerProductOp.getResultVectorType(); |
213 | if (!isMultipleOfSMETileVectorType(vectorType)) |
214 | return rewriter.notifyMatchFailure(outerProductOp, |
215 | kMatchFailureNotSMETileTypeMultiple); |
216 | |
217 | Value mask; |
218 | Operation *rootOp = outerProductOp; |
219 | auto loc = outerProductOp.getLoc(); |
220 | if (outerProductOp.isMasked()) { |
221 | auto maskOp = outerProductOp.getMaskingOp(); |
222 | mask = maskOp.getMask(); |
223 | rootOp = maskOp; |
224 | rewriter.setInsertionPoint(rootOp); |
225 | } |
226 | |
227 | if (!isSupportedMaskOp(mask)) |
228 | return rewriter.notifyMatchFailure(outerProductOp, |
229 | kMatchFailureUnsupportedMaskOp); |
230 | |
231 | ValueRange accSMETiles = adaptor.getAcc(); |
232 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
233 | VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); |
234 | |
235 | SmallVector<Value> resultSMETiles; |
236 | for (auto [index, smeTile] : llvm::enumerate( |
237 | decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
238 | |
239 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
240 | auto lhs = rewriter.create<vector::ScalableExtractOp>( |
241 | loc, sliceType, outerProductOp.getLhs(), smeTile.row); |
242 | auto rhs = rewriter.create<vector::ScalableExtractOp>( |
243 | loc, sliceType, outerProductOp.getRhs(), smeTile.col); |
244 | auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( |
245 | loc, smeTileType, lhs, rhs, |
246 | !accSMETiles.empty() ? accSMETiles[index] : Value{}, |
247 | outerProductOp.getKind()); |
248 | |
249 | auto maskedOuterProduct = |
250 | vector::maskOperation(rewriter, smeOuterProduct, smeMask); |
251 | resultSMETiles.push_back(maskedOuterProduct->getResult(0)); |
252 | } |
253 | |
254 | rewriter.replaceOpWithMultiple(op: rootOp, newValues: {resultSMETiles}); |
255 | return success(); |
256 | } |
257 | }; |
258 | |
259 | // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to |
260 | // get the help of the type conversion), but doing so results in the type |
261 | // conversion adding target materializations in the `vector.mask` region |
262 | // (invalid). This pattern matches on `vector.mask` then calls into the |
263 | // `vector.outerproduct` pattern to work around this issue. |
264 | struct LegalizeMaskedVectorOuterProductOpsByDecomposition |
265 | : public OpConversionPattern<vector::MaskOp> { |
266 | using OpConversionPattern::OpConversionPattern; |
267 | |
268 | LogicalResult |
269 | matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, |
270 | ConversionPatternRewriter &rewriter) const override { |
271 | if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( |
272 | maskOp.getMaskableOp())) { |
273 | LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), |
274 | getContext()); |
275 | return static_cast<RewritePattern &>(pattern).matchAndRewrite( |
276 | outerProductOp, rewriter); |
277 | } |
278 | return failure(); |
279 | } |
280 | }; |
281 | |
282 | /// Legalize `vector.transfer_read` operations to fit within SME tiles by |
283 | /// decomposing them into tile-sized operations. |
284 | struct LegalizeTransferReadOpsByDecomposition |
285 | : public OpConversionPattern<vector::TransferReadOp> { |
286 | using OpConversionPattern::OpConversionPattern; |
287 | |
288 | LogicalResult |
289 | matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, |
290 | ConversionPatternRewriter &rewriter) const override { |
291 | auto vectorType = readOp.getVectorType(); |
292 | if (!isMultipleOfSMETileVectorType(vectorType)) |
293 | return rewriter.notifyMatchFailure(readOp, |
294 | kMatchFailureNotSMETileTypeMultiple); |
295 | |
296 | auto mask = readOp.getMask(); |
297 | if (!isSupportedMaskOp(mask)) |
298 | return rewriter.notifyMatchFailure(readOp, |
299 | kMatchFailureUnsupportedMaskOp); |
300 | |
301 | auto permutationMap = readOp.getPermutationMap(); |
302 | if (!permutationMap.isPermutation()) |
303 | return rewriter.notifyMatchFailure(readOp, |
304 | kMatchFailureNonPermutationMap); |
305 | |
306 | // Note: For 2D vector types the only non-identity permutation is a simple |
307 | // transpose [1, 0]. |
308 | bool transposed = !permutationMap.isIdentity(); |
309 | |
310 | auto loc = readOp.getLoc(); |
311 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
312 | |
313 | SmallVector<Value> resultSMETiles; |
314 | for (SMESubTile smeTile : |
315 | decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { |
316 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
317 | auto smeRead = rewriter.create<vector::TransferReadOp>( |
318 | loc, smeTileType, readOp.getBase(), |
319 | getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), |
320 | readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, |
321 | readOp.getInBoundsAttr()); |
322 | resultSMETiles.push_back(smeRead); |
323 | } |
324 | |
325 | rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); |
326 | return success(); |
327 | } |
328 | }; |
329 | |
330 | /// Legalize `vector.transfer_write` operations to fit within SME tiles by |
331 | /// decomposing them into tile-sized operations. |
332 | struct LegalizeTransferWriteOpsByDecomposition |
333 | : public OpConversionPattern<vector::TransferWriteOp> { |
334 | using OpConversionPattern::OpConversionPattern; |
335 | |
336 | LogicalResult |
337 | matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
338 | ConversionPatternRewriter &rewriter) const override { |
339 | auto vectorType = writeOp.getVectorType(); |
340 | if (!isMultipleOfSMETileVectorType(vectorType)) |
341 | return rewriter.notifyMatchFailure(writeOp, |
342 | kMatchFailureNotSMETileTypeMultiple); |
343 | |
344 | auto mask = writeOp.getMask(); |
345 | if (!isSupportedMaskOp(mask)) |
346 | return rewriter.notifyMatchFailure(writeOp, |
347 | kMatchFailureUnsupportedMaskOp); |
348 | |
349 | auto permutationMap = writeOp.getPermutationMap(); |
350 | if (!permutationMap.isPermutation()) |
351 | return rewriter.notifyMatchFailure(writeOp, |
352 | kMatchFailureNonPermutationMap); |
353 | |
354 | // Note: For 2D vector types the only non-identity permutation is a simple |
355 | // transpose [1, 0]. |
356 | bool transposed = !permutationMap.isIdentity(); |
357 | |
358 | auto loc = writeOp.getLoc(); |
359 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
360 | auto inputSMETiles = adaptor.getValueToStore(); |
361 | |
362 | Value destTensorOrMemref = writeOp.getBase(); |
363 | for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( |
364 | rewriter, vectorType, smeTileType, transposed))) { |
365 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
366 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
367 | loc, inputSMETiles[index], destTensorOrMemref, |
368 | getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), |
369 | writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); |
370 | if (writeOp.hasPureTensorSemantics()) |
371 | destTensorOrMemref = smeWrite.getResult(); |
372 | } |
373 | |
374 | if (writeOp.hasPureTensorSemantics()) |
375 | rewriter.replaceOp(writeOp, destTensorOrMemref); |
376 | else |
377 | rewriter.eraseOp(op: writeOp); |
378 | |
379 | return success(); |
380 | } |
381 | }; |
382 | |
383 | /// Legalize a multi-tile transfer_write as a single store loop. This is done as |
384 | /// part of type decomposition as at this level we know each tile write is |
385 | /// disjoint, but that information is lost after decomposition (without analysis |
386 | /// to reconstruct it). |
387 | /// |
388 | /// Example (pseudo-MLIR): |
389 | /// |
390 | /// ``` |
391 | /// vector.transfer_write %vector, %dest[%y, %x], %mask |
392 | /// : vector<[16]x[8]xi16>, memref<?x?xi16> |
393 | /// ``` |
394 | /// Is rewritten to: |
395 | /// ``` |
396 | /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { |
397 | /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ |
398 | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
399 | /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile |
400 | /// : vector<[8]xi16> from vector<[8]x[8]xi16> | |
401 | /// vector.transfer_write %upper_slice, | |
402 | /// %dest[%slice_idx + %y, %x], %upper_slice_mask | |
403 | /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
404 | /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ |
405 | /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | |
406 | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
407 | /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower |
408 | /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile |
409 | /// vector.transfer_write %lower_slice, | |
410 | /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | |
411 | /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
412 | /// } |
413 | /// ``` |
414 | struct LegalizeMultiTileTransferWriteAsStoreLoop |
415 | : public OpConversionPattern<vector::TransferWriteOp> { |
416 | using OpConversionPattern::OpConversionPattern; |
417 | |
418 | LogicalResult |
419 | matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
420 | ConversionPatternRewriter &rewriter) const override { |
421 | if (writeOp.hasPureTensorSemantics()) |
422 | return rewriter.notifyMatchFailure( |
423 | writeOp, "TODO: tensor semantics are unsupported"); |
424 | |
425 | auto permutationMap = writeOp.getPermutationMap(); |
426 | if (!permutationMap.isPermutation()) |
427 | return rewriter.notifyMatchFailure(writeOp, |
428 | kMatchFailureNonPermutationMap); |
429 | |
430 | bool transposed = !permutationMap.isIdentity(); |
431 | if (transposed) |
432 | return rewriter.notifyMatchFailure(writeOp, |
433 | "TODO: transpose unsupported"); |
434 | |
435 | auto vectorType = writeOp.getVectorType(); |
436 | if (!isMultipleOfSMETileVectorType(vectorType)) |
437 | return rewriter.notifyMatchFailure(writeOp, |
438 | kMatchFailureNotSMETileTypeMultiple); |
439 | |
440 | // Note: We also disallow masks where any dimension is > 16 because that |
441 | // prevents the masking from being lowered to use arm_sve.psel. |
442 | auto mask = writeOp.getMask(); |
443 | if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || |
444 | vectorType.getDimSize(1) > 16))) |
445 | return rewriter.notifyMatchFailure(writeOp, |
446 | kMatchFailureUnsupportedMaskOp); |
447 | |
448 | auto loc = writeOp.getLoc(); |
449 | auto createVscaleMultiple = |
450 | vector::makeVscaleConstantBuilder(rewriter, loc: loc); |
451 | |
452 | // Get SME tile and slice types. |
453 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
454 | auto minTileSlices = smeTileType.getDimSize(0); |
455 | VectorType sliceMaskType = |
456 | VectorType::get(minTileSlices, rewriter.getI1Type(), true); |
457 | |
458 | // Create loop over all tile slices. |
459 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
460 | auto upperBound = createVscaleMultiple(minTileSlices); |
461 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
462 | auto storeLoop = |
463 | rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
464 | rewriter.setInsertionPointToStart(storeLoop.getBody()); |
465 | |
466 | // For each sub-tile of the multi-tile `vectorType`. |
467 | auto inputSMETiles = adaptor.getValueToStore(); |
468 | auto tileSliceIndex = storeLoop.getInductionVar(); |
469 | for (auto [index, smeTile] : llvm::enumerate( |
470 | decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
471 | // The coordinates of the tile within `vectorType`. |
472 | auto tileRow = createVscaleMultiple(smeTile.row); |
473 | auto tileCol = createVscaleMultiple(smeTile.col); |
474 | |
475 | // The current slice of `vectorType` we are processing. |
476 | auto sliceIndex = |
477 | rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex); |
478 | |
479 | // Where in the destination memref the current slice will be stored. |
480 | auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex, |
481 | writeOp.getIndices()[0]); |
482 | auto storeCol = |
483 | rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]); |
484 | |
485 | // Extract the mask for the current slice. |
486 | Value sliceMask = nullptr; |
487 | if (mask) { |
488 | sliceMask = rewriter.create<vector::ExtractOp>( |
489 | loc, mask, OpFoldResult(sliceIndex)); |
490 | if (sliceMaskType != sliceMask.getType()) |
491 | sliceMask = rewriter.create<vector::ScalableExtractOp>( |
492 | loc, sliceMaskType, sliceMask, smeTile.col); |
493 | } |
494 | |
495 | // Extract and store the current slice. |
496 | Value tile = inputSMETiles[index]; |
497 | auto slice = |
498 | rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex); |
499 | rewriter.create<vector::TransferWriteOp>( |
500 | loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol}, |
501 | AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), |
502 | sliceMask, |
503 | rewriter.getBoolArrayAttr( |
504 | ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front())); |
505 | } |
506 | |
507 | rewriter.eraseOp(op: writeOp); |
508 | return success(); |
509 | } |
510 | }; |
511 | |
512 | //===----------------------------------------------------------------------===// |
513 | // ArmSME-specific fixup canonicalizations/folds |
514 | //===----------------------------------------------------------------------===// |
515 | |
516 | /// Folds an extract from a 3D `vector.create_mask` (which is a vector of |
517 | /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is |
518 | /// necessary for the mask to be lowered to ArmSME. |
519 | /// |
520 | /// Example: |
521 | /// |
522 | /// BEFORE: |
523 | /// ```mlir |
524 | /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> |
525 | /// %subMask = vector.extract %mask[2] |
526 | /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> |
527 | /// ``` |
528 | /// |
529 | /// AFTER: |
530 | /// ```mlir |
531 | /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index |
532 | /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index |
533 | /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> |
534 | /// ``` |
535 | struct FoldExtractFromVectorOfSMELikeCreateMasks |
536 | : public OpRewritePattern<vector::ExtractOp> { |
537 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
538 | |
539 | LogicalResult matchAndRewrite(vector::ExtractOp extractOp, |
540 | PatternRewriter &rewriter) const override { |
541 | auto loc = extractOp.getLoc(); |
542 | auto createMaskOp = |
543 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
544 | if (!createMaskOp) |
545 | return rewriter.notifyMatchFailure( |
546 | extractOp, "extract not from vector.create_mask op"); |
547 | |
548 | VectorType extractedMaskType = |
549 | llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); |
550 | if (!extractedMaskType) |
551 | return rewriter.notifyMatchFailure(extractOp, |
552 | "extracted type is not a vector type"); |
553 | |
554 | auto numScalable = extractedMaskType.getNumScalableDims(); |
555 | if (numScalable != 2) |
556 | return rewriter.notifyMatchFailure( |
557 | extractOp, "expected extracted type to be an SME-like mask"); |
558 | |
559 | // TODO: Support multiple extraction indices. |
560 | if (extractOp.getStaticPosition().size() != 1) |
561 | return rewriter.notifyMatchFailure( |
562 | extractOp, "only a single extraction index is supported"); |
563 | |
564 | auto frontMaskDim = createMaskOp.getOperand(0); |
565 | if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) |
566 | return rewriter.notifyMatchFailure( |
567 | extractOp, |
568 | "constant vector.create_masks dims should be folded elsewhere"); |
569 | |
570 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
571 | auto extractionIndex = getValueOrCreateConstantIndexOp( |
572 | rewriter, loc, extractOp.getMixedPosition()[0]); |
573 | auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>( |
574 | loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, |
575 | frontMaskDim); |
576 | auto newMaskFrontDim = rewriter.create<arith::SelectOp>( |
577 | loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); |
578 | |
579 | rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( |
580 | extractOp, extractedMaskType, |
581 | ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); |
582 | return success(); |
583 | } |
584 | }; |
585 | |
586 | /// A vector type where no fixed dimension comes after a scalable dimension. |
587 | bool isLegalVectorType(VectorType vType) { |
588 | bool seenFixedDim = false; |
589 | for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { |
590 | seenFixedDim |= !scalableFlag; |
591 | if (seenFixedDim && scalableFlag) |
592 | return false; |
593 | } |
594 | return true; |
595 | } |
596 | |
597 | /// Lifts an illegal vector.transpose and vector.transfer_read to a |
598 | /// memref.subview + memref.transpose, followed by a legal read. |
599 | /// |
600 | /// 'Illegal' here means a leading scalable dimension and a fixed trailing |
601 | /// dimension, which has no valid lowering. |
602 | /// |
603 | /// The memref.transpose is metadata-only transpose that produces a strided |
604 | /// memref, which eventually becomes a loop reading individual elements. |
605 | /// |
606 | /// Example: |
607 | /// |
608 | /// BEFORE: |
609 | /// ```mlir |
610 | /// %illegalRead = vector.transfer_read %memref[%a, %b] |
611 | /// : memref<?x?xf32>, vector<[8]x4xf32> |
612 | /// %legalType = vector.transpose %illegalRead, [1, 0] |
613 | /// : vector<[8]x4xf32> to vector<4x[8]xf32> |
614 | /// ``` |
615 | /// |
616 | /// AFTER: |
617 | /// ```mlir |
618 | /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] |
619 | /// : memref<?x?xf32> to memref<?x?xf32> |
620 | /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) |
621 | /// : memref<?x?xf32> to memref<?x?xf32> |
622 | /// %legalType = vector.transfer_read %transpose[%c0, %c0] |
623 | /// : memref<?x?xf32>, vector<4x[8]xf32> |
624 | /// ``` |
625 | struct LiftIllegalVectorTransposeToMemory |
626 | : public OpRewritePattern<vector::TransposeOp> { |
627 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
628 | |
629 | static Value getExtensionSource(Operation *op) { |
630 | if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op)) |
631 | return op->getOperand(idx: 0); |
632 | return {}; |
633 | } |
634 | |
635 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
636 | PatternRewriter &rewriter) const override { |
637 | auto sourceType = transposeOp.getSourceVectorType(); |
638 | auto resultType = transposeOp.getResultVectorType(); |
639 | if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
640 | return rewriter.notifyMatchFailure(transposeOp, |
641 | kMatchFailureNotIllegalToLegal); |
642 | |
643 | // Look through extend for transfer_read. |
644 | Value maybeRead = transposeOp.getVector(); |
645 | auto *transposeSourceOp = maybeRead.getDefiningOp(); |
646 | Operation *extendOp = nullptr; |
647 | if (Value extendSource = getExtensionSource(op: transposeSourceOp)) { |
648 | maybeRead = extendSource; |
649 | extendOp = transposeSourceOp; |
650 | } |
651 | |
652 | auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); |
653 | if (!illegalRead) |
654 | return rewriter.notifyMatchFailure( |
655 | transposeOp, |
656 | "expected source to be (possibly extended) transfer_read"); |
657 | |
658 | if (!illegalRead.getPermutationMap().isIdentity()) |
659 | return rewriter.notifyMatchFailure( |
660 | illegalRead, "expected read to have identity permutation map"); |
661 | |
662 | auto loc = transposeOp.getLoc(); |
663 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
664 | auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
665 | |
666 | // Create a subview that matches the size of the illegal read vector type. |
667 | auto readType = illegalRead.getVectorType(); |
668 | auto readSizes = llvm::map_to_vector( |
669 | llvm::zip_equal(readType.getShape(), readType.getScalableDims()), |
670 | [&](auto dim) -> Value { |
671 | auto [size, isScalable] = dim; |
672 | auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); |
673 | if (!isScalable) |
674 | return dimSize; |
675 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
676 | return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); |
677 | }); |
678 | SmallVector<Value> strides(readType.getRank(), Value(one)); |
679 | auto readSubview = rewriter.create<memref::SubViewOp>( |
680 | loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes, |
681 | strides); |
682 | |
683 | // Apply the transpose to all values/attributes of the transfer_read: |
684 | // - The mask |
685 | Value mask = illegalRead.getMask(); |
686 | if (mask) { |
687 | // Note: The transpose for the mask should fold into the |
688 | // vector.create_mask/constant_mask op, which will then become legal. |
689 | mask = rewriter.create<vector::TransposeOp>(loc, mask, |
690 | transposeOp.getPermutation()); |
691 | } |
692 | // - The source memref |
693 | mlir::AffineMap transposeMap = AffineMap::getPermutationMap( |
694 | transposeOp.getPermutation(), getContext()); |
695 | auto transposedSubview = rewriter.create<memref::TransposeOp>( |
696 | loc, readSubview, AffineMapAttr::get(transposeMap)); |
697 | ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); |
698 | // - The `in_bounds` attribute |
699 | if (inBoundsAttr) { |
700 | SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), |
701 | inBoundsAttr.end()); |
702 | applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); |
703 | inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); |
704 | } |
705 | |
706 | VectorType legalReadType = resultType.clone(readType.getElementType()); |
707 | // Note: The indices are all zero as the subview is already offset. |
708 | SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); |
709 | auto legalRead = rewriter.create<vector::TransferReadOp>( |
710 | loc, legalReadType, transposedSubview, readIndices, |
711 | illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, |
712 | inBoundsAttr); |
713 | |
714 | // Replace the transpose with the new read, extending the result if |
715 | // necessary. |
716 | rewriter.replaceOp(transposeOp, [&]() -> Operation * { |
717 | if (extendOp) |
718 | return rewriter.create(loc, extendOp->getName().getIdentifier(), |
719 | Value(legalRead), resultType); |
720 | return legalRead; |
721 | }()); |
722 | |
723 | return success(); |
724 | } |
725 | }; |
726 | |
727 | /// A rewrite to turn unit dim transpose-like vector.shape_casts into |
728 | /// vector.transposes. The shape_cast has to be from an illegal vector type to a |
729 | /// legal one (as defined by isLegalVectorType). |
730 | /// |
731 | /// The reasoning for this is if we've got to this pass and we still have |
732 | /// shape_casts of illegal types, then they likely will not cancel out. Turning |
733 | /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to |
734 | /// eliminate them. |
735 | /// |
736 | /// Example: |
737 | /// |
738 | /// BEFORE: |
739 | /// ```mlir |
740 | /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> |
741 | /// ``` |
742 | /// |
743 | /// AFTER: |
744 | /// ```mlir |
745 | /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
746 | /// ``` |
747 | struct ConvertIllegalShapeCastOpsToTransposes |
748 | : public OpRewritePattern<vector::ShapeCastOp> { |
749 | using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; |
750 | |
751 | LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, |
752 | PatternRewriter &rewriter) const override { |
753 | auto sourceType = shapeCastOp.getSourceVectorType(); |
754 | auto resultType = shapeCastOp.getResultVectorType(); |
755 | if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
756 | return rewriter.notifyMatchFailure(shapeCastOp, |
757 | kMatchFailureNotIllegalToLegal); |
758 | |
759 | // Note: If we know that `sourceType` is an illegal vector type (and 2D) |
760 | // then dim 0 is scalable and dim 1 is fixed. |
761 | if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) |
762 | return rewriter.notifyMatchFailure( |
763 | shapeCastOp, "expected source to be a 2D scalable vector with a " |
764 | "trailing unit dim"); |
765 | |
766 | auto loc = shapeCastOp.getLoc(); |
767 | auto transpose = rewriter.create<vector::TransposeOp>( |
768 | loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); |
769 | |
770 | if (resultType.getRank() == 1) |
771 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, |
772 | transpose); |
773 | else |
774 | rewriter.replaceOp(shapeCastOp, transpose); |
775 | |
776 | return success(); |
777 | } |
778 | }; |
779 | |
780 | /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use |
781 | /// the ZA state. This workaround rewrite to support these transposes when ZA is |
782 | /// available. |
783 | /// |
784 | /// Example: |
785 | /// |
786 | /// BEFORE: |
787 | /// ```mlir |
788 | /// %transpose = vector.transpose %vec, [1, 0] |
789 | /// : vector<2x[4]xf32> to vector<[4]x2xf32> |
790 | /// vector.transfer_write %transpose, %dest[%y, %x] |
791 | /// : vector<[4]x2xf32>, memref<?x?xf32> |
792 | /// ``` |
793 | /// |
794 | /// AFTER: |
795 | /// ```mlir |
796 | /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> |
797 | /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> |
798 | /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> |
799 | /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> |
800 | /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> |
801 | /// %c4_vscale = arith.muli %vscale, %c4 : index |
802 | /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> |
803 | /// vector.transfer_write %4, %dest[%y, %x], %mask |
804 | /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} |
805 | /// : vector<[4]x[4]xf32>, memref<?x?xf32> |
806 | /// ``` |
807 | /// |
808 | /// Values larger than a single tile are supported via decomposition. |
809 | struct LowerIllegalTransposeStoreViaZA |
810 | : public OpRewritePattern<vector::TransferWriteOp> { |
811 | using OpRewritePattern::OpRewritePattern; |
812 | |
813 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
814 | PatternRewriter &rewriter) const override { |
815 | if (!isSupportedMaskOp(writeOp.getMask())) |
816 | return rewriter.notifyMatchFailure(writeOp, |
817 | kMatchFailureUnsupportedMaskOp); |
818 | |
819 | auto permutationMap = writeOp.getPermutationMap(); |
820 | if (!permutationMap.isIdentity()) |
821 | return rewriter.notifyMatchFailure(writeOp, |
822 | kMatchFailureNonPermutationMap); |
823 | |
824 | auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); |
825 | if (!transposeOp) |
826 | return failure(); |
827 | |
828 | auto sourceType = transposeOp.getSourceVectorType(); |
829 | auto resultType = transposeOp.getResultVectorType(); |
830 | |
831 | if (resultType.getRank() != 2) |
832 | return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2"); |
833 | |
834 | if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) |
835 | return rewriter.notifyMatchFailure( |
836 | transposeOp, "not illegal/unsupported SVE transpose"); |
837 | |
838 | auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); |
839 | VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); |
840 | |
841 | if (sourceType.getDimSize(0) <= 1 || |
842 | sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) |
843 | return rewriter.notifyMatchFailure(writeOp, "unsupported source shape"); |
844 | |
845 | auto loc = writeOp.getLoc(); |
846 | auto createVscaleMultiple = |
847 | vector::makeVscaleConstantBuilder(rewriter, loc: loc); |
848 | |
849 | auto transposeMap = AffineMapAttr::get( |
850 | AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); |
851 | |
852 | // Note: We need to use `get_tile` as there's no vector-level `undef`. |
853 | Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); |
854 | Value destTensorOrMemref = writeOp.getBase(); |
855 | auto numSlicesPerTile = |
856 | std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); |
857 | auto numSlices = |
858 | rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); |
859 | for (auto [index, smeTile] : llvm::enumerate( |
860 | decomposeToSMETiles(rewriter, sourceType, smeTileType))) { |
861 | // 1. _Deliberately_ drop a scalable dimension and insert a fixed number |
862 | // of slices from the source type into the SME tile. Without checking |
863 | // vscale (and emitting multiple implementations) we can't make use of the |
864 | // rows of the tile after 1*vscale rows. |
865 | Value tile = undefTile; |
866 | for (int d = 0; d < numSlicesPerTile; ++d) { |
867 | Value vector = rewriter.create<vector::ExtractOp>( |
868 | loc, transposeOp.getVector(), |
869 | rewriter.getIndexAttr(d + smeTile.row)); |
870 | if (vector.getType() != smeSliceType) { |
871 | vector = rewriter.create<vector::ScalableExtractOp>( |
872 | loc, smeSliceType, vector, smeTile.col); |
873 | } |
874 | tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); |
875 | } |
876 | |
877 | // 2. Transpose the tile position. |
878 | auto transposedRow = createVscaleMultiple(smeTile.col); |
879 | auto transposedCol = |
880 | rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); |
881 | |
882 | // 3. Compute mask for tile store. |
883 | Value maskRows; |
884 | Value maskCols; |
885 | if (auto mask = writeOp.getMask()) { |
886 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
887 | maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), |
888 | transposedRow); |
889 | maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), |
890 | transposedCol); |
891 | maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); |
892 | } else { |
893 | maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); |
894 | maskCols = numSlices; |
895 | } |
896 | auto subMask = rewriter.create<vector::CreateMaskOp>( |
897 | loc, smeTileType.clone(rewriter.getI1Type()), |
898 | ValueRange{maskRows, maskCols}); |
899 | |
900 | // 4. Emit a transposed tile write. |
901 | auto writeIndices = writeOp.getIndices(); |
902 | Value destRow = |
903 | rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); |
904 | Value destCol = |
905 | rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); |
906 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
907 | loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, |
908 | transposeMap, subMask, writeOp.getInBounds()); |
909 | |
910 | if (writeOp.hasPureTensorSemantics()) |
911 | destTensorOrMemref = smeWrite.getResult(); |
912 | } |
913 | |
914 | if (writeOp.hasPureTensorSemantics()) |
915 | rewriter.replaceOp(writeOp, destTensorOrMemref); |
916 | else |
917 | rewriter.eraseOp(op: writeOp); |
918 | |
919 | return success(); |
920 | } |
921 | }; |
922 | |
923 | struct VectorLegalizationPass |
924 | : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { |
925 | void runOnOperation() override { |
926 | auto *context = &getContext(); |
927 | TypeConverter converter; |
928 | RewritePatternSet patterns(context); |
929 | converter.addConversion(callback: [](Type type) { return type; }); |
930 | converter.addConversion( |
931 | callback: [](VectorType vectorType, |
932 | SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { |
933 | if (!isMultipleOfSMETileVectorType(vectorType)) |
934 | return std::nullopt; |
935 | auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); |
936 | auto smeTileType = |
937 | getSMETileTypeForElement(vectorType.getElementType()); |
938 | types = SmallVector<Type>(smeTileCount, smeTileType); |
939 | return success(); |
940 | }); |
941 | |
942 | // Apply preprocessing patterns. |
943 | RewritePatternSet rewritePatterns(context); |
944 | rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, |
945 | LiftIllegalVectorTransposeToMemory, |
946 | ConvertIllegalShapeCastOpsToTransposes, |
947 | LowerIllegalTransposeStoreViaZA>(context); |
948 | if (failed( |
949 | applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) |
950 | return signalPassFailure(); |
951 | |
952 | // Note: These two patterns are added with a high benefit to ensure: |
953 | // - Masked outer products are handled before unmasked ones |
954 | // - Multi-tile writes are lowered as a store loop (if possible) |
955 | patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition, |
956 | LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context, |
957 | /*benefit=*/1024); |
958 | patterns.add<LegalizeArithConstantOpsByDecomposition, |
959 | LegalizeVectorOuterProductOpsByDecomposition, |
960 | LegalizeTransferReadOpsByDecomposition, |
961 | LegalizeTransferWriteOpsByDecomposition>(converter, context); |
962 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
963 | converter); |
964 | populateCallOpTypeConversionPattern(patterns, converter); |
965 | populateReturnOpTypeConversionPattern(patterns, converter); |
966 | scf::populateSCFStructuralTypeConversions(typeConverter: converter, patterns); |
967 | |
968 | ConversionTarget target(getContext()); |
969 | target.markUnknownOpDynamicallyLegal( |
970 | fn: [&](Operation *op) { return converter.isLegal(op); }); |
971 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
972 | return converter.isSignatureLegal(op.getFunctionType()); |
973 | }); |
974 | if (failed(applyPartialConversion(getOperation(), target, |
975 | std::move(patterns)))) |
976 | return signalPassFailure(); |
977 | } |
978 | }; |
979 | |
980 | } // namespace |
981 | |
982 | std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { |
983 | return std::make_unique<VectorLegalizationPass>(); |
984 | } |
985 |
Definitions
- kMatchFailureNotSMETileTypeMultiple
- kMatchFailureUnsupportedMaskOp
- kMatchFailureNonPermutationMap
- kMatchFailureNotIllegalToLegal
- SMESubTile
- addConstantScalableOffset
- getSMESubTileIndices
- isSupportedMaskOp
- extractSMEMask
- decomposeToSMETiles
- getNumberOfSMETilesForVectorType
- LegalizeArithConstantOpsByDecomposition
- matchAndRewrite
- LegalizeVectorOuterProductOpsByDecomposition
- matchAndRewrite
- LegalizeMaskedVectorOuterProductOpsByDecomposition
- matchAndRewrite
- LegalizeTransferReadOpsByDecomposition
- matchAndRewrite
- LegalizeTransferWriteOpsByDecomposition
- matchAndRewrite
- LegalizeMultiTileTransferWriteAsStoreLoop
- matchAndRewrite
- FoldExtractFromVectorOfSMELikeCreateMasks
- matchAndRewrite
- isLegalVectorType
- LiftIllegalVectorTransposeToMemory
- getExtensionSource
- matchAndRewrite
- ConvertIllegalShapeCastOpsToTransposes
- matchAndRewrite
- LowerIllegalTransposeStoreViaZA
- matchAndRewrite
- VectorLegalizationPass
- runOnOperation
Improve your Profiling and Debugging skills
Find out more