1//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass legalizes vector operations so they can be lowered to ArmSME.
10//
11// Note: In the context of this pass 'tile' always refers to an SME tile.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
17#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
18#include "mlir/Dialect/ArmSME/Utils/Utils.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/Transforms/Patterns.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/Transforms/OneToNTypeConversion.h"
25
26#define DEBUG_TYPE "arm-sme-vector-legalization"
27
28namespace mlir::arm_sme {
29#define GEN_PASS_DEF_VECTORLEGALIZATION
30#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
31} // namespace mlir::arm_sme
32
33using namespace mlir;
34using namespace mlir::arm_sme;
35
36namespace {
37
38//===----------------------------------------------------------------------===//
39// Decomposition of vector operations larger than an SME tile
40//===----------------------------------------------------------------------===//
41
42// Common match failure reasons.
43static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
44 "op vector size is not multiple of SME tiles");
45static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
46 "op mask is unsupported for legalization/decomposition");
47static constexpr StringLiteral
48 kMatchFailureNonPermutationMap("op affine map is not a permutation");
49static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
50 "expected transpose from illegal type to legal type");
51
52/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
53/// larger vector type. The (`row`, `col`) are the position of the tile in the
54/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
55/// sub-tiles, we would have:
56///
57/// 8 x vscale
58/// ┌─────────────┬─────────────┐
59/// │(0,0) │(0,4) │
60/// │ │ │
61/// ├─────────────┼─────────────┤ 8 x vscale
62/// │(4,0) │(4,4) │
63/// │ │ │
64/// └─────────────┴─────────────┘
65struct SMESubTile {
66 // Note: The units of (row, col) are vscale (as SME tiles are scalable).
67 int row{0};
68 int col{0};
69 // The SME tile type.
70 VectorType type;
71};
72
73/// Adds a constant elementwise scalable offset to `indices` (which are of equal
74/// length). For example, in the 2D case this would return:
75// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
76SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
77 Location loc,
78 ValueRange indices,
79 ArrayRef<int> scalableOffsets) {
80 auto vscale = builder.create<vector::VectorScaleOp>(loc);
81 return llvm::map_to_vector(
82 C: llvm::zip_equal(t&: indices, u&: scalableOffsets), F: [&](auto pair) -> Value {
83 auto [index, base] = pair;
84 auto offset = builder.create<arith::MulIOp>(
85 loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
86 return builder.create<arith::AddIOp>(loc, index, offset);
87 });
88}
89
90/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
91/// indices for one of the SME sub-tiles it will decompose into.
92///
93/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
94/// indices for each tile would need to be adjusted as follows:
95///
96/// initial indices = [a,b], inital size = 8x8, target size = 4x4
97/// ┌─────────────┬─────────────┐
98/// │[a,b] │[a,b+4] │
99/// │ │ │
100/// ├─────────────┼─────────────┤
101/// │[a+4,b] │[a+4,b+4] │
102/// │ │ │
103/// └─────────────┴─────────────┘
104SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
105 ValueRange indices,
106 SMESubTile smeTile) {
107 return addConstantScalableOffset(builder, loc, indices,
108 scalableOffsets: {smeTile.row, smeTile.col});
109}
110
111/// Returns true if `mask` is generated by an operation that can be decomposed
112/// for SME. Currently, that is just no mask, or vector.create_mask.
113/// TODO: Add support for vector.constant_mask once required for SME.
114bool isSupportedMaskOp(Value mask) {
115 return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
116}
117
118/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
119Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
120 SMESubTile smeTile) {
121 assert(isSupportedMaskOp(mask));
122 if (!mask)
123 return Value{};
124 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
125 // The operands of `vector.create_mask` (from a 2D perspective) are the
126 // coordinates where the mask ends. So we subtract where this tile starts,
127 // from the mask operands to get the parameters for this sub-tile.
128 auto smeTileMaskDims = addConstantScalableOffset(
129 builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
130 auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
131 loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
132 return smeTileCreateMask.getResult();
133}
134
135/// Constructs an iterator that returns each SME tile (with coordinates)
136/// contained within a VectorType. For example, if decomposing an [8]x[8] into
137/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
138/// (4, 4).
139auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
140 VectorType smeTileType,
141 bool transposeIndices = false) {
142 assert(isMultipleOfSMETileVectorType(type) &&
143 "`type` not multiple of SME tiles");
144 return llvm::map_range(
145 StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
146 smeTileType.getDimSize(1)}),
147 [=](auto indices) {
148 int row = int(indices[0]);
149 int col = int(indices[1]);
150 if (transposeIndices)
151 std::swap(a&: row, b&: col);
152 return SMESubTile{row, col, smeTileType};
153 });
154}
155
156/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
157/// `type`.
158int getNumberOfSMETilesForVectorType(VectorType type) {
159 assert(isMultipleOfSMETileVectorType(type) &&
160 "`type` not multiple of SME tiles");
161 int64_t vectorRows = type.getDimSize(0);
162 int64_t vectorCols = type.getDimSize(1);
163 auto elementType = type.getElementType();
164 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
165 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
166}
167
168/// Legalize `arith.constant dense<value>` splat operations to fit within SME
169/// tiles by decomposing them into tile-sized operations.
170struct LegalizeArithConstantOpsByDecomposition
171 : public OneToNOpConversionPattern<arith::ConstantOp> {
172 using OneToNOpConversionPattern::OneToNOpConversionPattern;
173
174 LogicalResult
175 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
176 OneToNPatternRewriter &rewriter) const override {
177 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
178 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
179 if (!vectorType || !denseAttr || !denseAttr.isSplat())
180 return failure();
181
182 if (!isMultipleOfSMETileVectorType(vectorType))
183 return rewriter.notifyMatchFailure(constantOp,
184 kMatchFailureNotSMETileTypeMultiple);
185
186 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
187 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
188 auto tileSplat = rewriter.create<arith::ConstantOp>(
189 constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
190 rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
191 adaptor.getResultMapping());
192
193 return success();
194 }
195};
196
197/// Legalize `vector.outerproduct` operations to fit within SME tiles by
198/// decomposing them into tile-sized operations.
199struct LegalizeVectorOuterProductOpsByDecomposition
200 : public OneToNOpConversionPattern<vector::OuterProductOp> {
201 using OneToNOpConversionPattern::OneToNOpConversionPattern;
202
203 LogicalResult
204 matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
205 OneToNPatternRewriter &rewriter) const override {
206 auto vectorType = outerProductOp.getResultVectorType();
207 if (!isMultipleOfSMETileVectorType(vectorType))
208 return rewriter.notifyMatchFailure(outerProductOp,
209 kMatchFailureNotSMETileTypeMultiple);
210
211 Value mask;
212 Operation *rootOp = outerProductOp;
213 auto loc = outerProductOp.getLoc();
214 if (outerProductOp.isMasked()) {
215 auto maskOp = outerProductOp.getMaskingOp();
216 mask = maskOp.getMask();
217 rootOp = maskOp;
218 }
219
220 if (!isSupportedMaskOp(mask))
221 return rewriter.notifyMatchFailure(outerProductOp,
222 kMatchFailureUnsupportedMaskOp);
223
224 ValueRange accSMETiles = adaptor.getAcc();
225 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
226 VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
227
228 SmallVector<Value> resultSMETiles;
229 for (auto [index, smeTile] : llvm::enumerate(
230 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
231
232 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
233 auto lhs = rewriter.create<vector::ScalableExtractOp>(
234 loc, sliceType, outerProductOp.getLhs(), smeTile.row);
235 auto rhs = rewriter.create<vector::ScalableExtractOp>(
236 loc, sliceType, outerProductOp.getRhs(), smeTile.col);
237 auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
238 loc, smeTileType, lhs, rhs,
239 !accSMETiles.empty() ? accSMETiles[index] : Value{},
240 outerProductOp.getKind());
241
242 auto maskedOuterProduct =
243 vector::maskOperation(rewriter, smeOuterProduct, smeMask);
244 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
245 }
246
247 rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
248 return success();
249 }
250};
251
252// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
253// get the help of the type conversion), but doing so results in the type
254// conversion adding target materializations in the `vector.mask` region
255// (invalid). This pattern matches on `vector.mask` then calls into the
256// `vector.outerproduct` pattern to work around this issue.
257struct LegalizeMaskedVectorOuterProductOpsByDecomposition
258 : public OneToNOpConversionPattern<vector::MaskOp> {
259 using OneToNOpConversionPattern::OneToNOpConversionPattern;
260
261 LogicalResult
262 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
263 OneToNPatternRewriter &rewriter) const override {
264 if (auto outerProductOp =
265 llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
266 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
267 getContext());
268 return static_cast<RewritePattern &>(pattern).matchAndRewrite(
269 outerProductOp, rewriter);
270 }
271 return failure();
272 }
273};
274
275/// Legalize `vector.transfer_read` operations to fit within SME tiles by
276/// decomposing them into tile-sized operations.
277struct LegalizeTransferReadOpsByDecomposition
278 : public OneToNOpConversionPattern<vector::TransferReadOp> {
279 using OneToNOpConversionPattern::OneToNOpConversionPattern;
280
281 LogicalResult
282 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
283 OneToNPatternRewriter &rewriter) const override {
284 auto vectorType = readOp.getVectorType();
285 if (!isMultipleOfSMETileVectorType(vectorType))
286 return rewriter.notifyMatchFailure(readOp,
287 kMatchFailureNotSMETileTypeMultiple);
288
289 auto mask = readOp.getMask();
290 if (!isSupportedMaskOp(mask))
291 return rewriter.notifyMatchFailure(readOp,
292 kMatchFailureUnsupportedMaskOp);
293
294 auto permutationMap = readOp.getPermutationMap();
295 if (!permutationMap.isPermutation())
296 return rewriter.notifyMatchFailure(readOp,
297 kMatchFailureNonPermutationMap);
298
299 // Note: For 2D vector types the only non-identity permutation is a simple
300 // tranpose [1, 0].
301 bool transposed = !permutationMap.isIdentity();
302
303 auto loc = readOp.getLoc();
304 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
305
306 SmallVector<Value> resultSMETiles;
307 for (SMESubTile smeTile :
308 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
309 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
310 auto smeRead = rewriter.create<vector::TransferReadOp>(
311 loc, smeTileType, readOp.getSource(),
312 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
313 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
314 readOp.getInBoundsAttr());
315 resultSMETiles.push_back(smeRead);
316 }
317
318 rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
319 return success();
320 }
321};
322
323/// Legalize `vector.transfer_write` operations to fit within SME tiles by
324/// decomposing them into tile-sized operations.
325struct LegalizeTransferWriteOpsByDecomposition
326 : public OneToNOpConversionPattern<vector::TransferWriteOp> {
327 using OneToNOpConversionPattern::OneToNOpConversionPattern;
328
329 LogicalResult
330 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
331 OneToNPatternRewriter &rewriter) const override {
332 auto vectorType = writeOp.getVectorType();
333 if (!isMultipleOfSMETileVectorType(vectorType))
334 return rewriter.notifyMatchFailure(writeOp,
335 kMatchFailureNotSMETileTypeMultiple);
336
337 auto mask = writeOp.getMask();
338 if (!isSupportedMaskOp(mask))
339 return rewriter.notifyMatchFailure(writeOp,
340 kMatchFailureUnsupportedMaskOp);
341
342 auto permutationMap = writeOp.getPermutationMap();
343 if (!permutationMap.isPermutation())
344 return rewriter.notifyMatchFailure(writeOp,
345 kMatchFailureNonPermutationMap);
346
347 // Note: For 2D vector types the only non-identity permutation is a simple
348 // tranpose [1, 0].
349 bool transposed = !permutationMap.isIdentity();
350
351 auto loc = writeOp.getLoc();
352 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
353 auto inputSMETiles = adaptor.getVector();
354
355 Value destTensorOrMemref = writeOp.getSource();
356 for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
357 rewriter, vectorType, smeTileType, transposed))) {
358 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
359 auto smeWrite = rewriter.create<vector::TransferWriteOp>(
360 loc, inputSMETiles[index], destTensorOrMemref,
361 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
362 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
363 if (writeOp.hasPureTensorSemantics())
364 destTensorOrMemref = smeWrite.getResult();
365 }
366
367 if (writeOp.hasPureTensorSemantics())
368 rewriter.replaceOp(writeOp, destTensorOrMemref);
369 else
370 rewriter.eraseOp(op: writeOp);
371
372 return success();
373 }
374};
375
376//===----------------------------------------------------------------------===//
377// ArmSME-specific fixup canonicalizations/folds
378//===----------------------------------------------------------------------===//
379
380/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
381/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
382/// necessary for the mask to be lowered to ArmSME.
383///
384/// Example:
385///
386/// BEFORE:
387/// ```mlir
388/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
389/// %subMask = vector.extract %mask[2]
390/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
391/// ```
392///
393/// AFTER:
394/// ```mlir
395/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
396/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
397/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
398/// ```
399struct FoldExtractFromVectorOfSMELikeCreateMasks
400 : public OpRewritePattern<vector::ExtractOp> {
401 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
402
403 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
404 PatternRewriter &rewriter) const override {
405 auto loc = extractOp.getLoc();
406 auto createMaskOp =
407 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
408 if (!createMaskOp)
409 return rewriter.notifyMatchFailure(
410 extractOp, "extract not from vector.create_mask op");
411
412 VectorType extractedMaskType =
413 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
414 if (!extractedMaskType)
415 return rewriter.notifyMatchFailure(extractOp,
416 "extracted type is not a vector type");
417
418 auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
419 if (numScalable != 2)
420 return rewriter.notifyMatchFailure(
421 extractOp, "expected extracted type to be an SME-like mask");
422
423 // TODO: Support multiple extraction indices.
424 if (extractOp.getStaticPosition().size() != 1)
425 return rewriter.notifyMatchFailure(
426 extractOp, "only a single extraction index is supported");
427
428 auto frontMaskDim = createMaskOp.getOperand(0);
429 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
430 return rewriter.notifyMatchFailure(
431 extractOp,
432 "constant vector.create_masks dims should be folded elsewhere");
433
434 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
435 auto extractionIndex = getValueOrCreateConstantIndexOp(
436 rewriter, loc, extractOp.getMixedPosition()[0]);
437 auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
438 loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
439 frontMaskDim);
440 auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
441 loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
442
443 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
444 extractOp, extractedMaskType,
445 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
446 return success();
447 }
448};
449
450/// A vector type where no fixed dimension comes after a scalable dimension.
451bool isLegalVectorType(VectorType vType) {
452 bool seenFixedDim = false;
453 for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
454 seenFixedDim |= !scalableFlag;
455 if (seenFixedDim && scalableFlag)
456 return false;
457 }
458 return true;
459}
460
461/// Lifts an illegal vector.transpose and vector.transfer_read to a
462/// memref.subview + memref.transpose, followed by a legal read.
463///
464/// 'Illegal' here means a leading scalable dimension and a fixed trailing
465/// dimension, which has no valid lowering.
466///
467/// The memref.transpose is metadata-only transpose that produces a strided
468/// memref, which eventually becomes a loop reading individual elements.
469///
470/// Example:
471///
472/// BEFORE:
473/// ```mlir
474/// %illegalRead = vector.transfer_read %memref[%a, %b]
475/// : memref<?x?xf32>, vector<[8]x4xf32>
476/// %legalType = vector.transpose %illegalRead, [1, 0]
477/// : vector<[8]x4xf32> to vector<4x[8]xf32>
478/// ```
479///
480/// AFTER:
481/// ```mlir
482/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
483/// : memref<?x?xf32> to memref<?x?xf32>
484/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
485/// : memref<?x?xf32> to memref<?x?xf32>
486/// %legalType = vector.transfer_read %transpose[%c0, %c0]
487/// : memref<?x?xf32>, vector<4x[8]xf32>
488/// ```
489struct LiftIllegalVectorTransposeToMemory
490 : public OpRewritePattern<vector::TransposeOp> {
491 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
492
493 static Value getExtensionSource(Operation *op) {
494 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
495 return op->getOperand(idx: 0);
496 return {};
497 }
498
499 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
500 PatternRewriter &rewriter) const override {
501 auto sourceType = transposeOp.getSourceVectorType();
502 auto resultType = transposeOp.getResultVectorType();
503 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
504 return rewriter.notifyMatchFailure(transposeOp,
505 kMatchFailureNotIllegalToLegal);
506
507 // Look through extend for transfer_read.
508 Value maybeRead = transposeOp.getVector();
509 auto *transposeSourceOp = maybeRead.getDefiningOp();
510 Operation *extendOp = nullptr;
511 if (Value extendSource = getExtensionSource(op: transposeSourceOp)) {
512 maybeRead = extendSource;
513 extendOp = transposeSourceOp;
514 }
515
516 auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
517 if (!illegalRead)
518 return rewriter.notifyMatchFailure(
519 transposeOp,
520 "expected source to be (possibly extended) transfer_read");
521
522 if (!illegalRead.getPermutationMap().isIdentity())
523 return rewriter.notifyMatchFailure(
524 illegalRead, "expected read to have identity permutation map");
525
526 auto loc = transposeOp.getLoc();
527 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
528 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
529
530 // Create a subview that matches the size of the illegal read vector type.
531 auto readType = illegalRead.getVectorType();
532 auto readSizes = llvm::map_to_vector(
533 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
534 [&](auto dim) -> Value {
535 auto [size, isScalable] = dim;
536 auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
537 if (!isScalable)
538 return dimSize;
539 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
540 return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
541 });
542 SmallVector<Value> strides(readType.getRank(), Value(one));
543 auto readSubview = rewriter.create<memref::SubViewOp>(
544 loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
545 strides);
546
547 // Apply the transpose to all values/attributes of the transfer_read:
548 // - The mask
549 Value mask = illegalRead.getMask();
550 if (mask) {
551 // Note: The transpose for the mask should fold into the
552 // vector.create_mask/constant_mask op, which will then become legal.
553 mask = rewriter.create<vector::TransposeOp>(loc, mask,
554 transposeOp.getPermutation());
555 }
556 // - The source memref
557 mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
558 transposeOp.getPermutation(), getContext());
559 auto transposedSubview = rewriter.create<memref::TransposeOp>(
560 loc, readSubview, AffineMapAttr::get(transposeMap));
561 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
562 // - The `in_bounds` attribute
563 if (inBoundsAttr) {
564 SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
565 inBoundsAttr.end());
566 applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
567 inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
568 }
569
570 VectorType legalReadType = resultType.clone(readType.getElementType());
571 // Note: The indices are all zero as the subview is already offset.
572 SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
573 auto legalRead = rewriter.create<vector::TransferReadOp>(
574 loc, legalReadType, transposedSubview, readIndices,
575 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
576 inBoundsAttr);
577
578 // Replace the transpose with the new read, extending the result if
579 // necessary.
580 rewriter.replaceOp(transposeOp, [&]() -> Operation * {
581 if (extendOp)
582 return rewriter.create(loc, extendOp->getName().getIdentifier(),
583 Value(legalRead), resultType);
584 return legalRead;
585 }());
586
587 return success();
588 }
589};
590
591/// A rewrite to turn unit dim transpose-like vector.shape_casts into
592/// vector.transposes. The shape_cast has to be from an illegal vector type to a
593/// legal one (as defined by isLegalVectorType).
594///
595/// The reasoning for this is if we've got to this pass and we still have
596/// shape_casts of illegal types, then they likely will not cancel out. Turning
597/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
598/// eliminate them.
599///
600/// Example:
601///
602/// BEFORE:
603/// ```mlir
604/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
605/// ```
606///
607/// AFTER:
608/// ```mlir
609/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
610/// ```
611struct ConvertIllegalShapeCastOpsToTransposes
612 : public OpRewritePattern<vector::ShapeCastOp> {
613 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
614
615 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
616 PatternRewriter &rewriter) const override {
617 auto sourceType = shapeCastOp.getSourceVectorType();
618 auto resultType = shapeCastOp.getResultVectorType();
619 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
620 return rewriter.notifyMatchFailure(shapeCastOp,
621 kMatchFailureNotIllegalToLegal);
622
623 // Note: If we know that `sourceType` is an illegal vector type (and 2D)
624 // then dim 0 is scalable and dim 1 is fixed.
625 if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
626 return rewriter.notifyMatchFailure(
627 shapeCastOp, "expected source to be a 2D scalable vector with a "
628 "trailing unit dim");
629
630 auto loc = shapeCastOp.getLoc();
631 auto transpose = rewriter.create<vector::TransposeOp>(
632 loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
633
634 if (resultType.getRank() == 1)
635 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
636 transpose);
637 else
638 rewriter.replaceOp(shapeCastOp, transpose);
639
640 return success();
641 }
642};
643
644struct VectorLegalizationPass
645 : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
646 void runOnOperation() override {
647 auto *context = &getContext();
648 OneToNTypeConverter converter;
649 RewritePatternSet patterns(context);
650 converter.addConversion(callback: [](Type type) { return type; });
651 converter.addConversion(
652 callback: [](VectorType vectorType,
653 SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
654 if (!isMultipleOfSMETileVectorType(vectorType))
655 return std::nullopt;
656 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
657 auto smeTileType =
658 getSMETileTypeForElement(vectorType.getElementType());
659 types = SmallVector<Type>(smeTileCount, smeTileType);
660 return success();
661 });
662
663 patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
664 LiftIllegalVectorTransposeToMemory,
665 ConvertIllegalShapeCastOpsToTransposes>(context);
666 // Note: High benefit to ensure masked outer products are lowered first.
667 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
668 converter, context, 1024);
669 patterns.add<LegalizeArithConstantOpsByDecomposition,
670 LegalizeVectorOuterProductOpsByDecomposition,
671 LegalizeTransferReadOpsByDecomposition,
672 LegalizeTransferWriteOpsByDecomposition>(converter, context);
673 populateFuncTypeConversionPatterns(typeConverter&: converter, patterns);
674 scf::populateSCFStructuralOneToNTypeConversions(typeConverter&: converter, patterns);
675
676 if (failed(applyPartialOneToNConversion(getOperation(), converter,
677 std::move(patterns))))
678 return signalPassFailure();
679 }
680};
681
682} // namespace
683
684std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
685 return std::make_unique<VectorLegalizationPass>();
686}
687

source code of mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp