1//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
10
11#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12#include "mlir/Dialect/ArmSME/Utils/Utils.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "llvm/Support/Casting.h"
16
17using namespace mlir;
18
19namespace {
20
21/// Conversion pattern for vector.transfer_read.
22///
23/// ---
24///
25/// Example 1: op with identity permutation map to horizontal
26/// arm_sme.tile_load:
27///
28/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
29///
30/// is converted to:
31///
32/// arm_sme.tile_load ...
33///
34/// ---
35///
36/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
37/// (in-flight transpose):
38///
39/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
40///
41/// is converted to:
42///
43/// arm_sme.tile_load ... layout<vertical>
44struct TransferReadToArmSMELowering
45 : public OpRewritePattern<vector::TransferReadOp> {
46 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
47
48 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
49 PatternRewriter &rewriter) const final {
50 // The permutation map must have two results.
51 if (transferReadOp.getTransferRank() != 2)
52 return rewriter.notifyMatchFailure(transferReadOp,
53 "not a 2 result permutation map");
54
55 auto vectorType = transferReadOp.getVectorType();
56 if (!arm_sme::isValidSMETileVectorType(vType: vectorType))
57 return rewriter.notifyMatchFailure(transferReadOp,
58 "not a valid vector type for SME");
59
60 if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
61 return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
62
63 // Out-of-bounds dims are not supported.
64 if (transferReadOp.hasOutOfBoundsDim())
65 return rewriter.notifyMatchFailure(transferReadOp,
66 "not inbounds transfer read");
67
68 arm_sme::TileSliceLayout layout;
69
70 AffineExpr d0, d1;
71 bindDims(transferReadOp.getContext(), d0, d1);
72 AffineMap map = transferReadOp.getPermutationMap();
73 if (map.isIdentity())
74 layout = arm_sme::TileSliceLayout::Horizontal;
75 else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
76 transferReadOp.getContext()))
77 layout = arm_sme::TileSliceLayout::Vertical;
78 else
79 return rewriter.notifyMatchFailure(transferReadOp,
80 "unsupported permutation map");
81
82 // Padding isn't optional for transfer_read, but is only used in the case
83 // of out-of-bounds accesses (not supported here) and/or masking. Mask is
84 // optional, if it's not present don't pass padding.
85 auto mask = transferReadOp.getMask();
86 auto padding = mask ? transferReadOp.getPadding() : nullptr;
87 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
88 transferReadOp, vectorType, transferReadOp.getSource(),
89 transferReadOp.getIndices(), padding, mask, layout);
90
91 return success();
92 }
93};
94
95/// Conversion pattern for vector.transfer_write.
96///
97/// ---
98///
99/// Example 1: op with identity permutation map to horizontal
100/// arm_sme.tile_store:
101///
102/// vector.transfer_write %vector, %source[%c0, %c0]
103/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
104///
105/// is converted to:
106///
107/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
108/// vector<[16]x[16]xi8>
109/// ---
110///
111/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
112/// (in-flight transpose):
113///
114/// vector.transfer_write %vector, %source[%c0, %c0]
115/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
116/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
117///
118/// is converted to:
119///
120/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
121/// : memref<?x?xi8>, vector<[16]x[16]xi8>
122struct TransferWriteToArmSMELowering
123 : public OpRewritePattern<vector::TransferWriteOp> {
124 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
125
126 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
127 PatternRewriter &rewriter) const final {
128 auto vType = writeOp.getVectorType();
129 if (!arm_sme::isValidSMETileVectorType(vType: vType))
130 return failure();
131
132 if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
133 return failure();
134
135 // Out-of-bounds dims are not supported.
136 if (writeOp.hasOutOfBoundsDim())
137 return rewriter.notifyMatchFailure(writeOp,
138 "not inbounds transfer write");
139
140 AffineExpr d0, d1;
141 bindDims(writeOp.getContext(), d0, d1);
142 AffineMap map = writeOp.getPermutationMap();
143 bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
144 writeOp.getContext()));
145
146 if (!map.isIdentity() && !isTranspose)
147 return rewriter.notifyMatchFailure(writeOp,
148 "unsupported permutation map");
149
150 arm_sme::TileSliceLayout layout =
151 isTranspose ? arm_sme::TileSliceLayout::Vertical
152 : arm_sme::TileSliceLayout::Horizontal;
153
154 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
155 writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
156 writeOp.getMask(), layout);
157 return success();
158 }
159};
160
161/// Conversion pattern for vector.load.
162struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
163 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
164
165 LogicalResult matchAndRewrite(vector::LoadOp load,
166 PatternRewriter &rewriter) const override {
167 if (!arm_sme::isValidSMETileVectorType(vType: load.getVectorType()))
168 return failure();
169
170 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
171 load, load.getVectorType(), load.getBase(), load.getIndices());
172
173 return success();
174 }
175};
176
177/// Conversion pattern for vector.store.
178struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
179 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
180
181 LogicalResult matchAndRewrite(vector::StoreOp store,
182 PatternRewriter &rewriter) const override {
183 if (!arm_sme::isValidSMETileVectorType(vType: store.getVectorType()))
184 return failure();
185
186 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
187 store, store.getValueToStore(), store.getBase(), store.getIndices());
188
189 return success();
190 }
191};
192
193/// Conversion pattern for vector.broadcast.
194///
195/// Example:
196///
197/// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
198///
199/// is converted to:
200///
201/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
202/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
203/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
204/// {
205/// %tile_update = arm_sme.move_vector_to_tile_slice
206/// %broadcast_to_1d, %iter_tile, %tile_slice_index :
207/// vector<[4]xi32> into vector<[4]x[4]xi32>
208/// scf.yield %tile_update : vector<[4]x[4]xi32>
209/// }
210///
211/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
212struct BroadcastOpToArmSMELowering
213 : public OpRewritePattern<vector::BroadcastOp> {
214 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
215
216 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
217 PatternRewriter &rewriter) const final {
218 auto tileType = broadcastOp.getResultVectorType();
219 if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType))
220 return failure();
221
222 auto loc = broadcastOp.getLoc();
223
224 auto srcType = broadcastOp.getSourceType();
225 auto srcVectorType = dyn_cast<VectorType>(srcType);
226
227 Value broadcastOp1D;
228 if (srcType.isIntOrFloat() ||
229 (srcVectorType && (srcVectorType.getRank() == 0))) {
230 // Broadcast scalar or 0-d vector to 1-d vector.
231 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
232 broadcastOp1D = rewriter.create<vector::BroadcastOp>(
233 loc, tileSliceType, broadcastOp.getSource());
234 } else if (srcVectorType && (srcVectorType.getRank() == 1))
235 // Value to broadcast is already a 1-d vector, nothing to do.
236 broadcastOp1D = broadcastOp.getSource();
237 else
238 return failure();
239
240 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
241
242 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
243 Value currentTile) {
244 // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
245 // to each tile slice.
246 auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
247 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
248 return nextTile.getResult();
249 };
250
251 // Create a loop over ZA tile slices.
252 auto forOp =
253 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
254
255 rewriter.replaceOp(broadcastOp, forOp.getResult(0));
256
257 return success();
258 }
259};
260
261/// Conversion pattern for vector.splat.
262///
263/// Example:
264///
265/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
266///
267/// is converted to:
268///
269/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
270/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
271/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
272/// {
273/// %tile_update = arm_sme.move_vector_to_tile_slice
274/// %broadcast_to_1d, %iter_tile, %tile_slice_index :
275/// vector<[4]xi32> into vector<[4]x[4]xi32>
276/// scf.yield %tile_update : vector<[4]x[4]xi32>
277/// }
278///
279/// This is identical to vector.broadcast of a scalar.
280struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
281 using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
282
283 LogicalResult matchAndRewrite(vector::SplatOp splatOp,
284 PatternRewriter &rewriter) const final {
285 auto tileType = splatOp.getResult().getType();
286 if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType))
287 return failure();
288
289 auto loc = splatOp.getLoc();
290 auto srcType = splatOp.getOperand().getType();
291
292 assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
293 // Avoid unused-variable warning when building without assertions.
294 (void)srcType;
295
296 // First, broadcast the scalar to a 1-d vector.
297 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
298 Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
299 loc, tileSliceType, splatOp.getInput());
300
301 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
302
303 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
304 Value currentTile) {
305 auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
306 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
307 return nextTile.getResult();
308 };
309
310 // Next, create a loop over ZA tile slices and "move" the generated 1-d
311 // vector to each slice.
312 auto forOp =
313 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
314
315 rewriter.replaceOp(splatOp, forOp.getResult(0));
316
317 return success();
318 }
319};
320
321/// Conversion pattern for vector.transpose.
322///
323/// Stores the input tile to memory and reloads vertically.
324///
325/// Example:
326///
327/// %transposed_src = vector.transpose %src, [1, 0]
328/// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
329///
330/// is converted to:
331///
332/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
333/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
334/// : memref<?x?xi32>, vector<[4]x[4]xi32>
335/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
336/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
337///
338/// NOTE: Tranposing via memory is obviously expensive, the current intention
339/// is to avoid the transpose if possible, this is therefore intended as a
340/// fallback and to provide base support for Vector ops. If it turns out
341/// transposes can't be avoided then this should be replaced with a more optimal
342/// implementation, perhaps with tile <-> vector (MOVA) ops.
343struct TransposeOpToArmSMELowering
344 : public OpRewritePattern<vector::TransposeOp> {
345 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
346
347 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
348 PatternRewriter &rewriter) const final {
349 auto tileType = transposeOp.getResultVectorType();
350 if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType))
351 return failure();
352
353 // Bail unless this is a true 2-D matrix transpose.
354 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
355 if (permutation[0] != 1 || permutation[1] != 0)
356 return failure();
357
358 auto loc = transposeOp.getLoc();
359
360 // Allocate buffer to store input tile to.
361 Value vscale =
362 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
363 Value minTileSlices = rewriter.create<arith::ConstantOp>(
364 loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
365 Value c0 =
366 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
367 Value numTileSlices =
368 rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
369 auto bufferType =
370 MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
371 tileType.getElementType());
372 auto buffer = rewriter.create<memref::AllocaOp>(
373 loc, bufferType, ValueRange{numTileSlices, numTileSlices});
374
375 Value input = transposeOp.getVector();
376
377 // Store input tile.
378 auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
379 loc, input, buffer, ValueRange{c0, c0});
380
381 // Reload input tile vertically.
382 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
383 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
384 arm_sme::TileSliceLayout::Vertical);
385
386 return success();
387 }
388};
389
390/// Conversion pattern for vector.outerproduct.
391///
392/// If the vector.outerproduct is masked (and the mask is from a
393/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
394/// operands.
395///
396/// Example:
397///
398/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
399/// %result = vector.mask %mask {
400/// vector.outerproduct %vecA, %vecB
401/// : vector<[4]xf32>, vector<[4]xf32>
402/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
403///
404/// is converted to:
405///
406/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
407/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
408/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
409/// : vector<[4]xf32>, vector<[4]xf32>
410///
411/// Unmasked outerproducts can be directly replaced with the arm_sme op.
412///
413/// Example:
414///
415/// %result = vector.outerproduct %vecA, %vecB
416/// : vector<[4]xf32>, vector<[4]xf32>
417///
418/// is converted to:
419///
420/// %result = arm_sme.outerproduct %vecA, %vecB
421/// : vector<[4]xf32>, vector<[4]xf32>
422///
423struct VectorOuterProductToArmSMELowering
424 : public OpRewritePattern<vector::OuterProductOp> {
425
426 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
427
428 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
429 PatternRewriter &rewriter) const override {
430
431 // We don't yet support lowering AXPY operations to SME. These could be
432 // lowered by masking out all but the first element of the LHS.
433 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
434 return rewriter.notifyMatchFailure(outerProductOp,
435 "AXPY operations not supported");
436
437 if (!arm_sme::isValidSMETileVectorType(
438 vType: outerProductOp.getResultVectorType()))
439 return rewriter.notifyMatchFailure(
440 outerProductOp, "outer product does not fit into SME tile");
441
442 auto kind = outerProductOp.getKind();
443 if (kind != vector::CombiningKind::ADD)
444 return rewriter.notifyMatchFailure(
445 outerProductOp,
446 "unsupported kind (lowering to SME only supports ADD at the moment)");
447
448 Value lhsMask = {};
449 Value rhsMask = {};
450 Operation *rootOp = outerProductOp;
451 auto loc = outerProductOp.getLoc();
452 if (outerProductOp.isMasked()) {
453 auto maskOp = outerProductOp.getMaskingOp();
454 rewriter.setInsertionPoint(maskOp);
455 rootOp = maskOp;
456 auto operandMasks = decomposeResultMask(loc: loc, mask: maskOp.getMask(), rewriter);
457 if (failed(operandMasks))
458 return failure();
459 std::tie(args&: lhsMask, args&: rhsMask) = *operandMasks;
460 }
461
462 rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
463 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
464 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
465
466 return success();
467 }
468
469 static FailureOr<std::pair<Value, Value>>
470 decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
471 // Attempt to extract masks from vector.create_mask.
472 // TODO: Add support for other mask sources.
473 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
474 if (!createMaskOp)
475 return failure();
476
477 auto maskType = createMaskOp.getVectorType();
478 Value lhsMaskDim = createMaskOp.getOperand(0);
479 Value rhsMaskDim = createMaskOp.getOperand(1);
480
481 VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
482 Value lhsMask =
483 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
484 Value rhsMask =
485 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
486
487 return std::make_pair(x&: lhsMask, y&: rhsMask);
488 }
489};
490
491/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
492///
493/// Example:
494/// ```
495/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
496/// ```
497/// Becomes:
498/// ```
499/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
500/// : vector<[4]xi32> from vector<[4]x[4]xi32>
501/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
502/// ```
503struct VectorExtractToArmSMELowering
504 : public OpRewritePattern<vector::ExtractOp> {
505 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
506
507 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
508 PatternRewriter &rewriter) const override {
509 VectorType sourceType = extractOp.getSourceVectorType();
510 if (!arm_sme::isValidSMETileVectorType(vType: sourceType))
511 return failure();
512
513 auto loc = extractOp.getLoc();
514 auto position = extractOp.getMixedPosition();
515
516 Value sourceVector = extractOp.getVector();
517
518 // Extract entire vector. Should be handled by folder, but just to be safe.
519 if (position.empty()) {
520 rewriter.replaceOp(extractOp, sourceVector);
521 return success();
522 }
523
524 Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front();
525 auto moveTileSliceToVector =
526 rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
527 sliceIndex);
528
529 if (position.size() == 1) {
530 // Single index case: Extracts a 1D slice.
531 rewriter.replaceOp(extractOp, moveTileSliceToVector);
532 return success();
533 }
534
535 // Two indices case: Extracts a single element.
536 assert(position.size() == 2);
537 rewriter.replaceOpWithNewOp<vector::ExtractOp>(
538 extractOp, moveTileSliceToVector, position[1]);
539
540 return success();
541 }
542};
543
544/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
545/// `arm_sme.move_tile_slice_to_vector`.
546///
547/// Example:
548/// ```
549/// %new_tile = vector.insert %el, %tile[%row, %col]
550/// : i32 into vector<[4]x[4]xi32>
551/// ```
552/// Becomes:
553/// ```
554/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
555/// : vector<[4]xi32> from vector<[4]x[4]xi32>
556/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
557/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
558/// : vector<[4]xi32> into vector<[4]x[4]xi32>
559/// ```
560struct VectorInsertToArmSMELowering
561 : public OpRewritePattern<vector::InsertOp> {
562 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
563
564 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
565 PatternRewriter &rewriter) const override {
566 VectorType resultType = insertOp.getResult().getType();
567
568 if (!arm_sme::isValidSMETileVectorType(vType: resultType))
569 return failure();
570
571 auto loc = insertOp.getLoc();
572 auto position = insertOp.getMixedPosition();
573
574 Value source = insertOp.getSource();
575
576 // Overwrite entire vector with value. Should be handled by folder, but
577 // just to be safe.
578 if (position.empty()) {
579 rewriter.replaceOp(insertOp, source);
580 return success();
581 }
582
583 Value tileSlice = source;
584 Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front();
585 if (position.size() == 2) {
586 // Two indices case: Insert single element into tile.
587 // We need to first extract the existing slice and update the element.
588 tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
589 loc, insertOp.getDest(), sliceIndex);
590 tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
591 position[1]);
592 }
593
594 // Insert the slice into the destination tile.
595 rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
596 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
597 return success();
598 }
599};
600
601/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
602/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
603/// a 1D `vector.print`.
604///
605/// BEFORE:
606/// ```mlir
607/// vector.print %tile : vector<[4]x[4]xf32>
608/// ```
609/// AFTER:
610/// ```mlir
611/// %c0 = arith.constant 0 : index
612/// %c1 = arith.constant 1 : index
613/// %c4 = arith.constant 4 : index
614/// %vscale = vector.vscale
615/// %svl_s = arith.muli %c4, %vscale : index
616/// scf.for %i = %c0 to %svl_s step %c1 {
617/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
618/// : vector<[4]xf32> from vector<[4]x[4]xf32>
619/// vector.print %tile_slice : vector<[4]xf32>
620/// }
621/// ```
622struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
623 using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
624
625 LogicalResult matchAndRewrite(vector::PrintOp printOp,
626 PatternRewriter &rewriter) const override {
627 if (!printOp.getSource())
628 return failure();
629
630 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
631 if (!vectorType || !arm_sme::isValidSMETileVectorType(vType: vectorType))
632 return failure();
633
634 auto loc = printOp.getLoc();
635
636 // Create a loop over the rows of the tile.
637 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
638 auto minTileRows =
639 rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
640 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
641 auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
642 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
643 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
644 {
645 // Loop body.
646 rewriter.setInsertionPointToStart(forOp.getBody());
647 // Extract the current row from the tile.
648 Value rowIndex = forOp.getInductionVar();
649 auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
650 loc, printOp.getSource(), rowIndex);
651 // Print the row with a 1D vector.print.
652 rewriter.create<vector::PrintOp>(loc, tileSlice,
653 printOp.getPunctuation());
654 }
655
656 rewriter.eraseOp(op: printOp);
657 return success();
658 }
659};
660
661} // namespace
662
663void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
664 MLIRContext &ctx) {
665 patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
666 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
667 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
668 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
669 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
670 VectorPrintToArmSMELowering>(arg: &ctx);
671}
672

source code of mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp