1 | //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" |
10 | |
11 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
12 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "llvm/Support/Casting.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | namespace { |
20 | |
21 | /// Conversion pattern for vector.transfer_read. |
22 | /// |
23 | /// --- |
24 | /// |
25 | /// Example 1: op with identity permutation map to horizontal |
26 | /// arm_sme.tile_load: |
27 | /// |
28 | /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) |
29 | /// |
30 | /// is converted to: |
31 | /// |
32 | /// arm_sme.tile_load ... |
33 | /// |
34 | /// --- |
35 | /// |
36 | /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load |
37 | /// (in-flight transpose): |
38 | /// |
39 | /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) |
40 | /// |
41 | /// is converted to: |
42 | /// |
43 | /// arm_sme.tile_load ... layout<vertical> |
44 | struct TransferReadToArmSMELowering |
45 | : public OpRewritePattern<vector::TransferReadOp> { |
46 | using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
47 | |
48 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
49 | PatternRewriter &rewriter) const final { |
50 | // The permutation map must have two results. |
51 | if (transferReadOp.getTransferRank() != 2) |
52 | return rewriter.notifyMatchFailure(transferReadOp, |
53 | "not a 2 result permutation map" ); |
54 | |
55 | auto vectorType = transferReadOp.getVectorType(); |
56 | if (!arm_sme::isValidSMETileVectorType(vType: vectorType)) |
57 | return rewriter.notifyMatchFailure(transferReadOp, |
58 | "not a valid vector type for SME" ); |
59 | |
60 | if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType())) |
61 | return rewriter.notifyMatchFailure(transferReadOp, "not a memref source" ); |
62 | |
63 | // Out-of-bounds dims are not supported. |
64 | if (transferReadOp.hasOutOfBoundsDim()) |
65 | return rewriter.notifyMatchFailure(transferReadOp, |
66 | "not inbounds transfer read" ); |
67 | |
68 | arm_sme::TileSliceLayout layout; |
69 | |
70 | AffineExpr d0, d1; |
71 | bindDims(transferReadOp.getContext(), d0, d1); |
72 | AffineMap map = transferReadOp.getPermutationMap(); |
73 | if (map.isIdentity()) |
74 | layout = arm_sme::TileSliceLayout::Horizontal; |
75 | else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, |
76 | transferReadOp.getContext())) |
77 | layout = arm_sme::TileSliceLayout::Vertical; |
78 | else |
79 | return rewriter.notifyMatchFailure(transferReadOp, |
80 | "unsupported permutation map" ); |
81 | |
82 | // Padding isn't optional for transfer_read, but is only used in the case |
83 | // of out-of-bounds accesses (not supported here) and/or masking. Mask is |
84 | // optional, if it's not present don't pass padding. |
85 | auto mask = transferReadOp.getMask(); |
86 | auto padding = mask ? transferReadOp.getPadding() : nullptr; |
87 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
88 | transferReadOp, vectorType, transferReadOp.getSource(), |
89 | transferReadOp.getIndices(), padding, mask, layout); |
90 | |
91 | return success(); |
92 | } |
93 | }; |
94 | |
95 | /// Conversion pattern for vector.transfer_write. |
96 | /// |
97 | /// --- |
98 | /// |
99 | /// Example 1: op with identity permutation map to horizontal |
100 | /// arm_sme.tile_store: |
101 | /// |
102 | /// vector.transfer_write %vector, %source[%c0, %c0] |
103 | /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> |
104 | /// |
105 | /// is converted to: |
106 | /// |
107 | /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, |
108 | /// vector<[16]x[16]xi8> |
109 | /// --- |
110 | /// |
111 | /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store |
112 | /// (in-flight transpose): |
113 | /// |
114 | /// vector.transfer_write %vector, %source[%c0, %c0] |
115 | /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, |
116 | /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> |
117 | /// |
118 | /// is converted to: |
119 | /// |
120 | /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical> |
121 | /// : memref<?x?xi8>, vector<[16]x[16]xi8> |
122 | struct TransferWriteToArmSMELowering |
123 | : public OpRewritePattern<vector::TransferWriteOp> { |
124 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
125 | |
126 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
127 | PatternRewriter &rewriter) const final { |
128 | auto vType = writeOp.getVectorType(); |
129 | if (!arm_sme::isValidSMETileVectorType(vType: vType)) |
130 | return failure(); |
131 | |
132 | if (!llvm::isa<MemRefType>(writeOp.getSource().getType())) |
133 | return failure(); |
134 | |
135 | // Out-of-bounds dims are not supported. |
136 | if (writeOp.hasOutOfBoundsDim()) |
137 | return rewriter.notifyMatchFailure(writeOp, |
138 | "not inbounds transfer write" ); |
139 | |
140 | AffineExpr d0, d1; |
141 | bindDims(writeOp.getContext(), d0, d1); |
142 | AffineMap map = writeOp.getPermutationMap(); |
143 | bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, |
144 | writeOp.getContext())); |
145 | |
146 | if (!map.isIdentity() && !isTranspose) |
147 | return rewriter.notifyMatchFailure(writeOp, |
148 | "unsupported permutation map" ); |
149 | |
150 | arm_sme::TileSliceLayout layout = |
151 | isTranspose ? arm_sme::TileSliceLayout::Vertical |
152 | : arm_sme::TileSliceLayout::Horizontal; |
153 | |
154 | rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( |
155 | writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), |
156 | writeOp.getMask(), layout); |
157 | return success(); |
158 | } |
159 | }; |
160 | |
161 | /// Conversion pattern for vector.load. |
162 | struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> { |
163 | using OpRewritePattern<vector::LoadOp>::OpRewritePattern; |
164 | |
165 | LogicalResult matchAndRewrite(vector::LoadOp load, |
166 | PatternRewriter &rewriter) const override { |
167 | if (!arm_sme::isValidSMETileVectorType(vType: load.getVectorType())) |
168 | return failure(); |
169 | |
170 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
171 | load, load.getVectorType(), load.getBase(), load.getIndices()); |
172 | |
173 | return success(); |
174 | } |
175 | }; |
176 | |
177 | /// Conversion pattern for vector.store. |
178 | struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> { |
179 | using OpRewritePattern<vector::StoreOp>::OpRewritePattern; |
180 | |
181 | LogicalResult matchAndRewrite(vector::StoreOp store, |
182 | PatternRewriter &rewriter) const override { |
183 | if (!arm_sme::isValidSMETileVectorType(vType: store.getVectorType())) |
184 | return failure(); |
185 | |
186 | rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( |
187 | store, store.getValueToStore(), store.getBase(), store.getIndices()); |
188 | |
189 | return success(); |
190 | } |
191 | }; |
192 | |
193 | /// Conversion pattern for vector.broadcast. |
194 | /// |
195 | /// Example: |
196 | /// |
197 | /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> |
198 | /// |
199 | /// is converted to: |
200 | /// |
201 | /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> |
202 | /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices |
203 | /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) |
204 | /// { |
205 | /// %tile_update = arm_sme.move_vector_to_tile_slice |
206 | /// %broadcast_to_1d, %iter_tile, %tile_slice_index : |
207 | /// vector<[4]xi32> into vector<[4]x[4]xi32> |
208 | /// scf.yield %tile_update : vector<[4]x[4]xi32> |
209 | /// } |
210 | /// |
211 | /// Supports scalar, 0-d vector, and 1-d vector broadcasts. |
212 | struct BroadcastOpToArmSMELowering |
213 | : public OpRewritePattern<vector::BroadcastOp> { |
214 | using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; |
215 | |
216 | LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
217 | PatternRewriter &rewriter) const final { |
218 | auto tileType = broadcastOp.getResultVectorType(); |
219 | if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType)) |
220 | return failure(); |
221 | |
222 | auto loc = broadcastOp.getLoc(); |
223 | |
224 | auto srcType = broadcastOp.getSourceType(); |
225 | auto srcVectorType = dyn_cast<VectorType>(srcType); |
226 | |
227 | Value broadcastOp1D; |
228 | if (srcType.isIntOrFloat() || |
229 | (srcVectorType && (srcVectorType.getRank() == 0))) { |
230 | // Broadcast scalar or 0-d vector to 1-d vector. |
231 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
232 | broadcastOp1D = rewriter.create<vector::BroadcastOp>( |
233 | loc, tileSliceType, broadcastOp.getSource()); |
234 | } else if (srcVectorType && (srcVectorType.getRank() == 1)) |
235 | // Value to broadcast is already a 1-d vector, nothing to do. |
236 | broadcastOp1D = broadcastOp.getSource(); |
237 | else |
238 | return failure(); |
239 | |
240 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
241 | |
242 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
243 | Value currentTile) { |
244 | // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value |
245 | // to each tile slice. |
246 | auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>( |
247 | loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); |
248 | return nextTile.getResult(); |
249 | }; |
250 | |
251 | // Create a loop over ZA tile slices. |
252 | auto forOp = |
253 | createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); |
254 | |
255 | rewriter.replaceOp(broadcastOp, forOp.getResult(0)); |
256 | |
257 | return success(); |
258 | } |
259 | }; |
260 | |
261 | /// Conversion pattern for vector.splat. |
262 | /// |
263 | /// Example: |
264 | /// |
265 | /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> |
266 | /// |
267 | /// is converted to: |
268 | /// |
269 | /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> |
270 | /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices |
271 | /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) |
272 | /// { |
273 | /// %tile_update = arm_sme.move_vector_to_tile_slice |
274 | /// %broadcast_to_1d, %iter_tile, %tile_slice_index : |
275 | /// vector<[4]xi32> into vector<[4]x[4]xi32> |
276 | /// scf.yield %tile_update : vector<[4]x[4]xi32> |
277 | /// } |
278 | /// |
279 | /// This is identical to vector.broadcast of a scalar. |
280 | struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { |
281 | using OpRewritePattern<vector::SplatOp>::OpRewritePattern; |
282 | |
283 | LogicalResult matchAndRewrite(vector::SplatOp splatOp, |
284 | PatternRewriter &rewriter) const final { |
285 | auto tileType = splatOp.getResult().getType(); |
286 | if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType)) |
287 | return failure(); |
288 | |
289 | auto loc = splatOp.getLoc(); |
290 | auto srcType = splatOp.getOperand().getType(); |
291 | |
292 | assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat" ); |
293 | // Avoid unused-variable warning when building without assertions. |
294 | (void)srcType; |
295 | |
296 | // First, broadcast the scalar to a 1-d vector. |
297 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
298 | Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( |
299 | loc, tileSliceType, splatOp.getInput()); |
300 | |
301 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
302 | |
303 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
304 | Value currentTile) { |
305 | auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>( |
306 | loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); |
307 | return nextTile.getResult(); |
308 | }; |
309 | |
310 | // Next, create a loop over ZA tile slices and "move" the generated 1-d |
311 | // vector to each slice. |
312 | auto forOp = |
313 | createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); |
314 | |
315 | rewriter.replaceOp(splatOp, forOp.getResult(0)); |
316 | |
317 | return success(); |
318 | } |
319 | }; |
320 | |
321 | /// Conversion pattern for vector.transpose. |
322 | /// |
323 | /// Stores the input tile to memory and reloads vertically. |
324 | /// |
325 | /// Example: |
326 | /// |
327 | /// %transposed_src = vector.transpose %src, [1, 0] |
328 | /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> |
329 | /// |
330 | /// is converted to: |
331 | /// |
332 | /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32> |
333 | /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0] |
334 | /// : memref<?x?xi32>, vector<[4]x[4]xi32> |
335 | /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] |
336 | /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> |
337 | /// |
338 | /// NOTE: Tranposing via memory is obviously expensive, the current intention |
339 | /// is to avoid the transpose if possible, this is therefore intended as a |
340 | /// fallback and to provide base support for Vector ops. If it turns out |
341 | /// transposes can't be avoided then this should be replaced with a more optimal |
342 | /// implementation, perhaps with tile <-> vector (MOVA) ops. |
343 | struct TransposeOpToArmSMELowering |
344 | : public OpRewritePattern<vector::TransposeOp> { |
345 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
346 | |
347 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
348 | PatternRewriter &rewriter) const final { |
349 | auto tileType = transposeOp.getResultVectorType(); |
350 | if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType)) |
351 | return failure(); |
352 | |
353 | // Bail unless this is a true 2-D matrix transpose. |
354 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
355 | if (permutation[0] != 1 || permutation[1] != 0) |
356 | return failure(); |
357 | |
358 | auto loc = transposeOp.getLoc(); |
359 | |
360 | // Allocate buffer to store input tile to. |
361 | Value vscale = |
362 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
363 | Value minTileSlices = rewriter.create<arith::ConstantOp>( |
364 | loc, rewriter.getIndexAttr(tileType.getDimSize(0))); |
365 | Value c0 = |
366 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); |
367 | Value numTileSlices = |
368 | rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices); |
369 | auto bufferType = |
370 | MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, |
371 | tileType.getElementType()); |
372 | auto buffer = rewriter.create<memref::AllocaOp>( |
373 | loc, bufferType, ValueRange{numTileSlices, numTileSlices}); |
374 | |
375 | Value input = transposeOp.getVector(); |
376 | |
377 | // Store input tile. |
378 | auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>( |
379 | loc, input, buffer, ValueRange{c0, c0}); |
380 | |
381 | // Reload input tile vertically. |
382 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
383 | transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), |
384 | arm_sme::TileSliceLayout::Vertical); |
385 | |
386 | return success(); |
387 | } |
388 | }; |
389 | |
390 | /// Conversion pattern for vector.outerproduct. |
391 | /// |
392 | /// If the vector.outerproduct is masked (and the mask is from a |
393 | /// vector.create_mask), then the mask is decomposed into two 1-D masks for the |
394 | /// operands. |
395 | /// |
396 | /// Example: |
397 | /// |
398 | /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> |
399 | /// %result = vector.mask %mask { |
400 | /// vector.outerproduct %vecA, %vecB |
401 | /// : vector<[4]xf32>, vector<[4]xf32> |
402 | /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> |
403 | /// |
404 | /// is converted to: |
405 | /// |
406 | /// %maskA = vector.create_mask %dimA : vector<[4]xi1> |
407 | /// %maskB = vector.create_mask %dimB : vector<[4]xi1> |
408 | /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) |
409 | /// : vector<[4]xf32>, vector<[4]xf32> |
410 | /// |
411 | /// Unmasked outerproducts can be directly replaced with the arm_sme op. |
412 | /// |
413 | /// Example: |
414 | /// |
415 | /// %result = vector.outerproduct %vecA, %vecB |
416 | /// : vector<[4]xf32>, vector<[4]xf32> |
417 | /// |
418 | /// is converted to: |
419 | /// |
420 | /// %result = arm_sme.outerproduct %vecA, %vecB |
421 | /// : vector<[4]xf32>, vector<[4]xf32> |
422 | /// |
423 | struct VectorOuterProductToArmSMELowering |
424 | : public OpRewritePattern<vector::OuterProductOp> { |
425 | |
426 | using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; |
427 | |
428 | LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, |
429 | PatternRewriter &rewriter) const override { |
430 | |
431 | // We don't yet support lowering AXPY operations to SME. These could be |
432 | // lowered by masking out all but the first element of the LHS. |
433 | if (!isa<VectorType>(outerProductOp.getOperandTypeRHS())) |
434 | return rewriter.notifyMatchFailure(outerProductOp, |
435 | "AXPY operations not supported" ); |
436 | |
437 | if (!arm_sme::isValidSMETileVectorType( |
438 | vType: outerProductOp.getResultVectorType())) |
439 | return rewriter.notifyMatchFailure( |
440 | outerProductOp, "outer product does not fit into SME tile" ); |
441 | |
442 | auto kind = outerProductOp.getKind(); |
443 | if (kind != vector::CombiningKind::ADD) |
444 | return rewriter.notifyMatchFailure( |
445 | outerProductOp, |
446 | "unsupported kind (lowering to SME only supports ADD at the moment)" ); |
447 | |
448 | Value lhsMask = {}; |
449 | Value rhsMask = {}; |
450 | Operation *rootOp = outerProductOp; |
451 | auto loc = outerProductOp.getLoc(); |
452 | if (outerProductOp.isMasked()) { |
453 | auto maskOp = outerProductOp.getMaskingOp(); |
454 | rewriter.setInsertionPoint(maskOp); |
455 | rootOp = maskOp; |
456 | auto operandMasks = decomposeResultMask(loc: loc, mask: maskOp.getMask(), rewriter); |
457 | if (failed(operandMasks)) |
458 | return failure(); |
459 | std::tie(args&: lhsMask, args&: rhsMask) = *operandMasks; |
460 | } |
461 | |
462 | rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>( |
463 | rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), |
464 | outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); |
465 | |
466 | return success(); |
467 | } |
468 | |
469 | static FailureOr<std::pair<Value, Value>> |
470 | decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { |
471 | // Attempt to extract masks from vector.create_mask. |
472 | // TODO: Add support for other mask sources. |
473 | auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>(); |
474 | if (!createMaskOp) |
475 | return failure(); |
476 | |
477 | auto maskType = createMaskOp.getVectorType(); |
478 | Value lhsMaskDim = createMaskOp.getOperand(0); |
479 | Value rhsMaskDim = createMaskOp.getOperand(1); |
480 | |
481 | VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); |
482 | Value lhsMask = |
483 | rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim); |
484 | Value rhsMask = |
485 | rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim); |
486 | |
487 | return std::make_pair(x&: lhsMask, y&: rhsMask); |
488 | } |
489 | }; |
490 | |
491 | /// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`. |
492 | /// |
493 | /// Example: |
494 | /// ``` |
495 | /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> |
496 | /// ``` |
497 | /// Becomes: |
498 | /// ``` |
499 | /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row] |
500 | /// : vector<[4]xi32> from vector<[4]x[4]xi32> |
501 | /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> |
502 | /// ``` |
503 | struct |
504 | : public OpRewritePattern<vector::ExtractOp> { |
505 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
506 | |
507 | LogicalResult matchAndRewrite(vector::ExtractOp , |
508 | PatternRewriter &rewriter) const override { |
509 | VectorType sourceType = extractOp.getSourceVectorType(); |
510 | if (!arm_sme::isValidSMETileVectorType(vType: sourceType)) |
511 | return failure(); |
512 | |
513 | auto loc = extractOp.getLoc(); |
514 | auto position = extractOp.getMixedPosition(); |
515 | |
516 | Value sourceVector = extractOp.getVector(); |
517 | |
518 | // Extract entire vector. Should be handled by folder, but just to be safe. |
519 | if (position.empty()) { |
520 | rewriter.replaceOp(extractOp, sourceVector); |
521 | return success(); |
522 | } |
523 | |
524 | Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front(); |
525 | auto moveTileSliceToVector = |
526 | rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector, |
527 | sliceIndex); |
528 | |
529 | if (position.size() == 1) { |
530 | // Single index case: Extracts a 1D slice. |
531 | rewriter.replaceOp(extractOp, moveTileSliceToVector); |
532 | return success(); |
533 | } |
534 | |
535 | // Two indices case: Extracts a single element. |
536 | assert(position.size() == 2); |
537 | rewriter.replaceOpWithNewOp<vector::ExtractOp>( |
538 | extractOp, moveTileSliceToVector, position[1]); |
539 | |
540 | return success(); |
541 | } |
542 | }; |
543 | |
544 | /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and |
545 | /// `arm_sme.move_tile_slice_to_vector`. |
546 | /// |
547 | /// Example: |
548 | /// ``` |
549 | /// %new_tile = vector.insert %el, %tile[%row, %col] |
550 | /// : i32 into vector<[4]x[4]xi32> |
551 | /// ``` |
552 | /// Becomes: |
553 | /// ``` |
554 | /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row] |
555 | /// : vector<[4]xi32> from vector<[4]x[4]xi32> |
556 | /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> |
557 | /// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row |
558 | /// : vector<[4]xi32> into vector<[4]x[4]xi32> |
559 | /// ``` |
560 | struct VectorInsertToArmSMELowering |
561 | : public OpRewritePattern<vector::InsertOp> { |
562 | using OpRewritePattern<vector::InsertOp>::OpRewritePattern; |
563 | |
564 | LogicalResult matchAndRewrite(vector::InsertOp insertOp, |
565 | PatternRewriter &rewriter) const override { |
566 | VectorType resultType = insertOp.getResult().getType(); |
567 | |
568 | if (!arm_sme::isValidSMETileVectorType(vType: resultType)) |
569 | return failure(); |
570 | |
571 | auto loc = insertOp.getLoc(); |
572 | auto position = insertOp.getMixedPosition(); |
573 | |
574 | Value source = insertOp.getSource(); |
575 | |
576 | // Overwrite entire vector with value. Should be handled by folder, but |
577 | // just to be safe. |
578 | if (position.empty()) { |
579 | rewriter.replaceOp(insertOp, source); |
580 | return success(); |
581 | } |
582 | |
583 | Value tileSlice = source; |
584 | Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front(); |
585 | if (position.size() == 2) { |
586 | // Two indices case: Insert single element into tile. |
587 | // We need to first extract the existing slice and update the element. |
588 | tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>( |
589 | loc, insertOp.getDest(), sliceIndex); |
590 | tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice, |
591 | position[1]); |
592 | } |
593 | |
594 | // Insert the slice into the destination tile. |
595 | rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>( |
596 | insertOp, tileSlice, insertOp.getDest(), sliceIndex); |
597 | return success(); |
598 | } |
599 | }; |
600 | |
601 | /// Lowers `vector.print` of a tile into a loop over the rows of the tile, |
602 | /// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with |
603 | /// a 1D `vector.print`. |
604 | /// |
605 | /// BEFORE: |
606 | /// ```mlir |
607 | /// vector.print %tile : vector<[4]x[4]xf32> |
608 | /// ``` |
609 | /// AFTER: |
610 | /// ```mlir |
611 | /// %c0 = arith.constant 0 : index |
612 | /// %c1 = arith.constant 1 : index |
613 | /// %c4 = arith.constant 4 : index |
614 | /// %vscale = vector.vscale |
615 | /// %svl_s = arith.muli %c4, %vscale : index |
616 | /// scf.for %i = %c0 to %svl_s step %c1 { |
617 | /// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i] |
618 | /// : vector<[4]xf32> from vector<[4]x[4]xf32> |
619 | /// vector.print %tile_slice : vector<[4]xf32> |
620 | /// } |
621 | /// ``` |
622 | struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { |
623 | using OpRewritePattern<vector::PrintOp>::OpRewritePattern; |
624 | |
625 | LogicalResult matchAndRewrite(vector::PrintOp printOp, |
626 | PatternRewriter &rewriter) const override { |
627 | if (!printOp.getSource()) |
628 | return failure(); |
629 | |
630 | VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); |
631 | if (!vectorType || !arm_sme::isValidSMETileVectorType(vType: vectorType)) |
632 | return failure(); |
633 | |
634 | auto loc = printOp.getLoc(); |
635 | |
636 | // Create a loop over the rows of the tile. |
637 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
638 | auto minTileRows = |
639 | rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0)); |
640 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
641 | auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale); |
642 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
643 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
644 | { |
645 | // Loop body. |
646 | rewriter.setInsertionPointToStart(forOp.getBody()); |
647 | // Extract the current row from the tile. |
648 | Value rowIndex = forOp.getInductionVar(); |
649 | auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>( |
650 | loc, printOp.getSource(), rowIndex); |
651 | // Print the row with a 1D vector.print. |
652 | rewriter.create<vector::PrintOp>(loc, tileSlice, |
653 | printOp.getPunctuation()); |
654 | } |
655 | |
656 | rewriter.eraseOp(op: printOp); |
657 | return success(); |
658 | } |
659 | }; |
660 | |
661 | } // namespace |
662 | |
663 | void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, |
664 | MLIRContext &ctx) { |
665 | patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, |
666 | TransferReadToArmSMELowering, TransferWriteToArmSMELowering, |
667 | TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, |
668 | VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, |
669 | VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, |
670 | VectorPrintToArmSMELowering>(arg: &ctx); |
671 | } |
672 | |