1//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
10
11#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12#include "mlir/Dialect/ArmSME/Utils/Utils.h"
13#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "llvm/Support/Casting.h"
17
18using namespace mlir;
19
20namespace {
21
22/// Conversion pattern for vector.transfer_read.
23///
24/// ---
25///
26/// Example 1: op with identity permutation map to horizontal
27/// arm_sme.tile_load:
28///
29/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
30///
31/// is converted to:
32///
33/// arm_sme.tile_load ...
34///
35/// ---
36///
37/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38/// (in-flight transpose):
39///
40/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
41///
42/// is converted to:
43///
44/// arm_sme.tile_load ... layout<vertical>
45struct TransferReadToArmSMELowering
46 : public OpRewritePattern<vector::TransferReadOp> {
47 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
48
49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50 PatternRewriter &rewriter) const final {
51 // The permutation map must have two results.
52 if (transferReadOp.getTransferRank() != 2)
53 return rewriter.notifyMatchFailure(transferReadOp,
54 "not a 2 result permutation map");
55
56 auto vectorType = transferReadOp.getVectorType();
57 if (!arm_sme::isValidSMETileVectorType(vectorType))
58 return rewriter.notifyMatchFailure(transferReadOp,
59 "not a valid vector type for SME");
60
61 if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
62 return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
63
64 // Out-of-bounds dims are not supported.
65 if (transferReadOp.hasOutOfBoundsDim())
66 return rewriter.notifyMatchFailure(transferReadOp,
67 "not inbounds transfer read");
68
69 AffineMap map = transferReadOp.getPermutationMap();
70 if (!map.isPermutation())
71 return rewriter.notifyMatchFailure(transferReadOp,
72 "unsupported permutation map");
73
74 // Note: For 2D vector types the only non-identity permutation is a simple
75 // transpose [1, 0].
76 bool transposed = !map.isIdentity();
77 arm_sme::TileSliceLayout layout =
78 transposed ? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal;
80
81 // Padding isn't optional for transfer_read, but is only used in the case
82 // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83 // optional, if it's not present don't pass padding.
84 auto mask = transferReadOp.getMask();
85 auto padding = mask ? transferReadOp.getPadding() : nullptr;
86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87 transferReadOp, vectorType, transferReadOp.getBase(),
88 transferReadOp.getIndices(), padding, mask, layout);
89
90 return success();
91 }
92};
93
94/// Conversion pattern for vector.transfer_write.
95///
96/// ---
97///
98/// Example 1: op with identity permutation map to horizontal
99/// arm_sme.tile_store:
100///
101/// vector.transfer_write %vector, %source[%c0, %c0]
102/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
103///
104/// is converted to:
105///
106/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107/// vector<[16]x[16]xi8>
108/// ---
109///
110/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111/// (in-flight transpose):
112///
113/// vector.transfer_write %vector, %source[%c0, %c0]
114/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
116///
117/// is converted to:
118///
119/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120/// : memref<?x?xi8>, vector<[16]x[16]xi8>
121struct TransferWriteToArmSMELowering
122 : public OpRewritePattern<vector::TransferWriteOp> {
123 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
124
125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126 PatternRewriter &rewriter) const final {
127 auto vType = writeOp.getVectorType();
128 if (!arm_sme::isValidSMETileVectorType(vType))
129 return failure();
130
131 if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
132 return failure();
133
134 // Out-of-bounds dims are not supported.
135 if (writeOp.hasOutOfBoundsDim())
136 return rewriter.notifyMatchFailure(writeOp,
137 "not inbounds transfer write");
138
139 AffineMap map = writeOp.getPermutationMap();
140 if (!map.isPermutation())
141 return rewriter.notifyMatchFailure(writeOp,
142 "unsupported permutation map");
143
144 // Note: For 2D vector types the only non-identity permutation is a simple
145 // transpose [1, 0].
146 bool transposed = !map.isIdentity();
147 arm_sme::TileSliceLayout layout =
148 transposed ? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal;
150
151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152 writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
153 writeOp.getMask(), layout);
154 return success();
155 }
156};
157
158/// Conversion pattern for vector.load.
159struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
160 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
161
162 LogicalResult matchAndRewrite(vector::LoadOp load,
163 PatternRewriter &rewriter) const override {
164 if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165 return failure();
166
167 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168 load, load.getVectorType(), load.getBase(), load.getIndices());
169
170 return success();
171 }
172};
173
174/// Conversion pattern for vector.store.
175struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
176 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
177
178 LogicalResult matchAndRewrite(vector::StoreOp store,
179 PatternRewriter &rewriter) const override {
180 if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181 return failure();
182
183 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184 store, store.getValueToStore(), store.getBase(), store.getIndices());
185
186 return success();
187 }
188};
189
190/// Conversion pattern for vector.broadcast.
191///
192/// Example:
193///
194/// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
195///
196/// is converted to:
197///
198/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201/// {
202/// %tile_update = arm_sme.insert_tile_slice
203/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204/// vector<[4]xi32> into vector<[4]x[4]xi32>
205/// scf.yield %tile_update : vector<[4]x[4]xi32>
206/// }
207///
208/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209struct BroadcastOpToArmSMELowering
210 : public OpRewritePattern<vector::BroadcastOp> {
211 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
212
213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214 PatternRewriter &rewriter) const final {
215 auto tileType = broadcastOp.getResultVectorType();
216 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
217 return failure();
218
219 auto loc = broadcastOp.getLoc();
220
221 auto srcType = broadcastOp.getSourceType();
222 auto srcVectorType = dyn_cast<VectorType>(srcType);
223
224 Value broadcastOp1D;
225 if (srcType.isIntOrFloat() ||
226 (srcVectorType && (srcVectorType.getRank() == 0))) {
227 // Broadcast scalar or 0-d vector to 1-d vector.
228 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229 broadcastOp1D = rewriter.create<vector::BroadcastOp>(
230 loc, tileSliceType, broadcastOp.getSource());
231 } else if (srcVectorType && (srcVectorType.getRank() == 1))
232 // Value to broadcast is already a 1-d vector, nothing to do.
233 broadcastOp1D = broadcastOp.getSource();
234 else
235 return failure();
236
237 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
238
239 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
240 Value currentTile) {
241 // Create 'arm_sme.insert_tile_slice' to broadcast the value
242 // to each tile slice.
243 auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
244 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245 return nextTile.getResult();
246 };
247
248 // Create a loop over ZA tile slices.
249 auto forOp =
250 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
251
252 rewriter.replaceOp(broadcastOp, forOp.getResult(0));
253
254 return success();
255 }
256};
257
258/// Conversion pattern for vector.splat.
259///
260/// Example:
261///
262/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263///
264/// is converted to:
265///
266/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269/// {
270/// %tile_update = arm_sme.insert_tile_slice
271/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272/// vector<[4]xi32> into vector<[4]x[4]xi32>
273/// scf.yield %tile_update : vector<[4]x[4]xi32>
274/// }
275///
276/// This is identical to vector.broadcast of a scalar.
277struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
278 using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(vector::SplatOp splatOp,
281 PatternRewriter &rewriter) const final {
282 auto tileType = splatOp.getResult().getType();
283 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
284 return failure();
285
286 auto loc = splatOp.getLoc();
287 auto srcType = splatOp.getOperand().getType();
288
289 assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290 // Avoid unused-variable warning when building without assertions.
291 (void)srcType;
292
293 // First, broadcast the scalar to a 1-d vector.
294 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
295 Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
296 loc, tileSliceType, splatOp.getInput());
297
298 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
299
300 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301 Value currentTile) {
302 auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
303 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304 return nextTile.getResult();
305 };
306
307 // Next, create a loop over ZA tile slices and "move" the generated 1-d
308 // vector to each slice.
309 auto forOp =
310 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
311
312 rewriter.replaceOp(splatOp, forOp.getResult(0));
313
314 return success();
315 }
316};
317
318/// Conversion pattern for vector.transpose.
319///
320/// Stores the input tile to memory and reloads vertically.
321///
322/// Example:
323///
324/// %transposed_src = vector.transpose %src, [1, 0]
325/// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
326///
327/// is converted to:
328///
329/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
330/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
331/// : memref<?x?xi32>, vector<[4]x[4]xi32>
332/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
333/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
334///
335/// NOTE: Transposing via memory is obviously expensive, the current intention
336/// is to avoid the transpose if possible, this is therefore intended as a
337/// fallback and to provide base support for Vector ops. If it turns out
338/// transposes can't be avoided then this should be replaced with a more optimal
339/// implementation, perhaps with tile <-> vector (MOVA) ops.
340struct TransposeOpToArmSMELowering
341 : public OpRewritePattern<vector::TransposeOp> {
342 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
343
344 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
345 PatternRewriter &rewriter) const final {
346 auto tileType = transposeOp.getResultVectorType();
347 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
348 return failure();
349
350 // Bail unless this is a true 2-D matrix transpose.
351 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
352 if (permutation[0] != 1 || permutation[1] != 0)
353 return failure();
354
355 auto loc = transposeOp.getLoc();
356 Value input = transposeOp.getVector();
357
358 if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
359 xferOp && xferOp->hasOneUse()) {
360 // Fold transpose into transfer_read to enable in-flight transpose when
361 // converting to arm_sme.tile_load.
362 rewriter.modifyOpInPlace(xferOp, [&]() {
363 xferOp->setAttr(xferOp.getPermutationMapAttrName(),
364 AffineMapAttr::get(AffineMap::getPermutationMap(
365 permutation, transposeOp.getContext())));
366 });
367 rewriter.replaceOp(transposeOp, xferOp);
368 return success();
369 }
370
371 // Allocate buffer to store input tile to.
372 Value vscale =
373 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
374 Value minTileSlices = rewriter.create<arith::ConstantOp>(
375 loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
376 Value c0 =
377 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
378 Value numTileSlices =
379 rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
380 auto bufferType =
381 MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
382 tileType.getElementType());
383 auto buffer = rewriter.create<memref::AllocaOp>(
384 loc, bufferType, ValueRange{numTileSlices, numTileSlices});
385
386 // Store input tile.
387 auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
388 loc, input, buffer, ValueRange{c0, c0});
389
390 // Reload input tile vertically.
391 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
392 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
393 arm_sme::TileSliceLayout::Vertical);
394
395 return success();
396 }
397};
398
399/// Conversion pattern for vector.outerproduct.
400///
401/// If the vector.outerproduct is masked (and the mask is from a
402/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
403/// operands.
404///
405/// Example:
406///
407/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
408/// %result = vector.mask %mask {
409/// vector.outerproduct %vecA, %vecB
410/// : vector<[4]xf32>, vector<[4]xf32>
411/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
412///
413/// is converted to:
414///
415/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
416/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
417/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
418/// : vector<[4]xf32>, vector<[4]xf32>
419///
420/// Unmasked outerproducts can be directly replaced with the arm_sme op.
421///
422/// Example:
423///
424/// %result = vector.outerproduct %vecA, %vecB
425/// : vector<[4]xf32>, vector<[4]xf32>
426///
427/// is converted to:
428///
429/// %result = arm_sme.outerproduct %vecA, %vecB
430/// : vector<[4]xf32>, vector<[4]xf32>
431///
432struct VectorOuterProductToArmSMELowering
433 : public OpRewritePattern<vector::OuterProductOp> {
434
435 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
436
437 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
438 PatternRewriter &rewriter) const override {
439
440 // We don't yet support lowering AXPY operations to SME. These could be
441 // lowered by masking out all but the first element of the LHS.
442 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
443 return rewriter.notifyMatchFailure(outerProductOp,
444 "AXPY operations not supported");
445
446 if (!arm_sme::isValidSMETileVectorType(
447 outerProductOp.getResultVectorType()))
448 return rewriter.notifyMatchFailure(
449 outerProductOp, "outer product does not fit into SME tile");
450
451 auto kind = outerProductOp.getKind();
452 if (kind != vector::CombiningKind::ADD)
453 return rewriter.notifyMatchFailure(
454 outerProductOp,
455 "unsupported kind (lowering to SME only supports ADD at the moment)");
456
457 Value lhsMask = {};
458 Value rhsMask = {};
459 Operation *rootOp = outerProductOp;
460 auto loc = outerProductOp.getLoc();
461 if (outerProductOp.isMasked()) {
462 auto maskOp = outerProductOp.getMaskingOp();
463 rewriter.setInsertionPoint(maskOp);
464 rootOp = maskOp;
465 auto operandMasks = decomposeResultMask(loc: loc, mask: maskOp.getMask(), rewriter);
466 if (failed(operandMasks))
467 return failure();
468 std::tie(args&: lhsMask, args&: rhsMask) = *operandMasks;
469 }
470
471 rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
472 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
473 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
474
475 return success();
476 }
477
478 static FailureOr<std::pair<Value, Value>>
479 decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
480 // Attempt to extract masks from vector.create_mask.
481 // TODO: Add support for other mask sources.
482 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
483 if (!createMaskOp)
484 return failure();
485
486 auto maskType = createMaskOp.getVectorType();
487 Value lhsMaskDim = createMaskOp.getOperand(0);
488 Value rhsMaskDim = createMaskOp.getOperand(1);
489
490 VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
491 Value lhsMask =
492 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
493 Value rhsMask =
494 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
495
496 return std::make_pair(x&: lhsMask, y&: rhsMask);
497 }
498};
499
500/// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
501///
502/// Example:
503/// ```
504/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
505/// ```
506/// Becomes:
507/// ```
508/// %slice = arm_sme.extract_tile_slice %tile[%row]
509/// : vector<[4]xi32> from vector<[4]x[4]xi32>
510/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
511/// ```
512struct VectorExtractToArmSMELowering
513 : public OpRewritePattern<vector::ExtractOp> {
514 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
515
516 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
517 PatternRewriter &rewriter) const override {
518 VectorType sourceType = extractOp.getSourceVectorType();
519 if (!arm_sme::isValidSMETileVectorType(sourceType))
520 return failure();
521
522 auto loc = extractOp.getLoc();
523 auto position = extractOp.getMixedPosition();
524
525 Value sourceVector = extractOp.getVector();
526
527 // Extract entire vector. Should be handled by folder, but just to be safe.
528 if (position.empty()) {
529 rewriter.replaceOp(extractOp, sourceVector);
530 return success();
531 }
532
533 Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front();
534 auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
535 loc, sourceVector, sliceIndex);
536
537 if (position.size() == 1) {
538 // Single index case: Extracts a 1D slice.
539 rewriter.replaceOp(extractOp, extractTileSlice);
540 return success();
541 }
542
543 // Two indices case: Extracts a single element.
544 assert(position.size() == 2);
545 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
546 position[1]);
547
548 return success();
549 }
550};
551
552/// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
553/// `arm_sme.extract_tile_slice`.
554///
555/// Example:
556/// ```
557/// %new_tile = vector.insert %el, %tile[%row, %col]
558/// : i32 into vector<[4]x[4]xi32>
559/// ```
560/// Becomes:
561/// ```
562/// %slice = arm_sme.extract_tile_slice %tile[%row]
563/// : vector<[4]xi32> from vector<[4]x[4]xi32>
564/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
566/// : vector<[4]xi32> into vector<[4]x[4]xi32>
567/// ```
568struct VectorInsertToArmSMELowering
569 : public OpRewritePattern<vector::InsertOp> {
570 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
571
572 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
573 PatternRewriter &rewriter) const override {
574 VectorType resultType = insertOp.getResult().getType();
575
576 if (!arm_sme::isValidSMETileVectorType(resultType))
577 return failure();
578
579 auto loc = insertOp.getLoc();
580 auto position = insertOp.getMixedPosition();
581
582 Value source = insertOp.getValueToStore();
583
584 // Overwrite entire vector with value. Should be handled by folder, but
585 // just to be safe.
586 if (position.empty()) {
587 rewriter.replaceOp(insertOp, source);
588 return success();
589 }
590
591 Value tileSlice = source;
592 Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front();
593 if (position.size() == 2) {
594 // Two indices case: Insert single element into tile.
595 // We need to first extract the existing slice and update the element.
596 tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
597 loc, insertOp.getDest(), sliceIndex);
598 tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
599 position[1]);
600 }
601
602 // Insert the slice into the destination tile.
603 rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
604 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
605 return success();
606 }
607};
608
609/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
610/// extracting them via `arm_sme.extract_tile_slice`, then printing with
611/// a 1D `vector.print`.
612///
613/// BEFORE:
614/// ```mlir
615/// vector.print %tile : vector<[4]x[4]xf32>
616/// ```
617/// AFTER:
618/// ```mlir
619/// %c0 = arith.constant 0 : index
620/// %c1 = arith.constant 1 : index
621/// %c4 = arith.constant 4 : index
622/// %vscale = vector.vscale
623/// %svl_s = arith.muli %c4, %vscale : index
624/// scf.for %i = %c0 to %svl_s step %c1 {
625/// %tile_slice = arm_sme.extract_tile_slice %tile[%i]
626/// : vector<[4]xf32> from vector<[4]x[4]xf32>
627/// vector.print %tile_slice : vector<[4]xf32>
628/// }
629/// ```
630struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
631 using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
632
633 LogicalResult matchAndRewrite(vector::PrintOp printOp,
634 PatternRewriter &rewriter) const override {
635 if (!printOp.getSource())
636 return failure();
637
638 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
639 if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
640 return failure();
641
642 auto loc = printOp.getLoc();
643
644 // Create a loop over the rows of the tile.
645 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
646 auto minTileRows =
647 rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
648 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
649 auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
650 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
651 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
652 {
653 // Loop body.
654 rewriter.setInsertionPointToStart(forOp.getBody());
655 // Extract the current row from the tile.
656 Value rowIndex = forOp.getInductionVar();
657 auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
658 loc, printOp.getSource(), rowIndex);
659 // Print the row with a 1D vector.print.
660 rewriter.create<vector::PrintOp>(loc, tileSlice,
661 printOp.getPunctuation());
662 }
663
664 rewriter.eraseOp(op: printOp);
665 return success();
666 }
667};
668
669/// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
670///
671/// BEFORE:
672/// ```mlir
673/// %slice = arm_sme.extract_tile_slice %tile[%index]
674/// : vector<[4]xf32> from vector<[4]x[4]xf32>
675/// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676/// : vector<[4]xf32>, memref<?x?xf32>
677/// ```
678/// AFTER:
679/// ```mlir
680/// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681/// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682/// ```
683struct FoldTransferWriteOfExtractTileSlice
684 : public OpRewritePattern<vector::TransferWriteOp> {
685 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
686
687 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
688 PatternRewriter &rewriter) const final {
689 if (!isa<MemRefType>(writeOp.getBase().getType()))
690 return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
691
692 if (writeOp.hasOutOfBoundsDim())
693 return rewriter.notifyMatchFailure(writeOp,
694 "not inbounds transfer write");
695
696 auto extractTileSlice =
697 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
698 if (!extractTileSlice)
699 return rewriter.notifyMatchFailure(
700 writeOp, "vector to store not from ExtractTileSliceOp");
701
702 AffineMap map = writeOp.getPermutationMap();
703 if (!map.isMinorIdentity())
704 return rewriter.notifyMatchFailure(writeOp,
705 "unsupported permutation map");
706
707 Value mask = writeOp.getMask();
708 if (!mask) {
709 auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
710 mask = rewriter.create<arith::ConstantOp>(
711 writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
712 }
713
714 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
715 writeOp, extractTileSlice.getTile(),
716 extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
717 writeOp.getIndices(), extractTileSlice.getLayout());
718 return success();
719 }
720};
721
722/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
723/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
724/// SVE 2.1), so this is currently the most logical place for this lowering.
725///
726/// Example:
727/// ```mlir
728/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
729/// %slice = vector.extract %mask[%index]
730/// : vector<[8]xi1> from vector<[4]x[8]xi1>
731/// ```
732/// Becomes:
733/// ```
734/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
735/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
736/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
737/// : vector<[8]xi1>, vector<[4]xi1>
738/// ```
739struct ExtractFromCreateMaskToPselLowering
740 : public OpRewritePattern<vector::ExtractOp> {
741 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
742
743 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
744 PatternRewriter &rewriter) const override {
745 if (extractOp.getNumIndices() != 1)
746 return rewriter.notifyMatchFailure(extractOp, "not single extract index");
747
748 auto resultType = extractOp.getResult().getType();
749 auto resultVectorType = dyn_cast<VectorType>(resultType);
750 if (!resultVectorType)
751 return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
752
753 auto createMaskOp =
754 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
755 if (!createMaskOp)
756 return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
757
758 auto maskType = createMaskOp.getVectorType();
759 if (maskType.getRank() != 2 || !maskType.allDimsScalable())
760 return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
761
762 auto isSVEPredicateSize = [](int64_t size) {
763 return size > 0 && size <= 16 && llvm::isPowerOf2_32(Value: uint32_t(size));
764 };
765
766 auto rowsBaseSize = maskType.getDimSize(0);
767 auto colsBaseSize = maskType.getDimSize(1);
768 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
769 return rewriter.notifyMatchFailure(
770 createMaskOp, "mask dimensions not SVE predicate-sized");
771
772 auto loc = extractOp.getLoc();
773 VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
774 VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
775
776 // Create the two 1-D masks at the location of the 2-D create_mask (which is
777 // usually outside a loop). This prevents the need for later hoisting.
778 rewriter.setInsertionPoint(createMaskOp);
779 auto rowMask = rewriter.create<vector::CreateMaskOp>(
780 loc, rowMaskType, createMaskOp.getOperand(0));
781 auto colMask = rewriter.create<vector::CreateMaskOp>(
782 loc, colMaskType, createMaskOp.getOperand(1));
783
784 rewriter.setInsertionPoint(extractOp);
785 auto position =
786 vector::getAsValues(builder&: rewriter, loc: loc, foldResults: extractOp.getMixedPosition());
787 rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
788 position[0]);
789 return success();
790 }
791};
792
793} // namespace
794
795void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
796 MLIRContext &ctx) {
797 patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
798 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
799 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
800 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
801 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
802 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
803 ExtractFromCreateMaskToPselLowering>(arg: &ctx);
804}
805

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp