1//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass legalizes vector operations so they can be lowered to ArmSME.
10//
11// Note: In the context of this pass 'tile' always refers to an SME tile.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
17#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
18#include "mlir/Dialect/ArmSME/Utils/Utils.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
21#include "mlir/Dialect/Index/IR/IndexDialect.h"
22#include "mlir/Dialect/Index/IR/IndexOps.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SCF/Transforms/Patterns.h"
26#include "mlir/Dialect/Utils/IndexingUtils.h"
27#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31#define DEBUG_TYPE "arm-sme-vector-legalization"
32
33namespace mlir::arm_sme {
34#define GEN_PASS_DEF_VECTORLEGALIZATION
35#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36} // namespace mlir::arm_sme
37
38using namespace mlir;
39using namespace mlir::arm_sme;
40
41namespace {
42
43//===----------------------------------------------------------------------===//
44// Decomposition of vector operations larger than an SME tile
45//===----------------------------------------------------------------------===//
46
47// Common match failure reasons.
48static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49 "op vector size is not multiple of SME tiles");
50static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51 "op mask is unsupported for legalization/decomposition");
52static constexpr StringLiteral
53 kMatchFailureNonPermutationMap("op affine map is not a permutation");
54static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55 "expected transpose from illegal type to legal type");
56
57/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
58/// larger vector type. The (`row`, `col`) are the position of the tile in the
59/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
60/// sub-tiles, we would have:
61///
62/// 8 x vscale
63/// ┌─────────────┬─────────────┐
64/// │(0,0) │(0,4) │
65/// │ │ │
66/// ├─────────────┼─────────────┤ 8 x vscale
67/// │(4,0) │(4,4) │
68/// │ │ │
69/// └─────────────┴─────────────┘
70struct SMESubTile {
71 // Note: The units of (row, col) are vscale (as SME tiles are scalable).
72 int row{0};
73 int col{0};
74 // The SME tile type.
75 VectorType type;
76};
77
78/// Adds a constant elementwise scalable offset to `indices` (which are of equal
79/// length). For example, in the 2D case this would return:
80// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
81SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
82 Location loc,
83 ValueRange indices,
84 ArrayRef<int> scalableOffsets) {
85 auto vscale = builder.create<vector::VectorScaleOp>(location: loc);
86 return llvm::map_to_vector(
87 C: llvm::zip_equal(t&: indices, u&: scalableOffsets), F: [&](auto pair) -> Value {
88 auto [index, base] = pair;
89 auto offset = builder.create<arith::MulIOp>(
90 loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
91 return builder.create<arith::AddIOp>(loc, index, offset);
92 });
93}
94
95/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
96/// indices for one of the SME sub-tiles it will decompose into.
97///
98/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
99/// indices for each tile would need to be adjusted as follows:
100///
101/// initial indices = [a,b], inital size = 8x8, target size = 4x4
102/// ┌─────────────┬─────────────┐
103/// │[a,b] │[a,b+4] │
104/// │ │ │
105/// ├─────────────┼─────────────┤
106/// │[a+4,b] │[a+4,b+4] │
107/// │ │ │
108/// └─────────────┴─────────────┘
109SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
110 ValueRange indices,
111 SMESubTile smeTile) {
112 return addConstantScalableOffset(builder, loc, indices,
113 scalableOffsets: {smeTile.row, smeTile.col});
114}
115
116/// Returns true if `mask` is generated by an operation that can be decomposed
117/// for SME. Currently, that is just no mask, or vector.create_mask.
118/// TODO: Add support for vector.constant_mask once required for SME.
119bool isSupportedMaskOp(Value mask) {
120 return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
121}
122
123/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
124Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
125 SMESubTile smeTile) {
126 assert(isSupportedMaskOp(mask));
127 if (!mask)
128 return Value{};
129 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
130 // The operands of `vector.create_mask` (from a 2D perspective) are the
131 // coordinates where the mask ends. So we subtract where this tile starts,
132 // from the mask operands to get the parameters for this sub-tile.
133 auto smeTileMaskDims = addConstantScalableOffset(
134 builder, loc, indices: createMask.getOperands(), scalableOffsets: {-smeTile.row, -smeTile.col});
135 auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
136 location: loc, args: smeTile.type.clone(elementType: builder.getI1Type()), args&: smeTileMaskDims);
137 return smeTileCreateMask.getResult();
138}
139
140/// Constructs an iterator that returns each SME tile (with coordinates)
141/// contained within a VectorType. For example, if decomposing an [8]x[8] into
142/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
143/// (4, 4).
144auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
145 VectorType smeTileType,
146 bool transposeIndices = false) {
147 return llvm::map_range(
148 C: StaticTileOffsetRange(
149 type.getShape(),
150 {std::min(a: type.getDimSize(idx: 0), b: smeTileType.getDimSize(idx: 0)),
151 std::min(a: type.getDimSize(idx: 1), b: smeTileType.getDimSize(idx: 1))}),
152 F: [=](auto indices) {
153 int row = int(indices[0]);
154 int col = int(indices[1]);
155 if (transposeIndices)
156 std::swap(a&: row, b&: col);
157 return SMESubTile{.row: row, .col: col, .type: smeTileType};
158 });
159}
160
161/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
162/// `type`.
163int getNumberOfSMETilesForVectorType(VectorType type) {
164 assert(isMultipleOfSMETileVectorType(type) &&
165 "`type` not multiple of SME tiles");
166 int64_t vectorRows = type.getDimSize(idx: 0);
167 int64_t vectorCols = type.getDimSize(idx: 1);
168 auto elementType = type.getElementType();
169 unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType);
170 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
171}
172
173/// Legalize `arith.constant dense<value>` splat operations to fit within SME
174/// tiles by decomposing them into tile-sized operations.
175struct LegalizeArithConstantOpsByDecomposition
176 : public OpConversionPattern<arith::ConstantOp> {
177 using OpConversionPattern::OpConversionPattern;
178
179 LogicalResult
180 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 auto vectorType = dyn_cast<VectorType>(Val: constantOp.getType());
183 auto denseAttr = dyn_cast<DenseElementsAttr>(Val: constantOp.getValueAttr());
184 if (!vectorType || !denseAttr || !denseAttr.isSplat())
185 return failure();
186
187 if (!isMultipleOfSMETileVectorType(vType: vectorType))
188 return rewriter.notifyMatchFailure(arg&: constantOp,
189 msg: kMatchFailureNotSMETileTypeMultiple);
190
191 auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType());
192 auto tileCount = getNumberOfSMETilesForVectorType(type: vectorType);
193 auto tileSplat = rewriter.create<arith::ConstantOp>(
194 location: constantOp.getLoc(), args: denseAttr.resizeSplat(newType: smeTileType));
195 SmallVector<Value> repl(tileCount, tileSplat);
196 rewriter.replaceOpWithMultiple(op: constantOp, newValues: {repl});
197
198 return success();
199 }
200};
201
202/// Legalize `vector.outerproduct` operations to fit within SME tiles by
203/// decomposing them into tile-sized operations.
204struct LegalizeVectorOuterProductOpsByDecomposition
205 : public OpConversionPattern<vector::OuterProductOp> {
206 using OpConversionPattern::OpConversionPattern;
207
208 LogicalResult
209 matchAndRewrite(vector::OuterProductOp outerProductOp,
210 OneToNOpAdaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override {
212 auto vectorType = outerProductOp.getResultVectorType();
213 if (!isMultipleOfSMETileVectorType(vType: vectorType))
214 return rewriter.notifyMatchFailure(arg&: outerProductOp,
215 msg: kMatchFailureNotSMETileTypeMultiple);
216
217 Value mask;
218 Operation *rootOp = outerProductOp;
219 auto loc = outerProductOp.getLoc();
220 if (outerProductOp.isMasked()) {
221 auto maskOp = outerProductOp.getMaskingOp();
222 mask = maskOp.getMask();
223 rootOp = maskOp;
224 rewriter.setInsertionPoint(rootOp);
225 }
226
227 if (!isSupportedMaskOp(mask))
228 return rewriter.notifyMatchFailure(arg&: outerProductOp,
229 msg: kMatchFailureUnsupportedMaskOp);
230
231 ValueRange accSMETiles = adaptor.getAcc();
232 auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType());
233 VectorType sliceType = VectorType::Builder(smeTileType).dropDim(pos: 0);
234
235 SmallVector<Value> resultSMETiles;
236 for (auto [index, smeTile] : llvm::enumerate(
237 First: decomposeToSMETiles(builder&: rewriter, type: vectorType, smeTileType))) {
238
239 auto smeMask = extractSMEMask(builder&: rewriter, loc, mask, smeTile);
240 auto lhs = rewriter.create<vector::ScalableExtractOp>(
241 location: loc, args&: sliceType, args: outerProductOp.getLhs(), args&: smeTile.row);
242 auto rhs = rewriter.create<vector::ScalableExtractOp>(
243 location: loc, args&: sliceType, args: outerProductOp.getRhs(), args&: smeTile.col);
244 auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
245 location: loc, args&: smeTileType, args&: lhs, args&: rhs,
246 args: !accSMETiles.empty() ? accSMETiles[index] : Value{},
247 args: outerProductOp.getKind());
248
249 auto maskedOuterProduct =
250 vector::maskOperation(builder&: rewriter, maskableOp: smeOuterProduct, mask: smeMask);
251 resultSMETiles.push_back(Elt: maskedOuterProduct->getResult(idx: 0));
252 }
253
254 rewriter.replaceOpWithMultiple(op: rootOp, newValues: {resultSMETiles});
255 return success();
256 }
257};
258
259// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
260// get the help of the type conversion), but doing so results in the type
261// conversion adding target materializations in the `vector.mask` region
262// (invalid). This pattern matches on `vector.mask` then calls into the
263// `vector.outerproduct` pattern to work around this issue.
264struct LegalizeMaskedVectorOuterProductOpsByDecomposition
265 : public OpConversionPattern<vector::MaskOp> {
266 using OpConversionPattern::OpConversionPattern;
267
268 LogicalResult
269 matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter) const override {
271 if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
272 Val: maskOp.getMaskableOp())) {
273 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
274 getContext());
275 return static_cast<RewritePattern &>(pattern).matchAndRewrite(
276 op: outerProductOp, rewriter);
277 }
278 return failure();
279 }
280};
281
282/// Legalize `vector.transfer_read` operations to fit within SME tiles by
283/// decomposing them into tile-sized operations.
284struct LegalizeTransferReadOpsByDecomposition
285 : public OpConversionPattern<vector::TransferReadOp> {
286 using OpConversionPattern::OpConversionPattern;
287
288 LogicalResult
289 matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
290 ConversionPatternRewriter &rewriter) const override {
291 auto vectorType = readOp.getVectorType();
292 if (!isMultipleOfSMETileVectorType(vType: vectorType))
293 return rewriter.notifyMatchFailure(arg&: readOp,
294 msg: kMatchFailureNotSMETileTypeMultiple);
295
296 auto mask = readOp.getMask();
297 if (!isSupportedMaskOp(mask))
298 return rewriter.notifyMatchFailure(arg&: readOp,
299 msg: kMatchFailureUnsupportedMaskOp);
300
301 auto permutationMap = readOp.getPermutationMap();
302 if (!permutationMap.isPermutation())
303 return rewriter.notifyMatchFailure(arg&: readOp,
304 msg: kMatchFailureNonPermutationMap);
305
306 // Note: For 2D vector types the only non-identity permutation is a simple
307 // transpose [1, 0].
308 bool transposed = !permutationMap.isIdentity();
309
310 auto loc = readOp.getLoc();
311 auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType());
312
313 SmallVector<Value> resultSMETiles;
314 for (SMESubTile smeTile :
315 decomposeToSMETiles(builder&: rewriter, type: vectorType, smeTileType, transposeIndices: transposed)) {
316 auto smeMask = extractSMEMask(builder&: rewriter, loc, mask, smeTile);
317 auto smeRead = rewriter.create<vector::TransferReadOp>(
318 location: loc, args&: smeTileType, args: readOp.getBase(),
319 args: getSMESubTileIndices(builder&: rewriter, loc, indices: readOp.getIndices(), smeTile),
320 args: readOp.getPermutationMapAttr(), args: readOp.getPadding(), args&: smeMask,
321 args: readOp.getInBoundsAttr());
322 resultSMETiles.push_back(Elt: smeRead);
323 }
324
325 rewriter.replaceOpWithMultiple(op: readOp, newValues: {resultSMETiles});
326 return success();
327 }
328};
329
330/// Legalize `vector.transfer_write` operations to fit within SME tiles by
331/// decomposing them into tile-sized operations.
332struct LegalizeTransferWriteOpsByDecomposition
333 : public OpConversionPattern<vector::TransferWriteOp> {
334 using OpConversionPattern::OpConversionPattern;
335
336 LogicalResult
337 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
338 ConversionPatternRewriter &rewriter) const override {
339 auto vectorType = writeOp.getVectorType();
340 if (!isMultipleOfSMETileVectorType(vType: vectorType))
341 return rewriter.notifyMatchFailure(arg&: writeOp,
342 msg: kMatchFailureNotSMETileTypeMultiple);
343
344 auto mask = writeOp.getMask();
345 if (!isSupportedMaskOp(mask))
346 return rewriter.notifyMatchFailure(arg&: writeOp,
347 msg: kMatchFailureUnsupportedMaskOp);
348
349 auto permutationMap = writeOp.getPermutationMap();
350 if (!permutationMap.isPermutation())
351 return rewriter.notifyMatchFailure(arg&: writeOp,
352 msg: kMatchFailureNonPermutationMap);
353
354 // Note: For 2D vector types the only non-identity permutation is a simple
355 // transpose [1, 0].
356 bool transposed = !permutationMap.isIdentity();
357
358 auto loc = writeOp.getLoc();
359 auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType());
360 auto inputSMETiles = adaptor.getValueToStore();
361
362 Value destTensorOrMemref = writeOp.getBase();
363 for (auto [index, smeTile] : llvm::enumerate(First: decomposeToSMETiles(
364 builder&: rewriter, type: vectorType, smeTileType, transposeIndices: transposed))) {
365 auto smeMask = extractSMEMask(builder&: rewriter, loc, mask, smeTile);
366 auto smeWrite = rewriter.create<vector::TransferWriteOp>(
367 location: loc, args: inputSMETiles[index], args&: destTensorOrMemref,
368 args: getSMESubTileIndices(builder&: rewriter, loc, indices: writeOp.getIndices(), smeTile),
369 args: writeOp.getPermutationMapAttr(), args&: smeMask, args: writeOp.getInBoundsAttr());
370 if (writeOp.hasPureTensorSemantics())
371 destTensorOrMemref = smeWrite.getResult();
372 }
373
374 if (writeOp.hasPureTensorSemantics())
375 rewriter.replaceOp(op: writeOp, newValues: destTensorOrMemref);
376 else
377 rewriter.eraseOp(op: writeOp);
378
379 return success();
380 }
381};
382
383/// Legalize a multi-tile transfer_write as a single store loop. This is done as
384/// part of type decomposition as at this level we know each tile write is
385/// disjoint, but that information is lost after decomposition (without analysis
386/// to reconstruct it).
387///
388/// Example (pseudo-MLIR):
389///
390/// ```
391/// vector.transfer_write %vector, %dest[%y, %x], %mask
392/// : vector<[16]x[8]xi16>, memref<?x?xi16>
393/// ```
394/// Is rewritten to:
395/// ```
396/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
397/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
398/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
399/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
400/// : vector<[8]xi16> from vector<[8]x[8]xi16> |
401/// vector.transfer_write %upper_slice, |
402/// %dest[%slice_idx + %y, %x], %upper_slice_mask |
403/// : vector<[8]xi16>, memref<?x?xi16> ┘
404/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
405/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
406/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
407/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
408/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
409/// vector.transfer_write %lower_slice, |
410/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
411/// : vector<[8]xi16>, memref<?x?xi16> ┘
412/// }
413/// ```
414struct LegalizeMultiTileTransferWriteAsStoreLoop
415 : public OpConversionPattern<vector::TransferWriteOp> {
416 using OpConversionPattern::OpConversionPattern;
417
418 LogicalResult
419 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override {
421 if (writeOp.hasPureTensorSemantics())
422 return rewriter.notifyMatchFailure(
423 arg&: writeOp, msg: "TODO: tensor semantics are unsupported");
424
425 auto permutationMap = writeOp.getPermutationMap();
426 if (!permutationMap.isPermutation())
427 return rewriter.notifyMatchFailure(arg&: writeOp,
428 msg: kMatchFailureNonPermutationMap);
429
430 bool transposed = !permutationMap.isIdentity();
431 if (transposed)
432 return rewriter.notifyMatchFailure(arg&: writeOp,
433 msg: "TODO: transpose unsupported");
434
435 auto vectorType = writeOp.getVectorType();
436 if (!isMultipleOfSMETileVectorType(vType: vectorType))
437 return rewriter.notifyMatchFailure(arg&: writeOp,
438 msg: kMatchFailureNotSMETileTypeMultiple);
439
440 // Note: We also disallow masks where any dimension is > 16 because that
441 // prevents the masking from being lowered to use arm_sve.psel.
442 auto mask = writeOp.getMask();
443 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(idx: 0) > 16 ||
444 vectorType.getDimSize(idx: 1) > 16)))
445 return rewriter.notifyMatchFailure(arg&: writeOp,
446 msg: kMatchFailureUnsupportedMaskOp);
447
448 auto loc = writeOp.getLoc();
449 auto createVscaleMultiple =
450 vector::makeVscaleConstantBuilder(rewriter, loc);
451
452 // Get SME tile and slice types.
453 auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType());
454 auto minTileSlices = smeTileType.getDimSize(idx: 0);
455 VectorType sliceMaskType =
456 VectorType::get(shape: minTileSlices, elementType: rewriter.getI1Type(), scalableDims: true);
457
458 // Create loop over all tile slices.
459 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
460 auto upperBound = createVscaleMultiple(minTileSlices);
461 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
462 auto storeLoop =
463 rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step);
464 rewriter.setInsertionPointToStart(storeLoop.getBody());
465
466 // For each sub-tile of the multi-tile `vectorType`.
467 auto inputSMETiles = adaptor.getValueToStore();
468 auto tileSliceIndex = storeLoop.getInductionVar();
469 for (auto [index, smeTile] : llvm::enumerate(
470 First: decomposeToSMETiles(builder&: rewriter, type: vectorType, smeTileType))) {
471 // The coordinates of the tile within `vectorType`.
472 auto tileRow = createVscaleMultiple(smeTile.row);
473 auto tileCol = createVscaleMultiple(smeTile.col);
474
475 // The current slice of `vectorType` we are processing.
476 auto sliceIndex =
477 rewriter.create<arith::AddIOp>(location: loc, args&: tileRow, args&: tileSliceIndex);
478
479 // Where in the destination memref the current slice will be stored.
480 auto storeRow = rewriter.create<arith::AddIOp>(location: loc, args&: sliceIndex,
481 args: writeOp.getIndices()[0]);
482 auto storeCol =
483 rewriter.create<arith::AddIOp>(location: loc, args&: tileCol, args: writeOp.getIndices()[1]);
484
485 // Extract the mask for the current slice.
486 Value sliceMask = nullptr;
487 if (mask) {
488 sliceMask = rewriter.create<vector::ExtractOp>(
489 location: loc, args&: mask, args: OpFoldResult(sliceIndex));
490 if (sliceMaskType != sliceMask.getType())
491 sliceMask = rewriter.create<vector::ScalableExtractOp>(
492 location: loc, args&: sliceMaskType, args&: sliceMask, args&: smeTile.col);
493 }
494
495 // Extract and store the current slice.
496 Value tile = inputSMETiles[index];
497 auto slice =
498 rewriter.create<vector::ExtractOp>(location: loc, args&: tile, args&: tileSliceIndex);
499 rewriter.create<vector::TransferWriteOp>(
500 location: loc, args&: slice, args: writeOp.getBase(), args: ValueRange{storeRow, storeCol},
501 args: AffineMapAttr::get(value: writeOp.getPermutationMap().dropResult(pos: 0)),
502 args&: sliceMask,
503 args: rewriter.getBoolArrayAttr(
504 values: ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
505 }
506
507 rewriter.eraseOp(op: writeOp);
508 return success();
509 }
510};
511
512//===----------------------------------------------------------------------===//
513// ArmSME-specific fixup canonicalizations/folds
514//===----------------------------------------------------------------------===//
515
516/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
517/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
518/// necessary for the mask to be lowered to ArmSME.
519///
520/// Example:
521///
522/// BEFORE:
523/// ```mlir
524/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
525/// %subMask = vector.extract %mask[2]
526/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
527/// ```
528///
529/// AFTER:
530/// ```mlir
531/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
532/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
533/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
534/// ```
535struct FoldExtractFromVectorOfSMELikeCreateMasks
536 : public OpRewritePattern<vector::ExtractOp> {
537 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
538
539 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
540 PatternRewriter &rewriter) const override {
541 auto loc = extractOp.getLoc();
542 auto createMaskOp =
543 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
544 if (!createMaskOp)
545 return rewriter.notifyMatchFailure(
546 arg&: extractOp, msg: "extract not from vector.create_mask op");
547
548 VectorType extractedMaskType =
549 llvm::dyn_cast<VectorType>(Val: extractOp.getResult().getType());
550 if (!extractedMaskType)
551 return rewriter.notifyMatchFailure(arg&: extractOp,
552 msg: "extracted type is not a vector type");
553
554 auto numScalable = extractedMaskType.getNumScalableDims();
555 if (numScalable != 2)
556 return rewriter.notifyMatchFailure(
557 arg&: extractOp, msg: "expected extracted type to be an SME-like mask");
558
559 // TODO: Support multiple extraction indices.
560 if (extractOp.getStaticPosition().size() != 1)
561 return rewriter.notifyMatchFailure(
562 arg&: extractOp, msg: "only a single extraction index is supported");
563
564 auto frontMaskDim = createMaskOp.getOperand(i: 0);
565 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
566 return rewriter.notifyMatchFailure(
567 arg&: extractOp,
568 msg: "constant vector.create_masks dims should be folded elsewhere");
569
570 auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
571 auto extractionIndex = getValueOrCreateConstantIndexOp(
572 b&: rewriter, loc, ofr: extractOp.getMixedPosition()[0]);
573 auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
574 location: loc, args: rewriter.getI1Type(), args: arith::CmpIPredicate::slt, args&: extractionIndex,
575 args&: frontMaskDim);
576 auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
577 location: loc, args&: extractionInTrueRegion, args: createMaskOp.getOperand(i: 1), args&: zero);
578
579 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
580 op: extractOp, args&: extractedMaskType,
581 args: ValueRange{newMaskFrontDim, createMaskOp.getOperand(i: 2)});
582 return success();
583 }
584};
585
586/// A vector type where no fixed dimension comes after a scalable dimension.
587bool isLegalVectorType(VectorType vType) {
588 bool seenFixedDim = false;
589 for (bool scalableFlag : llvm::reverse(C: vType.getScalableDims())) {
590 seenFixedDim |= !scalableFlag;
591 if (seenFixedDim && scalableFlag)
592 return false;
593 }
594 return true;
595}
596
597/// Lifts an illegal vector.transpose and vector.transfer_read to a
598/// memref.subview + memref.transpose, followed by a legal read.
599///
600/// 'Illegal' here means a leading scalable dimension and a fixed trailing
601/// dimension, which has no valid lowering.
602///
603/// The memref.transpose is metadata-only transpose that produces a strided
604/// memref, which eventually becomes a loop reading individual elements.
605///
606/// Example:
607///
608/// BEFORE:
609/// ```mlir
610/// %illegalRead = vector.transfer_read %memref[%a, %b]
611/// : memref<?x?xf32>, vector<[8]x4xf32>
612/// %legalType = vector.transpose %illegalRead, [1, 0]
613/// : vector<[8]x4xf32> to vector<4x[8]xf32>
614/// ```
615///
616/// AFTER:
617/// ```mlir
618/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
619/// : memref<?x?xf32> to memref<?x?xf32>
620/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
621/// : memref<?x?xf32> to memref<?x?xf32>
622/// %legalType = vector.transfer_read %transpose[%c0, %c0]
623/// : memref<?x?xf32>, vector<4x[8]xf32>
624/// ```
625struct LiftIllegalVectorTransposeToMemory
626 : public OpRewritePattern<vector::TransposeOp> {
627 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
628
629 static Value getExtensionSource(Operation *op) {
630 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(Val: op))
631 return op->getOperand(idx: 0);
632 return {};
633 }
634
635 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
636 PatternRewriter &rewriter) const override {
637 auto sourceType = transposeOp.getSourceVectorType();
638 auto resultType = transposeOp.getResultVectorType();
639 if (isLegalVectorType(vType: sourceType) || !isLegalVectorType(vType: resultType))
640 return rewriter.notifyMatchFailure(arg&: transposeOp,
641 msg: kMatchFailureNotIllegalToLegal);
642
643 // Look through extend for transfer_read.
644 Value maybeRead = transposeOp.getVector();
645 auto *transposeSourceOp = maybeRead.getDefiningOp();
646 Operation *extendOp = nullptr;
647 if (Value extendSource = getExtensionSource(op: transposeSourceOp)) {
648 maybeRead = extendSource;
649 extendOp = transposeSourceOp;
650 }
651
652 auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
653 if (!illegalRead)
654 return rewriter.notifyMatchFailure(
655 arg&: transposeOp,
656 msg: "expected source to be (possibly extended) transfer_read");
657
658 if (!illegalRead.getPermutationMap().isIdentity())
659 return rewriter.notifyMatchFailure(
660 arg&: illegalRead, msg: "expected read to have identity permutation map");
661
662 auto loc = transposeOp.getLoc();
663 auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
664 auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
665
666 // Create a subview that matches the size of the illegal read vector type.
667 auto readType = illegalRead.getVectorType();
668 auto readSizes = llvm::map_to_vector(
669 C: llvm::zip_equal(t: readType.getShape(), u: readType.getScalableDims()),
670 F: [&](auto dim) -> Value {
671 auto [size, isScalable] = dim;
672 auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
673 if (!isScalable)
674 return dimSize;
675 auto vscale = rewriter.create<vector::VectorScaleOp>(location: loc);
676 return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
677 });
678 SmallVector<Value> strides(readType.getRank(), Value(one));
679 auto readSubview = rewriter.create<memref::SubViewOp>(
680 location: loc, args: illegalRead.getBase(), args: illegalRead.getIndices(), args&: readSizes,
681 args&: strides);
682
683 // Apply the transpose to all values/attributes of the transfer_read:
684 // - The mask
685 Value mask = illegalRead.getMask();
686 if (mask) {
687 // Note: The transpose for the mask should fold into the
688 // vector.create_mask/constant_mask op, which will then become legal.
689 mask = rewriter.create<vector::TransposeOp>(location: loc, args&: mask,
690 args: transposeOp.getPermutation());
691 }
692 // - The source memref
693 mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
694 permutation: transposeOp.getPermutation(), context: getContext());
695 auto transposedSubview = rewriter.create<memref::TransposeOp>(
696 location: loc, args&: readSubview, args: AffineMapAttr::get(value: transposeMap));
697 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
698 // - The `in_bounds` attribute
699 if (inBoundsAttr) {
700 SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
701 inBoundsAttr.end());
702 applyPermutationToVector(inVec&: inBoundsValues, permutation: transposeOp.getPermutation());
703 inBoundsAttr = rewriter.getArrayAttr(value: inBoundsValues);
704 }
705
706 VectorType legalReadType = resultType.clone(elementType: readType.getElementType());
707 // Note: The indices are all zero as the subview is already offset.
708 SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
709 auto legalRead = rewriter.create<vector::TransferReadOp>(
710 location: loc, args&: legalReadType, args&: transposedSubview, args&: readIndices,
711 args: illegalRead.getPermutationMapAttr(), args: illegalRead.getPadding(), args&: mask,
712 args&: inBoundsAttr);
713
714 // Replace the transpose with the new read, extending the result if
715 // necessary.
716 rewriter.replaceOp(op: transposeOp, newOp: [&]() -> Operation * {
717 if (extendOp)
718 return rewriter.create(loc, opName: extendOp->getName().getIdentifier(),
719 operands: Value(legalRead), types: resultType);
720 return legalRead;
721 }());
722
723 return success();
724 }
725};
726
727/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
728/// the ZA state. This workaround rewrite to support these transposes when ZA is
729/// available.
730///
731/// Example:
732///
733/// BEFORE:
734/// ```mlir
735/// %transpose = vector.transpose %vec, [1, 0]
736/// : vector<2x[4]xf32> to vector<[4]x2xf32>
737/// vector.transfer_write %transpose, %dest[%y, %x]
738/// : vector<[4]x2xf32>, memref<?x?xf32>
739/// ```
740///
741/// AFTER:
742/// ```mlir
743/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
744/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
745/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
746/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
747/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
748/// %c4_vscale = arith.muli %vscale, %c4 : index
749/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
750/// vector.transfer_write %4, %dest[%y, %x], %mask
751/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
752/// : vector<[4]x[4]xf32>, memref<?x?xf32>
753/// ```
754///
755/// Values larger than a single tile are supported via decomposition.
756struct LowerIllegalTransposeStoreViaZA
757 : public OpRewritePattern<vector::TransferWriteOp> {
758 using OpRewritePattern::OpRewritePattern;
759
760 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
761 PatternRewriter &rewriter) const override {
762 if (!isSupportedMaskOp(mask: writeOp.getMask()))
763 return rewriter.notifyMatchFailure(arg&: writeOp,
764 msg: kMatchFailureUnsupportedMaskOp);
765
766 auto permutationMap = writeOp.getPermutationMap();
767 if (!permutationMap.isIdentity())
768 return rewriter.notifyMatchFailure(arg&: writeOp,
769 msg: kMatchFailureNonPermutationMap);
770
771 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
772 if (!transposeOp)
773 return failure();
774
775 auto sourceType = transposeOp.getSourceVectorType();
776 auto resultType = transposeOp.getResultVectorType();
777
778 if (resultType.getRank() != 2)
779 return rewriter.notifyMatchFailure(arg&: transposeOp, msg: "TransposeOp not rank 2");
780
781 if (!isLegalVectorType(vType: sourceType) || isLegalVectorType(vType: resultType))
782 return rewriter.notifyMatchFailure(
783 arg&: transposeOp, msg: "not illegal/unsupported SVE transpose");
784
785 auto smeTileType = getSMETileTypeForElement(elementType: resultType.getElementType());
786 VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(pos: 0);
787
788 if (sourceType.getDimSize(idx: 0) <= 1 ||
789 sourceType.getDimSize(idx: 1) % smeSliceType.getDimSize(idx: 0) != 0)
790 return rewriter.notifyMatchFailure(arg&: writeOp, msg: "unsupported source shape");
791
792 auto loc = writeOp.getLoc();
793 auto createVscaleMultiple =
794 vector::makeVscaleConstantBuilder(rewriter, loc);
795
796 auto transposeMap = AffineMapAttr::get(
797 value: AffineMap::getPermutationMap(permutation: ArrayRef<int64_t>{1, 0}, context: getContext()));
798
799 // Note: We need to use `get_tile` as there's no vector-level `undef`.
800 Value undefTile = rewriter.create<arm_sme::GetTileOp>(location: loc, args&: smeTileType);
801 Value destTensorOrMemref = writeOp.getBase();
802 auto numSlicesPerTile =
803 std::min(a: sourceType.getDimSize(idx: 0), b: smeTileType.getDimSize(idx: 0));
804 auto numSlices =
805 rewriter.create<arith::ConstantIndexOp>(location: loc, args&: numSlicesPerTile);
806 for (auto [index, smeTile] : llvm::enumerate(
807 First: decomposeToSMETiles(builder&: rewriter, type: sourceType, smeTileType))) {
808 // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
809 // of slices from the source type into the SME tile. Without checking
810 // vscale (and emitting multiple implementations) we can't make use of the
811 // rows of the tile after 1*vscale rows.
812 Value tile = undefTile;
813 for (int d = 0; d < numSlicesPerTile; ++d) {
814 Value vector = rewriter.create<vector::ExtractOp>(
815 location: loc, args: transposeOp.getVector(),
816 args: rewriter.getIndexAttr(value: d + smeTile.row));
817 if (vector.getType() != smeSliceType) {
818 vector = rewriter.create<vector::ScalableExtractOp>(
819 location: loc, args&: smeSliceType, args&: vector, args&: smeTile.col);
820 }
821 tile = rewriter.create<vector::InsertOp>(location: loc, args&: vector, args&: tile, args&: d);
822 }
823
824 // 2. Transpose the tile position.
825 auto transposedRow = createVscaleMultiple(smeTile.col);
826 auto transposedCol =
827 rewriter.create<arith::ConstantIndexOp>(location: loc, args&: smeTile.row);
828
829 // 3. Compute mask for tile store.
830 Value maskRows;
831 Value maskCols;
832 if (auto mask = writeOp.getMask()) {
833 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
834 maskRows = rewriter.create<arith::SubIOp>(location: loc, args: createMask.getOperand(i: 0),
835 args&: transposedRow);
836 maskCols = rewriter.create<arith::SubIOp>(location: loc, args: createMask.getOperand(i: 1),
837 args&: transposedCol);
838 maskCols = rewriter.create<index::MinSOp>(location: loc, args&: maskCols, args&: numSlices);
839 } else {
840 maskRows = createVscaleMultiple(smeTileType.getDimSize(idx: 0));
841 maskCols = numSlices;
842 }
843 auto subMask = rewriter.create<vector::CreateMaskOp>(
844 location: loc, args: smeTileType.clone(elementType: rewriter.getI1Type()),
845 args: ValueRange{maskRows, maskCols});
846
847 // 4. Emit a transposed tile write.
848 auto writeIndices = writeOp.getIndices();
849 Value destRow =
850 rewriter.create<arith::AddIOp>(location: loc, args&: transposedRow, args: writeIndices[0]);
851 Value destCol =
852 rewriter.create<arith::AddIOp>(location: loc, args&: transposedCol, args: writeIndices[1]);
853 auto smeWrite = rewriter.create<vector::TransferWriteOp>(
854 location: loc, args&: tile, args&: destTensorOrMemref, args: ValueRange{destRow, destCol},
855 args&: transposeMap, args&: subMask, args: writeOp.getInBounds());
856
857 if (writeOp.hasPureTensorSemantics())
858 destTensorOrMemref = smeWrite.getResult();
859 }
860
861 if (writeOp.hasPureTensorSemantics())
862 rewriter.replaceOp(op: writeOp, newValues: destTensorOrMemref);
863 else
864 rewriter.eraseOp(op: writeOp);
865
866 return success();
867 }
868};
869
870/// Lower `vector.transfer_read` of a scalable column to `scf::for`
871///
872/// Lowers a "read" of a scalable column from a MemRef for which there is no
873/// hardware pperation that we could use to a loop over the rows to read and
874/// loads one element at a time.
875///
876/// BEFORE:
877/// ```
878/// %res = vector.transfer_read %mem[%a, %b] (...)
879/// : memref<?x?xf32>, vector<[4]x1xf32>
880/// ```
881///
882/// AFTER:
883/// ```
884/// %cst = arith.constant (...) : vector<[4]xf32>
885/// %vscale = vector.vscale
886/// %c4_vscale = arith.muli %vscale, %c4 : index
887/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
888/// -> (vector<[4]xf32>) {
889///
890/// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
891/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
892/// scf.yield %vec : vector<[4]xf32>
893/// }
894/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
895/// ```
896///
897/// TODO: This transformation isn't specific to SME - move it to the SVE
898/// dialect.
899/// TODO: Check the in_bounds attribute and generate vector.maskedload if
900/// required.
901struct LowerColumnTransferReadToLoops
902 : public OpRewritePattern<vector::TransferReadOp> {
903 using OpRewritePattern::OpRewritePattern;
904
905 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
906 PatternRewriter &rewriter) const override {
907 // NOTE: This is a fairly low-level transformation, so we shouldn't be
908 // adding support for Tensors without good rationale.
909 if (readOp.hasPureTensorSemantics())
910 return rewriter.notifyMatchFailure(
911 arg&: readOp, msg: "Tensor semantics are unsupported (either bufferize or "
912 "extend this pattern)");
913
914 auto resType = readOp.getVectorType();
915
916 if (resType.getRank() != 2)
917 return rewriter.notifyMatchFailure(arg&: readOp,
918 msg: "Only 2D vectors are supported!");
919
920 if (resType.getShape()[1] != 1)
921 return rewriter.notifyMatchFailure(
922 arg&: readOp, msg: "The trailing output dim is != 1 (not supported ATM)");
923
924 if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
925 return rewriter.notifyMatchFailure(
926 arg&: readOp, msg: "Expected the leading dim to be scalable and the trailing "
927 "dim to be fixed.");
928
929 // Create new result type - similar to the original vector with the
930 // trailing unit dim collapsed.
931 int64_t numRows = resType.getShape()[0];
932 VectorType newResType = VectorType::get(shape: numRows, elementType: resType.getElementType(),
933 /*scalableDims=*/{true});
934
935 // Create a loop over all rows and load one element at a time.
936 auto loc = readOp.getLoc();
937 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
938 auto createVscaleMultiple =
939 vector::makeVscaleConstantBuilder(rewriter, loc);
940 auto upperBound = createVscaleMultiple(numRows);
941 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
942 Value init = rewriter.create<arith::ConstantOp>(
943 location: loc, args&: newResType, args: DenseElementsAttr::get(type: newResType, value: 0.0f));
944
945 scf::ForOp loadLoop;
946 {
947 OpBuilder::InsertionGuard g(rewriter);
948 loadLoop = rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step,
949 args: ValueRange{init});
950 rewriter.setInsertionPointToStart(loadLoop.getBody());
951
952 auto tileSliceIndex = loadLoop.getInductionVar();
953
954 auto idx0 = rewriter.create<arith::AddIOp>(location: loc, args&: tileSliceIndex,
955 args: readOp.getIndices()[0]);
956 auto idx1 = readOp.getIndices()[1];
957
958 Value scalar = rewriter.create<memref::LoadOp>(
959 location: loc, args: readOp.getBase(), args: SmallVector<Value>({idx0, idx1}));
960
961 Operation *updateInit = rewriter.create<vector::InsertOp>(
962 location: loc, args&: scalar, args: loadLoop.getRegionIterArg(index: 0), args&: tileSliceIndex);
963
964 rewriter.create<scf::YieldOp>(location: loc, args: updateInit->getResult(idx: 0));
965 }
966
967 // The read operation has been "legalized", but since the original result
968 // type was a 2D vector, we need to cast before returning the result. This
969 // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
970 // no-op).
971 auto sc = rewriter.create<vector::ShapeCastOp>(
972 location: loc, args: readOp.getResult().getType(), args: loadLoop.getResult(i: 0));
973
974 rewriter.replaceOp(op: readOp, newOp: sc);
975
976 return success();
977 }
978};
979
980struct VectorLegalizationPass
981 : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
982 void runOnOperation() override {
983 auto *context = &getContext();
984 TypeConverter converter;
985 RewritePatternSet patterns(context);
986 converter.addConversion(callback: [](Type type) { return type; });
987 converter.addConversion(
988 callback: [](VectorType vectorType,
989 SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
990 if (!isMultipleOfSMETileVectorType(vType: vectorType))
991 return std::nullopt;
992 auto smeTileCount = getNumberOfSMETilesForVectorType(type: vectorType);
993 auto smeTileType =
994 getSMETileTypeForElement(elementType: vectorType.getElementType());
995 types = SmallVector<Type>(smeTileCount, smeTileType);
996 return success();
997 });
998
999 // Apply preprocessing patterns.
1000 RewritePatternSet rewritePatterns(context);
1001 rewritePatterns
1002 .add<FoldExtractFromVectorOfSMELikeCreateMasks,
1003 LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1004 LowerIllegalTransposeStoreViaZA>(arg&: context);
1005 if (failed(
1006 Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(rewritePatterns))))
1007 return signalPassFailure();
1008
1009 // Note: These two patterns are added with a high benefit to ensure:
1010 // - Masked outer products are handled before unmasked ones
1011 // - Multi-tile writes are lowered as a store loop (if possible)
1012 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
1013 LegalizeMultiTileTransferWriteAsStoreLoop>(arg&: converter, args&: context,
1014 /*benefit=*/args: 1024);
1015 patterns.add<LegalizeArithConstantOpsByDecomposition,
1016 LegalizeVectorOuterProductOpsByDecomposition,
1017 LegalizeTransferReadOpsByDecomposition,
1018 LegalizeTransferWriteOpsByDecomposition>(arg&: converter, args&: context);
1019 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1020 converter);
1021 populateCallOpTypeConversionPattern(patterns, converter);
1022 populateReturnOpTypeConversionPattern(patterns, converter);
1023 scf::populateSCFStructuralTypeConversions(typeConverter: converter, patterns);
1024
1025 ConversionTarget target(getContext());
1026 target.markUnknownOpDynamicallyLegal(
1027 fn: [&](Operation *op) { return converter.isLegal(op); });
1028 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
1029 return converter.isSignatureLegal(ty: op.getFunctionType());
1030 });
1031 if (failed(Result: applyPartialConversion(op: getOperation(), target,
1032 patterns: std::move(patterns))))
1033 return signalPassFailure();
1034 }
1035};
1036
1037} // namespace
1038
1039std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
1040 return std::make_unique<VectorLegalizationPass>();
1041}
1042

source code of mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp