1 | //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" |
10 | |
11 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
12 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
13 | #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | #include "mlir/IR/BuiltinTypes.h" |
16 | #include "llvm/Support/Casting.h" |
17 | |
18 | using namespace mlir; |
19 | |
20 | namespace { |
21 | |
22 | /// Conversion pattern for vector.transfer_read. |
23 | /// |
24 | /// --- |
25 | /// |
26 | /// Example 1: op with identity permutation map to horizontal |
27 | /// arm_sme.tile_load: |
28 | /// |
29 | /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) |
30 | /// |
31 | /// is converted to: |
32 | /// |
33 | /// arm_sme.tile_load ... |
34 | /// |
35 | /// --- |
36 | /// |
37 | /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load |
38 | /// (in-flight transpose): |
39 | /// |
40 | /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) |
41 | /// |
42 | /// is converted to: |
43 | /// |
44 | /// arm_sme.tile_load ... layout<vertical> |
45 | struct TransferReadToArmSMELowering |
46 | : public OpRewritePattern<vector::TransferReadOp> { |
47 | using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
48 | |
49 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
50 | PatternRewriter &rewriter) const final { |
51 | // The permutation map must have two results. |
52 | if (transferReadOp.getTransferRank() != 2) |
53 | return rewriter.notifyMatchFailure(transferReadOp, |
54 | "not a 2 result permutation map"); |
55 | |
56 | auto vectorType = transferReadOp.getVectorType(); |
57 | if (!arm_sme::isValidSMETileVectorType(vectorType)) |
58 | return rewriter.notifyMatchFailure(transferReadOp, |
59 | "not a valid vector type for SME"); |
60 | |
61 | if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType())) |
62 | return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); |
63 | |
64 | // Out-of-bounds dims are not supported. |
65 | if (transferReadOp.hasOutOfBoundsDim()) |
66 | return rewriter.notifyMatchFailure(transferReadOp, |
67 | "not inbounds transfer read"); |
68 | |
69 | AffineMap map = transferReadOp.getPermutationMap(); |
70 | if (!map.isPermutation()) |
71 | return rewriter.notifyMatchFailure(transferReadOp, |
72 | "unsupported permutation map"); |
73 | |
74 | // Note: For 2D vector types the only non-identity permutation is a simple |
75 | // transpose [1, 0]. |
76 | bool transposed = !map.isIdentity(); |
77 | arm_sme::TileSliceLayout layout = |
78 | transposed ? arm_sme::TileSliceLayout::Vertical |
79 | : arm_sme::TileSliceLayout::Horizontal; |
80 | |
81 | // Padding isn't optional for transfer_read, but is only used in the case |
82 | // of out-of-bounds accesses (not supported here) and/or masking. Mask is |
83 | // optional, if it's not present don't pass padding. |
84 | auto mask = transferReadOp.getMask(); |
85 | auto padding = mask ? transferReadOp.getPadding() : nullptr; |
86 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
87 | transferReadOp, vectorType, transferReadOp.getBase(), |
88 | transferReadOp.getIndices(), padding, mask, layout); |
89 | |
90 | return success(); |
91 | } |
92 | }; |
93 | |
94 | /// Conversion pattern for vector.transfer_write. |
95 | /// |
96 | /// --- |
97 | /// |
98 | /// Example 1: op with identity permutation map to horizontal |
99 | /// arm_sme.tile_store: |
100 | /// |
101 | /// vector.transfer_write %vector, %source[%c0, %c0] |
102 | /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> |
103 | /// |
104 | /// is converted to: |
105 | /// |
106 | /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, |
107 | /// vector<[16]x[16]xi8> |
108 | /// --- |
109 | /// |
110 | /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store |
111 | /// (in-flight transpose): |
112 | /// |
113 | /// vector.transfer_write %vector, %source[%c0, %c0] |
114 | /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, |
115 | /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> |
116 | /// |
117 | /// is converted to: |
118 | /// |
119 | /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical> |
120 | /// : memref<?x?xi8>, vector<[16]x[16]xi8> |
121 | struct TransferWriteToArmSMELowering |
122 | : public OpRewritePattern<vector::TransferWriteOp> { |
123 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
124 | |
125 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
126 | PatternRewriter &rewriter) const final { |
127 | auto vType = writeOp.getVectorType(); |
128 | if (!arm_sme::isValidSMETileVectorType(vType)) |
129 | return failure(); |
130 | |
131 | if (!llvm::isa<MemRefType>(writeOp.getBase().getType())) |
132 | return failure(); |
133 | |
134 | // Out-of-bounds dims are not supported. |
135 | if (writeOp.hasOutOfBoundsDim()) |
136 | return rewriter.notifyMatchFailure(writeOp, |
137 | "not inbounds transfer write"); |
138 | |
139 | AffineMap map = writeOp.getPermutationMap(); |
140 | if (!map.isPermutation()) |
141 | return rewriter.notifyMatchFailure(writeOp, |
142 | "unsupported permutation map"); |
143 | |
144 | // Note: For 2D vector types the only non-identity permutation is a simple |
145 | // transpose [1, 0]. |
146 | bool transposed = !map.isIdentity(); |
147 | arm_sme::TileSliceLayout layout = |
148 | transposed ? arm_sme::TileSliceLayout::Vertical |
149 | : arm_sme::TileSliceLayout::Horizontal; |
150 | |
151 | rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( |
152 | writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(), |
153 | writeOp.getMask(), layout); |
154 | return success(); |
155 | } |
156 | }; |
157 | |
158 | /// Conversion pattern for vector.load. |
159 | struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> { |
160 | using OpRewritePattern<vector::LoadOp>::OpRewritePattern; |
161 | |
162 | LogicalResult matchAndRewrite(vector::LoadOp load, |
163 | PatternRewriter &rewriter) const override { |
164 | if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) |
165 | return failure(); |
166 | |
167 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
168 | load, load.getVectorType(), load.getBase(), load.getIndices()); |
169 | |
170 | return success(); |
171 | } |
172 | }; |
173 | |
174 | /// Conversion pattern for vector.store. |
175 | struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> { |
176 | using OpRewritePattern<vector::StoreOp>::OpRewritePattern; |
177 | |
178 | LogicalResult matchAndRewrite(vector::StoreOp store, |
179 | PatternRewriter &rewriter) const override { |
180 | if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) |
181 | return failure(); |
182 | |
183 | rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( |
184 | store, store.getValueToStore(), store.getBase(), store.getIndices()); |
185 | |
186 | return success(); |
187 | } |
188 | }; |
189 | |
190 | /// Conversion pattern for vector.broadcast. |
191 | /// |
192 | /// Example: |
193 | /// |
194 | /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> |
195 | /// |
196 | /// is converted to: |
197 | /// |
198 | /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> |
199 | /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices |
200 | /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) |
201 | /// { |
202 | /// %tile_update = arm_sme.insert_tile_slice |
203 | /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : |
204 | /// vector<[4]xi32> into vector<[4]x[4]xi32> |
205 | /// scf.yield %tile_update : vector<[4]x[4]xi32> |
206 | /// } |
207 | /// |
208 | /// Supports scalar, 0-d vector, and 1-d vector broadcasts. |
209 | struct BroadcastOpToArmSMELowering |
210 | : public OpRewritePattern<vector::BroadcastOp> { |
211 | using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; |
212 | |
213 | LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
214 | PatternRewriter &rewriter) const final { |
215 | auto tileType = broadcastOp.getResultVectorType(); |
216 | if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
217 | return failure(); |
218 | |
219 | auto loc = broadcastOp.getLoc(); |
220 | |
221 | auto srcType = broadcastOp.getSourceType(); |
222 | auto srcVectorType = dyn_cast<VectorType>(srcType); |
223 | |
224 | Value broadcastOp1D; |
225 | if (srcType.isIntOrFloat() || |
226 | (srcVectorType && (srcVectorType.getRank() == 0))) { |
227 | // Broadcast scalar or 0-d vector to 1-d vector. |
228 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
229 | broadcastOp1D = rewriter.create<vector::BroadcastOp>( |
230 | loc, tileSliceType, broadcastOp.getSource()); |
231 | } else if (srcVectorType && (srcVectorType.getRank() == 1)) |
232 | // Value to broadcast is already a 1-d vector, nothing to do. |
233 | broadcastOp1D = broadcastOp.getSource(); |
234 | else |
235 | return failure(); |
236 | |
237 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
238 | |
239 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
240 | Value currentTile) { |
241 | // Create 'arm_sme.insert_tile_slice' to broadcast the value |
242 | // to each tile slice. |
243 | auto nextTile = b.create<arm_sme::InsertTileSliceOp>( |
244 | loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); |
245 | return nextTile.getResult(); |
246 | }; |
247 | |
248 | // Create a loop over ZA tile slices. |
249 | auto forOp = |
250 | createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); |
251 | |
252 | rewriter.replaceOp(broadcastOp, forOp.getResult(0)); |
253 | |
254 | return success(); |
255 | } |
256 | }; |
257 | |
258 | /// Conversion pattern for vector.splat. |
259 | /// |
260 | /// Example: |
261 | /// |
262 | /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> |
263 | /// |
264 | /// is converted to: |
265 | /// |
266 | /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> |
267 | /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices |
268 | /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) |
269 | /// { |
270 | /// %tile_update = arm_sme.insert_tile_slice |
271 | /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : |
272 | /// vector<[4]xi32> into vector<[4]x[4]xi32> |
273 | /// scf.yield %tile_update : vector<[4]x[4]xi32> |
274 | /// } |
275 | /// |
276 | /// This is identical to vector.broadcast of a scalar. |
277 | struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { |
278 | using OpRewritePattern<vector::SplatOp>::OpRewritePattern; |
279 | |
280 | LogicalResult matchAndRewrite(vector::SplatOp splatOp, |
281 | PatternRewriter &rewriter) const final { |
282 | auto tileType = splatOp.getResult().getType(); |
283 | if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
284 | return failure(); |
285 | |
286 | auto loc = splatOp.getLoc(); |
287 | auto srcType = splatOp.getOperand().getType(); |
288 | |
289 | assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); |
290 | // Avoid unused-variable warning when building without assertions. |
291 | (void)srcType; |
292 | |
293 | // First, broadcast the scalar to a 1-d vector. |
294 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
295 | Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( |
296 | loc, tileSliceType, splatOp.getInput()); |
297 | |
298 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
299 | |
300 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
301 | Value currentTile) { |
302 | auto nextTile = b.create<arm_sme::InsertTileSliceOp>( |
303 | loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); |
304 | return nextTile.getResult(); |
305 | }; |
306 | |
307 | // Next, create a loop over ZA tile slices and "move" the generated 1-d |
308 | // vector to each slice. |
309 | auto forOp = |
310 | createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); |
311 | |
312 | rewriter.replaceOp(splatOp, forOp.getResult(0)); |
313 | |
314 | return success(); |
315 | } |
316 | }; |
317 | |
318 | /// Conversion pattern for vector.transpose. |
319 | /// |
320 | /// Stores the input tile to memory and reloads vertically. |
321 | /// |
322 | /// Example: |
323 | /// |
324 | /// %transposed_src = vector.transpose %src, [1, 0] |
325 | /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> |
326 | /// |
327 | /// is converted to: |
328 | /// |
329 | /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32> |
330 | /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0] |
331 | /// : memref<?x?xi32>, vector<[4]x[4]xi32> |
332 | /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] |
333 | /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> |
334 | /// |
335 | /// NOTE: Transposing via memory is obviously expensive, the current intention |
336 | /// is to avoid the transpose if possible, this is therefore intended as a |
337 | /// fallback and to provide base support for Vector ops. If it turns out |
338 | /// transposes can't be avoided then this should be replaced with a more optimal |
339 | /// implementation, perhaps with tile <-> vector (MOVA) ops. |
340 | struct TransposeOpToArmSMELowering |
341 | : public OpRewritePattern<vector::TransposeOp> { |
342 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
343 | |
344 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
345 | PatternRewriter &rewriter) const final { |
346 | auto tileType = transposeOp.getResultVectorType(); |
347 | if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
348 | return failure(); |
349 | |
350 | // Bail unless this is a true 2-D matrix transpose. |
351 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
352 | if (permutation[0] != 1 || permutation[1] != 0) |
353 | return failure(); |
354 | |
355 | auto loc = transposeOp.getLoc(); |
356 | Value input = transposeOp.getVector(); |
357 | |
358 | if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>(); |
359 | xferOp && xferOp->hasOneUse()) { |
360 | // Fold transpose into transfer_read to enable in-flight transpose when |
361 | // converting to arm_sme.tile_load. |
362 | rewriter.modifyOpInPlace(xferOp, [&]() { |
363 | xferOp->setAttr(xferOp.getPermutationMapAttrName(), |
364 | AffineMapAttr::get(AffineMap::getPermutationMap( |
365 | permutation, transposeOp.getContext()))); |
366 | }); |
367 | rewriter.replaceOp(transposeOp, xferOp); |
368 | return success(); |
369 | } |
370 | |
371 | // Allocate buffer to store input tile to. |
372 | Value vscale = |
373 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
374 | Value minTileSlices = rewriter.create<arith::ConstantOp>( |
375 | loc, rewriter.getIndexAttr(tileType.getDimSize(0))); |
376 | Value c0 = |
377 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); |
378 | Value numTileSlices = |
379 | rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices); |
380 | auto bufferType = |
381 | MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, |
382 | tileType.getElementType()); |
383 | auto buffer = rewriter.create<memref::AllocaOp>( |
384 | loc, bufferType, ValueRange{numTileSlices, numTileSlices}); |
385 | |
386 | // Store input tile. |
387 | auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>( |
388 | loc, input, buffer, ValueRange{c0, c0}); |
389 | |
390 | // Reload input tile vertically. |
391 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
392 | transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), |
393 | arm_sme::TileSliceLayout::Vertical); |
394 | |
395 | return success(); |
396 | } |
397 | }; |
398 | |
399 | /// Conversion pattern for vector.outerproduct. |
400 | /// |
401 | /// If the vector.outerproduct is masked (and the mask is from a |
402 | /// vector.create_mask), then the mask is decomposed into two 1-D masks for the |
403 | /// operands. |
404 | /// |
405 | /// Example: |
406 | /// |
407 | /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> |
408 | /// %result = vector.mask %mask { |
409 | /// vector.outerproduct %vecA, %vecB |
410 | /// : vector<[4]xf32>, vector<[4]xf32> |
411 | /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> |
412 | /// |
413 | /// is converted to: |
414 | /// |
415 | /// %maskA = vector.create_mask %dimA : vector<[4]xi1> |
416 | /// %maskB = vector.create_mask %dimB : vector<[4]xi1> |
417 | /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) |
418 | /// : vector<[4]xf32>, vector<[4]xf32> |
419 | /// |
420 | /// Unmasked outerproducts can be directly replaced with the arm_sme op. |
421 | /// |
422 | /// Example: |
423 | /// |
424 | /// %result = vector.outerproduct %vecA, %vecB |
425 | /// : vector<[4]xf32>, vector<[4]xf32> |
426 | /// |
427 | /// is converted to: |
428 | /// |
429 | /// %result = arm_sme.outerproduct %vecA, %vecB |
430 | /// : vector<[4]xf32>, vector<[4]xf32> |
431 | /// |
432 | struct VectorOuterProductToArmSMELowering |
433 | : public OpRewritePattern<vector::OuterProductOp> { |
434 | |
435 | using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; |
436 | |
437 | LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, |
438 | PatternRewriter &rewriter) const override { |
439 | |
440 | // We don't yet support lowering AXPY operations to SME. These could be |
441 | // lowered by masking out all but the first element of the LHS. |
442 | if (!isa<VectorType>(outerProductOp.getOperandTypeRHS())) |
443 | return rewriter.notifyMatchFailure(outerProductOp, |
444 | "AXPY operations not supported"); |
445 | |
446 | if (!arm_sme::isValidSMETileVectorType( |
447 | outerProductOp.getResultVectorType())) |
448 | return rewriter.notifyMatchFailure( |
449 | outerProductOp, "outer product does not fit into SME tile"); |
450 | |
451 | auto kind = outerProductOp.getKind(); |
452 | if (kind != vector::CombiningKind::ADD) |
453 | return rewriter.notifyMatchFailure( |
454 | outerProductOp, |
455 | "unsupported kind (lowering to SME only supports ADD at the moment)"); |
456 | |
457 | Value lhsMask = {}; |
458 | Value rhsMask = {}; |
459 | Operation *rootOp = outerProductOp; |
460 | auto loc = outerProductOp.getLoc(); |
461 | if (outerProductOp.isMasked()) { |
462 | auto maskOp = outerProductOp.getMaskingOp(); |
463 | rewriter.setInsertionPoint(maskOp); |
464 | rootOp = maskOp; |
465 | auto operandMasks = decomposeResultMask(loc: loc, mask: maskOp.getMask(), rewriter); |
466 | if (failed(operandMasks)) |
467 | return failure(); |
468 | std::tie(args&: lhsMask, args&: rhsMask) = *operandMasks; |
469 | } |
470 | |
471 | rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>( |
472 | rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), |
473 | outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); |
474 | |
475 | return success(); |
476 | } |
477 | |
478 | static FailureOr<std::pair<Value, Value>> |
479 | decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { |
480 | // Attempt to extract masks from vector.create_mask. |
481 | // TODO: Add support for other mask sources. |
482 | auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>(); |
483 | if (!createMaskOp) |
484 | return failure(); |
485 | |
486 | auto maskType = createMaskOp.getVectorType(); |
487 | Value lhsMaskDim = createMaskOp.getOperand(0); |
488 | Value rhsMaskDim = createMaskOp.getOperand(1); |
489 | |
490 | VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); |
491 | Value lhsMask = |
492 | rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim); |
493 | Value rhsMask = |
494 | rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim); |
495 | |
496 | return std::make_pair(x&: lhsMask, y&: rhsMask); |
497 | } |
498 | }; |
499 | |
500 | /// Lower `vector.extract` using `arm_sme.extract_tile_slice`. |
501 | /// |
502 | /// Example: |
503 | /// ``` |
504 | /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> |
505 | /// ``` |
506 | /// Becomes: |
507 | /// ``` |
508 | /// %slice = arm_sme.extract_tile_slice %tile[%row] |
509 | /// : vector<[4]xi32> from vector<[4]x[4]xi32> |
510 | /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> |
511 | /// ``` |
512 | struct VectorExtractToArmSMELowering |
513 | : public OpRewritePattern<vector::ExtractOp> { |
514 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
515 | |
516 | LogicalResult matchAndRewrite(vector::ExtractOp extractOp, |
517 | PatternRewriter &rewriter) const override { |
518 | VectorType sourceType = extractOp.getSourceVectorType(); |
519 | if (!arm_sme::isValidSMETileVectorType(sourceType)) |
520 | return failure(); |
521 | |
522 | auto loc = extractOp.getLoc(); |
523 | auto position = extractOp.getMixedPosition(); |
524 | |
525 | Value sourceVector = extractOp.getVector(); |
526 | |
527 | // Extract entire vector. Should be handled by folder, but just to be safe. |
528 | if (position.empty()) { |
529 | rewriter.replaceOp(extractOp, sourceVector); |
530 | return success(); |
531 | } |
532 | |
533 | Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front(); |
534 | auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( |
535 | loc, sourceVector, sliceIndex); |
536 | |
537 | if (position.size() == 1) { |
538 | // Single index case: Extracts a 1D slice. |
539 | rewriter.replaceOp(extractOp, extractTileSlice); |
540 | return success(); |
541 | } |
542 | |
543 | // Two indices case: Extracts a single element. |
544 | assert(position.size() == 2); |
545 | rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice, |
546 | position[1]); |
547 | |
548 | return success(); |
549 | } |
550 | }; |
551 | |
552 | /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and |
553 | /// `arm_sme.extract_tile_slice`. |
554 | /// |
555 | /// Example: |
556 | /// ``` |
557 | /// %new_tile = vector.insert %el, %tile[%row, %col] |
558 | /// : i32 into vector<[4]x[4]xi32> |
559 | /// ``` |
560 | /// Becomes: |
561 | /// ``` |
562 | /// %slice = arm_sme.extract_tile_slice %tile[%row] |
563 | /// : vector<[4]xi32> from vector<[4]x[4]xi32> |
564 | /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> |
565 | /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] |
566 | /// : vector<[4]xi32> into vector<[4]x[4]xi32> |
567 | /// ``` |
568 | struct VectorInsertToArmSMELowering |
569 | : public OpRewritePattern<vector::InsertOp> { |
570 | using OpRewritePattern<vector::InsertOp>::OpRewritePattern; |
571 | |
572 | LogicalResult matchAndRewrite(vector::InsertOp insertOp, |
573 | PatternRewriter &rewriter) const override { |
574 | VectorType resultType = insertOp.getResult().getType(); |
575 | |
576 | if (!arm_sme::isValidSMETileVectorType(resultType)) |
577 | return failure(); |
578 | |
579 | auto loc = insertOp.getLoc(); |
580 | auto position = insertOp.getMixedPosition(); |
581 | |
582 | Value source = insertOp.getValueToStore(); |
583 | |
584 | // Overwrite entire vector with value. Should be handled by folder, but |
585 | // just to be safe. |
586 | if (position.empty()) { |
587 | rewriter.replaceOp(insertOp, source); |
588 | return success(); |
589 | } |
590 | |
591 | Value tileSlice = source; |
592 | Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front(); |
593 | if (position.size() == 2) { |
594 | // Two indices case: Insert single element into tile. |
595 | // We need to first extract the existing slice and update the element. |
596 | tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( |
597 | loc, insertOp.getDest(), sliceIndex); |
598 | tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice, |
599 | position[1]); |
600 | } |
601 | |
602 | // Insert the slice into the destination tile. |
603 | rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>( |
604 | insertOp, tileSlice, insertOp.getDest(), sliceIndex); |
605 | return success(); |
606 | } |
607 | }; |
608 | |
609 | /// Lowers `vector.print` of a tile into a loop over the rows of the tile, |
610 | /// extracting them via `arm_sme.extract_tile_slice`, then printing with |
611 | /// a 1D `vector.print`. |
612 | /// |
613 | /// BEFORE: |
614 | /// ```mlir |
615 | /// vector.print %tile : vector<[4]x[4]xf32> |
616 | /// ``` |
617 | /// AFTER: |
618 | /// ```mlir |
619 | /// %c0 = arith.constant 0 : index |
620 | /// %c1 = arith.constant 1 : index |
621 | /// %c4 = arith.constant 4 : index |
622 | /// %vscale = vector.vscale |
623 | /// %svl_s = arith.muli %c4, %vscale : index |
624 | /// scf.for %i = %c0 to %svl_s step %c1 { |
625 | /// %tile_slice = arm_sme.extract_tile_slice %tile[%i] |
626 | /// : vector<[4]xf32> from vector<[4]x[4]xf32> |
627 | /// vector.print %tile_slice : vector<[4]xf32> |
628 | /// } |
629 | /// ``` |
630 | struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { |
631 | using OpRewritePattern<vector::PrintOp>::OpRewritePattern; |
632 | |
633 | LogicalResult matchAndRewrite(vector::PrintOp printOp, |
634 | PatternRewriter &rewriter) const override { |
635 | if (!printOp.getSource()) |
636 | return failure(); |
637 | |
638 | VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); |
639 | if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType)) |
640 | return failure(); |
641 | |
642 | auto loc = printOp.getLoc(); |
643 | |
644 | // Create a loop over the rows of the tile. |
645 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
646 | auto minTileRows = |
647 | rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0)); |
648 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
649 | auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale); |
650 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
651 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
652 | { |
653 | // Loop body. |
654 | rewriter.setInsertionPointToStart(forOp.getBody()); |
655 | // Extract the current row from the tile. |
656 | Value rowIndex = forOp.getInductionVar(); |
657 | auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( |
658 | loc, printOp.getSource(), rowIndex); |
659 | // Print the row with a 1D vector.print. |
660 | rewriter.create<vector::PrintOp>(loc, tileSlice, |
661 | printOp.getPunctuation()); |
662 | } |
663 | |
664 | rewriter.eraseOp(op: printOp); |
665 | return success(); |
666 | } |
667 | }; |
668 | |
669 | /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp. |
670 | /// |
671 | /// BEFORE: |
672 | /// ```mlir |
673 | /// %slice = arm_sme.extract_tile_slice %tile[%index] |
674 | /// : vector<[4]xf32> from vector<[4]x[4]xf32> |
675 | /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} |
676 | /// : vector<[4]xf32>, memref<?x?xf32> |
677 | /// ``` |
678 | /// AFTER: |
679 | /// ```mlir |
680 | /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] |
681 | /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> |
682 | /// ``` |
683 | struct FoldTransferWriteOfExtractTileSlice |
684 | : public OpRewritePattern<vector::TransferWriteOp> { |
685 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
686 | |
687 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
688 | PatternRewriter &rewriter) const final { |
689 | if (!isa<MemRefType>(writeOp.getBase().getType())) |
690 | return rewriter.notifyMatchFailure(writeOp, "destination not a memref"); |
691 | |
692 | if (writeOp.hasOutOfBoundsDim()) |
693 | return rewriter.notifyMatchFailure(writeOp, |
694 | "not inbounds transfer write"); |
695 | |
696 | auto extractTileSlice = |
697 | writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>(); |
698 | if (!extractTileSlice) |
699 | return rewriter.notifyMatchFailure( |
700 | writeOp, "vector to store not from ExtractTileSliceOp"); |
701 | |
702 | AffineMap map = writeOp.getPermutationMap(); |
703 | if (!map.isMinorIdentity()) |
704 | return rewriter.notifyMatchFailure(writeOp, |
705 | "unsupported permutation map"); |
706 | |
707 | Value mask = writeOp.getMask(); |
708 | if (!mask) { |
709 | auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); |
710 | mask = rewriter.create<arith::ConstantOp>( |
711 | writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); |
712 | } |
713 | |
714 | rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>( |
715 | writeOp, extractTileSlice.getTile(), |
716 | extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(), |
717 | writeOp.getIndices(), extractTileSlice.getLayout()); |
718 | return success(); |
719 | } |
720 | }; |
721 | |
722 | /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to |
723 | /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or |
724 | /// SVE 2.1), so this is currently the most logical place for this lowering. |
725 | /// |
726 | /// Example: |
727 | /// ```mlir |
728 | /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> |
729 | /// %slice = vector.extract %mask[%index] |
730 | /// : vector<[8]xi1> from vector<[4]x[8]xi1> |
731 | /// ``` |
732 | /// Becomes: |
733 | /// ``` |
734 | /// %mask_rows = vector.create_mask %a : vector<[4]xi1> |
735 | /// %mask_cols = vector.create_mask %b : vector<[8]xi1> |
736 | /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index] |
737 | /// : vector<[8]xi1>, vector<[4]xi1> |
738 | /// ``` |
739 | struct ExtractFromCreateMaskToPselLowering |
740 | : public OpRewritePattern<vector::ExtractOp> { |
741 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
742 | |
743 | LogicalResult matchAndRewrite(vector::ExtractOp extractOp, |
744 | PatternRewriter &rewriter) const override { |
745 | if (extractOp.getNumIndices() != 1) |
746 | return rewriter.notifyMatchFailure(extractOp, "not single extract index"); |
747 | |
748 | auto resultType = extractOp.getResult().getType(); |
749 | auto resultVectorType = dyn_cast<VectorType>(resultType); |
750 | if (!resultVectorType) |
751 | return rewriter.notifyMatchFailure(extractOp, "result not VectorType"); |
752 | |
753 | auto createMaskOp = |
754 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
755 | if (!createMaskOp) |
756 | return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp"); |
757 | |
758 | auto maskType = createMaskOp.getVectorType(); |
759 | if (maskType.getRank() != 2 || !maskType.allDimsScalable()) |
760 | return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask"); |
761 | |
762 | auto isSVEPredicateSize = [](int64_t size) { |
763 | return size > 0 && size <= 16 && llvm::isPowerOf2_32(Value: uint32_t(size)); |
764 | }; |
765 | |
766 | auto rowsBaseSize = maskType.getDimSize(0); |
767 | auto colsBaseSize = maskType.getDimSize(1); |
768 | if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize)) |
769 | return rewriter.notifyMatchFailure( |
770 | createMaskOp, "mask dimensions not SVE predicate-sized"); |
771 | |
772 | auto loc = extractOp.getLoc(); |
773 | VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1); |
774 | VectorType colMaskType = VectorType::Builder(maskType).dropDim(0); |
775 | |
776 | // Create the two 1-D masks at the location of the 2-D create_mask (which is |
777 | // usually outside a loop). This prevents the need for later hoisting. |
778 | rewriter.setInsertionPoint(createMaskOp); |
779 | auto rowMask = rewriter.create<vector::CreateMaskOp>( |
780 | loc, rowMaskType, createMaskOp.getOperand(0)); |
781 | auto colMask = rewriter.create<vector::CreateMaskOp>( |
782 | loc, colMaskType, createMaskOp.getOperand(1)); |
783 | |
784 | rewriter.setInsertionPoint(extractOp); |
785 | auto position = |
786 | vector::getAsValues(builder&: rewriter, loc: loc, foldResults: extractOp.getMixedPosition()); |
787 | rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask, |
788 | position[0]); |
789 | return success(); |
790 | } |
791 | }; |
792 | |
793 | } // namespace |
794 | |
795 | void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, |
796 | MLIRContext &ctx) { |
797 | patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, |
798 | TransferReadToArmSMELowering, TransferWriteToArmSMELowering, |
799 | TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, |
800 | VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, |
801 | VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, |
802 | VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, |
803 | ExtractFromCreateMaskToPselLowering>(arg: &ctx); |
804 | } |
805 |
Definitions
- TransferReadToArmSMELowering
- matchAndRewrite
- TransferWriteToArmSMELowering
- matchAndRewrite
- VectorLoadToArmSMELowering
- matchAndRewrite
- VectorStoreToArmSMELowering
- matchAndRewrite
- BroadcastOpToArmSMELowering
- matchAndRewrite
- SplatOpToArmSMELowering
- matchAndRewrite
- TransposeOpToArmSMELowering
- matchAndRewrite
- VectorOuterProductToArmSMELowering
- matchAndRewrite
- decomposeResultMask
- VectorExtractToArmSMELowering
- matchAndRewrite
- VectorInsertToArmSMELowering
- matchAndRewrite
- VectorPrintToArmSMELowering
- matchAndRewrite
- FoldTransferWriteOfExtractTileSlice
- matchAndRewrite
- ExtractFromCreateMaskToPselLowering
- matchAndRewrite
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more