1//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to LLVM intrinsics.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
14
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19#include "mlir/Dialect/ArmSME/Utils/Utils.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34namespace {
35
36static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
37
38/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
39static Operation *createLoadTileSliceIntrinsic(
40 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
41 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
42 IntegerAttr tileId, Value tileSliceI32) {
43 if (layout == arm_sme::TileSliceLayout::Horizontal) {
44 switch (type) {
45 case arm_sme::ArmSMETileType::ZAB:
46 return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
47 loc, maskOp, ptr, tileId, tileSliceI32);
48 case arm_sme::ArmSMETileType::ZAH:
49 return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
50 loc, maskOp, ptr, tileId, tileSliceI32);
51 case arm_sme::ArmSMETileType::ZAS:
52 return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
53 loc, maskOp, ptr, tileId, tileSliceI32);
54 case arm_sme::ArmSMETileType::ZAD:
55 return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
56 loc, maskOp, ptr, tileId, tileSliceI32);
57 case arm_sme::ArmSMETileType::ZAQ:
58 return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
59 loc, maskOp, ptr, tileId, tileSliceI32);
60 }
61 } else {
62 switch (type) {
63 case arm_sme::ArmSMETileType::ZAB:
64 return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
65 loc, maskOp, ptr, tileId, tileSliceI32);
66 case arm_sme::ArmSMETileType::ZAH:
67 return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
68 loc, maskOp, ptr, tileId, tileSliceI32);
69 case arm_sme::ArmSMETileType::ZAS:
70 return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
71 loc, maskOp, ptr, tileId, tileSliceI32);
72 case arm_sme::ArmSMETileType::ZAD:
73 return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
74 loc, maskOp, ptr, tileId, tileSliceI32);
75 case arm_sme::ArmSMETileType::ZAQ:
76 return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
77 loc, maskOp, ptr, tileId, tileSliceI32);
78 break;
79 }
80 }
81}
82
83/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
84static Operation *createStoreTileSliceIntrinsic(
85 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
86 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
87 IntegerAttr tileId, Value tileSliceI32) {
88 if (layout == arm_sme::TileSliceLayout::Horizontal) {
89 switch (type) {
90 case arm_sme::ArmSMETileType::ZAB:
91 return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
92 loc, maskOp, ptr, tileId, tileSliceI32);
93 case arm_sme::ArmSMETileType::ZAH:
94 return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
95 loc, maskOp, ptr, tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAS:
97 return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
98 loc, maskOp, ptr, tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAD:
100 return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
101 loc, maskOp, ptr, tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAQ:
103 return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
104 loc, maskOp, ptr, tileId, tileSliceI32);
105 }
106 } else {
107 switch (type) {
108 case arm_sme::ArmSMETileType::ZAB:
109 return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
110 loc, maskOp, ptr, tileId, tileSliceI32);
111 case arm_sme::ArmSMETileType::ZAH:
112 return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
113 loc, maskOp, ptr, tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAS:
115 return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
116 loc, maskOp, ptr, tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAD:
118 return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
119 loc, maskOp, ptr, tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAQ:
121 return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
122 loc, maskOp, ptr, tileId, tileSliceI32);
123 }
124 }
125}
126
127IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
128 auto tileId = op.getTileId();
129 if (!tileId)
130 op.emitOpError(
131 "expected tile ID to be allocated before conversion to LLVM");
132 return tileId;
133}
134
135/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
136/// placed in the first block of the function.
137static memref::AllocaOp
138createAllocaForTile(RewriterBase &rewriter, Location loc,
139 FunctionOpInterface func,
140 arm_sme::ArmSMETileOpInterface tileOp) {
141 RewriterBase::InsertionGuard g(rewriter);
142 // Move to the first operation in the function.
143 rewriter.setInsertionPointToStart(&func.getBlocks().front());
144 // Create an alloca matching the tile size of the `tileOp`.
145 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
146 auto tileElementType = tileOp.getTileType().getElementType();
147 auto memrefType = MemRefType::get(
148 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
149 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(type: tileElementType);
150 auto minElementsOp =
151 rewriter.create<arith::ConstantIndexOp>(loc, minElements);
152 auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
153 auto alloca = rewriter.create<memref::AllocaOp>(
154 loc, memrefType, ValueRange{vectorLen, vectorLen});
155 return alloca;
156}
157
158/// Finds or creates an alloca for a spill of a tile.
159static memref::AllocaOp getOrCreateAllocaForTile(
160 RewriterBase &rewriter, Location loc, FunctionOpInterface func,
161 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
162 // Find an alloca at the top of the function tagged with a
163 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
164 for (auto &op : func.getBlocks().front()) {
165 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
166 if (!alloca)
167 continue;
168 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
169 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
170 if (!inMemoryTileId)
171 continue;
172 if (inMemoryTileId.getInt() == tileId)
173 return alloca;
174 }
175 // Otherwise, create a new alloca:
176 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
177 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
178 rewriter.getI32IntegerAttr(tileId));
179 return alloca;
180}
181
182/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
183/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
184/// the op to tile 0, then emitting a full tile swap between ZA and memory
185/// before + after the tile op.
186///
187/// Example:
188///
189/// // Note: <IN MEMORY TILE> = tile ID >= 16.
190/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
191///
192/// is converted to:
193/// // At function entry:
194/// %spill = memref.alloca ... : memref<?x?xty>
195///
196/// // Around op:
197/// scf.for %slice_idx {
198/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
199/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
200/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
201/// }
202/// arm_sme.tile_op { tile_id = 0 }
203/// scf.for %slice_idx {
204/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
205/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
206/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
207/// }
208///
209/// Note that these spills/fills are not inserted earlier as concept of a
210/// register, and the need to swap the contents, can't really be represented
211/// correctly at a high level in MLIR.
212///
213/// TODO: Reduce the spills/reloads to single slices where possible (and omit
214/// redundant reloads). This could be done via a method on the
215/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
216///
217/// `tileOp.getZaUsage()` could return:
218///
219/// struct ArmSMEOpZAUsage {
220/// enum class Kind {
221/// TileRead, // Omit store after tile operation.
222/// TileWrite, // Omit load before tile operation.
223/// TileReadWrite, // Needs both tile load and store.
224/// SliceRead, // Spill single slice and omit store after operation.
225/// SliceWrite, // Spill single slice and omit load before operation.
226/// SliceReadWrite // Spill single slice.
227/// };
228/// Value sliceIndex {};
229/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
230/// };
231///
232struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
233
234 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
235 const LLVMTypeConverter &typeConverter,
236 PatternBenefit benefit)
237 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
238 typeConverter, benefit) {}
239
240 LogicalResult
241 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
242 ConversionPatternRewriter &rewriter) const override {
243 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
244 // Tile has a real (hardware) tile. No spills/reloads required.
245 if (!tileOp.isInMemoryTile())
246 return failure();
247
248 // Step 1. Create an alloca for the tile at the top of the function (if one
249 // does not already exist).
250 auto loc = tileOp.getLoc();
251 auto func = tileOp->getParentOfType<FunctionOpInterface>();
252 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
253 tileOp.getTileId().getInt());
254
255 // Step 2. Assign the op a real tile ID.
256 // For simplicity, we always use tile 0 (which always exists).
257 auto zeroTileId = rewriter.getI32IntegerAttr(0);
258 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
259
260 VectorType tileVectorType = tileOp.getTileType();
261 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
262 auto swapInMemoryTileWithSMETileZero = [&] {
263 emitFullTileSwap(rewriter, loc, tileAlloca,
264 *arm_sme::getSMETileType(tileVectorType), sliceType,
265 zeroTileId);
266 };
267
268 // Step 3. Emit tile swaps before and after the op.
269 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
270 // touches (i.e. a single tile slice).
271 {
272 rewriter.setInsertionPoint(op);
273 // Swap the contents of ZA and the in-memory tile before the op.
274 swapInMemoryTileWithSMETileZero();
275 rewriter.setInsertionPointAfter(op);
276 // Swap the tile back out to memory again after the op.
277 swapInMemoryTileWithSMETileZero();
278 }
279
280 return success();
281 }
282
283 /// Extracts a pointer to a slice of an in-memory tile.
284 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
285 Value tileMemory, Value sliceIndex) const {
286 auto llvmType = getTypeConverter()->convertType(t: tileMemory.getType());
287 auto descriptor =
288 rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
289 auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
290 auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
291 loc, rewriter.getI64Type(), sliceIndex);
292 return getStridedElementPtr(
293 loc, type: llvm::cast<MemRefType>(tileMemory.getType()),
294 memRefDesc: descriptor.getResult(0), indices: {sliceIndexI64, zero},
295 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter));
296 }
297
298 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
299 /// tile-sized memref (`tileAlloca`).
300 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
301 arm_sme::ArmSMETileType tileType, VectorType sliceType,
302 IntegerAttr tileId, Value sliceIndex) const {
303 // Cast the slice index to an i32.
304 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
305 loc, rewriter.getI32Type(), sliceIndex);
306 // Create an all-true predicate for the slice.
307 auto predicateType = sliceType.clone(rewriter.getI1Type());
308 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
309 loc, DenseElementsAttr::get(predicateType, true));
310 // Create padding vector (never used due to all-true predicate).
311 auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
312 // Get a pointer to the current slice.
313 auto slicePtr =
314 getInMemoryTileSlicePtr(rewriter, loc, tileMemory: tileAlloca, sliceIndex);
315 // Read the value of the current slice from ZA.
316 auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
317 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
318 // Load the new tile slice back from memory into ZA.
319 createLoadTileSliceIntrinsic(
320 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
321 allTruePredicate, slicePtr, tileId, sliceIndexI32);
322 // Store the current tile slice to memory.
323 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
324 rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
325 ValueRange{sliceIndex, zero});
326 }
327
328 /// Emits a full in-place swap of the contents of a tile in ZA and a
329 /// tile-sized memref (`tileAlloca`).
330 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
331 arm_sme::ArmSMETileType tileType, VectorType sliceType,
332 IntegerAttr tileId) const {
333 RewriterBase::InsertionGuard guard(rewriter);
334 // Create an scf.for over all tile slices.
335 auto minNumElts =
336 rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
337 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
338 auto upperBound = rewriter.create<arith::MulIOp>(
339 loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
340 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
341 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
342 // Emit a swap for each tile slice.
343 rewriter.setInsertionPointToStart(forOp.getBody());
344 auto sliceIndex = forOp.getInductionVar();
345 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
346 sliceIndex);
347 }
348};
349
350enum class RequiresSpillsAndFills { Yes, No };
351
352/// Base class for ArmSME to LLVM conversion patterns. By default, this adds
353/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
354/// disabled by setting the `requiresSpillsAndFills` template parameter to
355/// `RequiresSpillsAndFills::No`.
356template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
357 RequiresSpillsAndFills::Yes>
358struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
359 using ArmSMEOp = SourceOp;
360 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
361
362 static constexpr bool requiresSpillsAndFillsConversion() {
363 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
364 }
365};
366
367template <typename Pattern>
368static void addArmSMEConversionPattern(RewritePatternSet &patterns,
369 LLVMTypeConverter const &typeConverter) {
370 // Register spills/fills for ops that implement the
371 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
372 // `RequiresSpillsAndFills::Yes`.
373 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
374 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
375 typename Pattern::ArmSMEOp>,
376 typename Pattern::ArmSMEOp>) {
377 // Add spill/fill conversions with a very high benefit to ensure
378 // they are lowered first.
379 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
380 Pattern::ArmSMEOp::getOperationName(), typeConverter,
381 /*benefit=*/1337);
382 }
383 patterns.add<Pattern>(typeConverter);
384}
385
386/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
387template <typename... Patterns>
388static void
389addArmSMEConversionPatterns(RewritePatternSet &patterns,
390 LLVMTypeConverter const &typeConverter) {
391 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
392}
393
394struct GetTileConversion
395 : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
396 RequiresSpillsAndFills::No> {
397 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
398
399 LogicalResult
400 matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
401 ConversionPatternRewriter &rewriter) const override {
402 rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
403 getTile, getTile.getTileType());
404 return success();
405 }
406};
407
408/// Lower 'arm_sme.zero' to SME intrinsics.
409///
410/// BEFORE:
411/// ```mlir
412/// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
413/// ```
414///
415/// AFTER:
416/// ```mlir
417/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
418/// %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
419/// ```
420///
421/// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
422/// once all ArmSME ops have been converted to LLVM intrinsics.
423struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
424 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
425
426 LogicalResult
427 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429 auto loc = zero.getLoc();
430
431 auto tileId = getTileIdOrError(zero);
432 if (!tileId)
433 return failure();
434
435 // Get the base mask for tile based on the element size.
436 // The base mask is just the mask to zero the first tile (of a size).
437 // These masks are derived from:
438 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
439 arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
440 auto baseMaskForSize = [&] {
441 switch (tileType) {
442 case arm_sme::ArmSMETileType::ZAB:
443 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
444 // 64-bit element tiles named ZA0.D to ZA7.D.
445 return 0b1111'1111;
446 case arm_sme::ArmSMETileType::ZAH:
447 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
448 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
449 // once for ZA1.H.
450 return 0b0101'0101;
451 case arm_sme::ArmSMETileType::ZAS:
452 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
453 // element tiles named ZA0.D and ZA4.D.
454 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
455 return 0b0001'0001;
456 case arm_sme::ArmSMETileType::ZAD:
457 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
458 // setting the bit for that tile.
459 return 0b0000'0001;
460 default:
461 llvm_unreachable("bad element size");
462 }
463 }();
464
465 // The actual mask is just the base mask shifted by the tile ID.
466 // This will be folded to a constant after tile allocation.
467 //
468 // The shift is just derived from the layout of the tiles, and that the tile
469 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
470 // tiles:
471 //
472 // ZA0.S = ZA0.D and ZA4.D
473 // * Tile ID -> 0
474 // * Mask -> 00010001 = (00010001 << 0)
475 // ZA1.S = ZA1.D and ZA5.D
476 // * Tile ID -> 1
477 // * Mask -> 00100010 = (00010001 << 1)
478 // ZA2.S = ZA2.D and ZA6.D
479 // * Tile ID -> 2
480 // * Mask -> 01000100 = (00010001 << 2)
481 // ZA3.S = ZA3.D and ZA7.D
482 // * Tile ID -> 3
483 // * Mask -> 10001000 = (00010001 << 3)
484 //
485 // This holds for all tile sizes.
486 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
487 rewriter.create<arm_sme::aarch64_sme_zero>(
488 loc, rewriter.getI32IntegerAttr(zeroMask));
489
490 // Create a placeholder op to preserve dataflow.
491 rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
492 zero, zero.getVectorType());
493
494 return success();
495 }
496};
497
498/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
499struct LoadTileSliceConversion
500 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
502
503 LogicalResult
504 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505 arm_sme::LoadTileSliceOp::Adaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
507 auto loc = loadTileSliceOp.getLoc();
508 auto tileId = getTileIdOrError(loadTileSliceOp);
509 if (!tileId)
510 return failure();
511
512 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
513 adaptor.getBase(),
514 adaptor.getIndices(), rewriter);
515
516 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
517
518 // Cast tile slice to i32 for intrinsic.
519 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
520 loc, rewriter.getI32Type(), tileSlice);
521
522 // Create all active predicate mask.
523 auto maskOp = loadTileSliceOp.getMask();
524
525 auto tileVectorType = loadTileSliceOp.getVectorType();
526 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
527 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
528
529 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
530 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531 tileId, tileSliceI32);
532
533 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
534 // the input tile to preserve dataflow.
535 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
536
537 return success();
538 }
539};
540
541/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
542struct StoreTileSliceConversion
543 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
545
546 LogicalResult
547 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548 arm_sme::StoreTileSliceOp::Adaptor adaptor,
549 ConversionPatternRewriter &rewriter) const override {
550 auto loc = storeTileSliceOp.getLoc();
551 auto tileVectorType = storeTileSliceOp.getVectorType();
552
553 auto tileId = getTileIdOrError(storeTileSliceOp);
554 if (!tileId)
555 return failure();
556
557 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
558 Value ptr = this->getStridedElementPtr(
559 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560 adaptor.getIndices(), rewriter);
561
562 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
563
564 // Cast tile slice to i32 for intrinsic.
565 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
566 loc, rewriter.getI32Type(), tileSlice);
567
568 auto maskOp = storeTileSliceOp.getMask();
569
570 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
571 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
572
573 rewriter.replaceOp(storeTileSliceOp,
574 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
575 layout, maskOp, ptr,
576 tileId, tileSliceI32));
577
578 return success();
579 }
580};
581
582/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
583struct MoveVectorToTileSliceConversion
584 : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
585 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
586
587 LogicalResult
588 matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
589 arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
590 ConversionPatternRewriter &rewriter) const override {
591 auto loc = moveVectorToTileSliceOp.getLoc();
592 auto tileType = moveVectorToTileSliceOp.getTileType();
593
594 auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
595 if (!tileId)
596 return failure();
597
598 auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
599
600 // Cast tile slice from index to i32 for intrinsic.
601 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
602 loc, rewriter.getI32Type(), tileSlice);
603
604 // Create all active predicate mask.
605 auto one = rewriter.create<arith::ConstantOp>(
606 loc, rewriter.getI1Type(),
607 rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
609 /*scalableDims=*/{true});
610 auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
611
612 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
613 switch (moveVectorToTileSliceOp.getLayout()) {
614 case arm_sme::TileSliceLayout::Horizontal:
615 rewriter.create<arm_sme::aarch64_sme_write_horiz>(
616 loc, tileId, tileSliceI32, allActiveMask,
617 moveVectorToTileSliceOp.getVector());
618 break;
619 case arm_sme::TileSliceLayout::Vertical:
620 rewriter.create<arm_sme::aarch64_sme_write_vert>(
621 loc, tileId, tileSliceI32, allActiveMask,
622 moveVectorToTileSliceOp.getVector());
623 break;
624 }
625
626 // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
627 // the input tile to preserve dataflow.
628 rewriter.replaceOp(moveVectorToTileSliceOp,
629 moveVectorToTileSliceOp.getTile());
630
631 return success();
632 }
633};
634
635/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
636struct MoveTileSliceToVectorConversion
637 : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
638 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
639
640 LogicalResult
641 matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
642 OpAdaptor,
643 ConversionPatternRewriter &rewriter) const override {
644 auto loc = moveTileSliceToVector.getLoc();
645 auto sliceType = moveTileSliceToVector.getSliceType();
646 auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
647
648 auto tileId = getTileIdOrError(moveTileSliceToVector);
649 if (!tileId)
650 return failure();
651
652 // Create an 'all true' predicate for the tile slice.
653 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
654 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
655 loc, DenseElementsAttr::get(predicateType, true));
656
657 // Zero destination/fallback for tile slice extraction.
658 auto zeroVector = rewriter.create<arith::ConstantOp>(
659 loc, sliceType, rewriter.getZeroAttr(sliceType));
660
661 // Cast tile slice from index to i32 for intrinsic.
662 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
663 loc, rewriter.getI32Type(), sliceIndex);
664
665 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
666 switch (moveTileSliceToVector.getLayout()) {
667 case arm_sme::TileSliceLayout::Horizontal:
668 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
669 moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
670 tileId, sliceIndexI32);
671 break;
672 case arm_sme::TileSliceLayout::Vertical:
673 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
674 moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
675 tileId, sliceIndexI32);
676 break;
677 }
678
679 return success();
680 }
681};
682
683/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
684///
685/// Example:
686///
687/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
688/// : vector<[4]xf32>, vector<[4]xf32>
689///
690/// is converted to:
691///
692/// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
693/// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
694/// vector<[4]xf32>) -> ()
695///
696/// Currently only supports FMOPA and BFMOPA (non-widening).
697struct OuterProductOpConversion
698 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
699 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
700
701 LogicalResult
702 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
703 arm_sme::OuterProductOp::Adaptor adaptor,
704 ConversionPatternRewriter &rewriter) const override {
705 auto tileId = getTileIdOrError(outerProductOp);
706 if (!tileId)
707 return failure();
708
709 auto isSupportedType = [](VectorType vectorType) {
710 // TODO: the FP outer product instruction variants are predicated on
711 // different features [1]:
712 //
713 // * FMOPA (non-widening)
714 // * half-precision - +sme2p1,+sme-f16f16
715 // * single-precision - +sme
716 // * double-precision - +sme-f64f64
717 // * BFMOPA
718 // * half-precision - +sme2p1,+b16b16
719 //
720 // It should be possible to control lowering based on target features.
721 // [1]
722 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
723 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
724 return false;
725
726 auto elementType = vectorType.getElementType();
727
728 if (!elementType.isF16() && !elementType.isBF16() &&
729 !elementType.isF32() && !elementType.isF64())
730 return false;
731
732 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
733 vectorType.getElementTypeBitWidth();
734 return vectorType.getShape() ==
735 ArrayRef<int64_t>({minNumElts, minNumElts});
736 };
737
738 // TODO: Support CombiningKind::Sub for outer products.
739 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
740 return outerProductOp.emitError("unsupported kind");
741
742 auto resultVectorType = outerProductOp.getResultType();
743 if (!isSupportedType(resultVectorType))
744 return outerProductOp.emitError("unsupported type");
745
746 auto loc = outerProductOp.getLoc();
747
748 Value acc = outerProductOp.getAcc();
749 if (!acc)
750 // Initalize accumulator with zero.
751 acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
752 rewriter, loc, resultVectorType);
753
754 Value lhsMask = outerProductOp.getLhsMask();
755 Value rhsMask = outerProductOp.getRhsMask();
756
757 if (!lhsMask || !rhsMask) {
758 auto predTy =
759 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
760 Value allActiveMask = rewriter.create<arith::ConstantOp>(
761 loc, DenseElementsAttr::get(predTy, true));
762 lhsMask = allActiveMask;
763 rhsMask = allActiveMask;
764 }
765
766 // Create 'arm_sme.intr.mopa' outer product intrinsic.
767 rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
768 outerProductOp.getLhs(),
769 outerProductOp.getRhs());
770
771 // The outerproduct intrinsics have no result, replace
772 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
773 rewriter.replaceOp(outerProductOp, acc);
774
775 return success();
776 }
777};
778
779/// Lower 2-way and 4-way widening outer products to intrinsics.
780template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
781struct OuterProductWideningOpConversion
782 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
783 using ConvertArmSMEOpToLLVMPattern<
784 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
785
786 LogicalResult
787 matchAndRewrite(OuterProductWideningOp op,
788 typename OuterProductWideningOp::Adaptor adaptor,
789 ConversionPatternRewriter &rewriter) const override {
790 auto tileId = getTileIdOrError(op);
791 if (!tileId)
792 return failure();
793
794 Value acc = op.getAcc();
795 if (!acc)
796 // Initalize accumulator with zero.
797 acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
798 rewriter, op.getLoc(), op.getResultType());
799
800 Value lhsMask = op.getLhsMask();
801 Value rhsMask = op.getRhsMask();
802 if (!lhsMask || !rhsMask) {
803 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
804 Value allActiveMask = rewriter.create<arith::ConstantOp>(
805 op.getLoc(), DenseElementsAttr::get(predTy, true));
806 lhsMask = allActiveMask;
807 rhsMask = allActiveMask;
808 }
809
810 rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
811 rhsMask, adaptor.getLhs(),
812 adaptor.getRhs());
813
814 // The outerproduct intrinsics have no result, replace
815 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816 rewriter.replaceOp(op, acc);
817
818 return success();
819 }
820};
821
822/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
823///
824/// Example:
825///
826/// %0 = arm_sme.streaming_vl <half>
827///
828/// is converted to:
829///
830/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
831/// %0 = arith.index_cast %cnt : i64 to index
832///
833struct StreamingVLOpConversion
834 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835 RequiresSpillsAndFills::No> {
836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
837
838 LogicalResult
839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840 arm_sme::StreamingVLOp::Adaptor adaptor,
841 ConversionPatternRewriter &rewriter) const override {
842 auto loc = streamingVlOp.getLoc();
843 auto i64Type = rewriter.getI64Type();
844 auto *intrOp = [&]() -> Operation * {
845 switch (streamingVlOp.getTypeSize()) {
846 case arm_sme::TypeSize::Byte:
847 return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848 case arm_sme::TypeSize::Half:
849 return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850 case arm_sme::TypeSize::Word:
851 return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852 case arm_sme::TypeSize::Double:
853 return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
854 }
855 }();
856 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
857 streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
858 return success();
859 }
860};
861
862} // namespace
863
864namespace {
865
866struct ConvertArmSMEToLLVMPass
867 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
868 void runOnOperation() override {
869 LLVMConversionTarget target(getContext());
870 RewritePatternSet patterns(&getContext());
871 LLVMTypeConverter converter(&getContext());
872 configureArmSMEToLLVMConversionLegality(target);
873 populateArmSMEToLLVMConversionPatterns(converter, patterns);
874
875 if (failed(applyPartialConversion(getOperation(), target,
876 std::move(patterns))))
877 signalPassFailure();
878 }
879};
880
881} // namespace
882
883void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
884 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
885 target.addLegalOp<
886 arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
887 arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
888 arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
889 arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
890 arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
891 arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
892 arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
893 arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
894 arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
895 arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
896 arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
897 arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
898 arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
899 arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
900 arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
901 arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
902 arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
903 arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
904 arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
905 arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
906 arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
907 arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
908 arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
909 target.addLegalDialect<arith::ArithDialect,
910 /* The following are used to lower tile spills/fills */
911 vector::VectorDialect, scf::SCFDialect,
912 memref::MemRefDialect>();
913 target.addLegalOp<UnrealizedConversionCastOp>();
914}
915
916void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
917 RewritePatternSet &patterns) {
918 converter.addConversion(callback: [&](VectorType type) -> std::optional<Type> {
919 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
920 // SME vector types should be eliminated.
921 if (arm_sme::isValidSMETileVectorType(vType: type))
922 return type;
923 return std::nullopt;
924 });
925
926 addArmSMEConversionPatterns<
927 LoadTileSliceConversion, MoveTileSliceToVectorConversion,
928 MoveVectorToTileSliceConversion, StoreTileSliceConversion,
929 StreamingVLOpConversion, OuterProductOpConversion,
930 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
931 arm_sme::aarch64_sme_mopa_wide>,
932 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
933 arm_sme::aarch64_sme_mops_wide>,
934 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
935 arm_sme::aarch64_sme_smopa_za32>,
936 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
937 arm_sme::aarch64_sme_smops_za32>,
938 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
939 arm_sme::aarch64_sme_umopa_za32>,
940 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
941 arm_sme::aarch64_sme_umops_za32>,
942 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
943 arm_sme::aarch64_sme_smopa_wide>,
944 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
945 arm_sme::aarch64_sme_smops_wide>,
946 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
947 arm_sme::aarch64_sme_umopa_wide>,
948 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
949 arm_sme::aarch64_sme_umops_wide>,
950 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
951 arm_sme::aarch64_sme_sumopa_wide>,
952 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
953 arm_sme::aarch64_sme_sumops_wide>,
954 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
955 arm_sme::aarch64_sme_usmopa_wide>,
956 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
957 arm_sme::aarch64_sme_usmops_wide>,
958 ZeroOpConversion, GetTileConversion>(patterns, converter);
959}
960
961std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
962 return std::make_unique<ConvertArmSMEToLLVMPass>();
963}
964

source code of mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp