1//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to LLVM intrinsics.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
14
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
20#include "mlir/Dialect/ArmSME/Utils/Utils.h"
21#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/ADT/ScopeExit.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
39
40/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
41static Operation *createLoadTileSliceIntrinsic(
42 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
43 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
44 IntegerAttr tileId, Value tileSliceI32) {
45 if (layout == arm_sme::TileSliceLayout::Horizontal) {
46 switch (type) {
47 case arm_sme::ArmSMETileType::ZAB:
48 return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
49 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
50 case arm_sme::ArmSMETileType::ZAH:
51 return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
52 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
53 case arm_sme::ArmSMETileType::ZAS:
54 return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
55 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
56 case arm_sme::ArmSMETileType::ZAD:
57 return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
58 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
59 case arm_sme::ArmSMETileType::ZAQ:
60 return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
61 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
62 }
63 } else {
64 switch (type) {
65 case arm_sme::ArmSMETileType::ZAB:
66 return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
67 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
68 case arm_sme::ArmSMETileType::ZAH:
69 return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
70 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
71 case arm_sme::ArmSMETileType::ZAS:
72 return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
73 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
74 case arm_sme::ArmSMETileType::ZAD:
75 return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
76 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
77 case arm_sme::ArmSMETileType::ZAQ:
78 return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
79 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
80 break;
81 }
82 }
83 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
84}
85
86/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
87static Operation *createStoreTileSliceIntrinsic(
88 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
89 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
90 IntegerAttr tileId, Value tileSliceI32) {
91 if (layout == arm_sme::TileSliceLayout::Horizontal) {
92 switch (type) {
93 case arm_sme::ArmSMETileType::ZAB:
94 return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
95 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAH:
97 return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
98 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAS:
100 return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
101 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAD:
103 return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
104 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
105 case arm_sme::ArmSMETileType::ZAQ:
106 return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
107 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
108 }
109 } else {
110 switch (type) {
111 case arm_sme::ArmSMETileType::ZAB:
112 return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
113 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAH:
115 return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
116 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAS:
118 return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
119 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAD:
121 return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
122 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
123 case arm_sme::ArmSMETileType::ZAQ:
124 return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
125 location: loc, args&: maskOp, args&: ptr, args&: tileId, args&: tileSliceI32);
126 }
127 }
128 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
129}
130
131IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
132 auto tileId = op.getTileId();
133 if (!tileId)
134 op.emitOpError(
135 message: "expected tile ID to be allocated before conversion to LLVM");
136 return tileId;
137}
138
139/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
140/// placed in the first block of the function.
141static memref::AllocaOp
142createAllocaForTile(RewriterBase &rewriter, Location loc,
143 FunctionOpInterface func,
144 arm_sme::ArmSMETileOpInterface tileOp) {
145 RewriterBase::InsertionGuard g(rewriter);
146 // Move to the first operation in the function.
147 rewriter.setInsertionPointToStart(&func.getBlocks().front());
148 // Create an alloca matching the tile size of the `tileOp`.
149 auto vscale = rewriter.create<vector::VectorScaleOp>(location: loc);
150 auto tileElementType = tileOp.getTileType().getElementType();
151 auto memrefType = MemRefType::get(
152 shape: {ShapedType::kDynamic, ShapedType::kDynamic}, elementType: tileElementType);
153 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(type: tileElementType);
154 auto minElementsOp =
155 rewriter.create<arith::ConstantIndexOp>(location: loc, args&: minElements);
156 auto vectorLen = rewriter.create<arith::MulIOp>(location: loc, args&: vscale, args&: minElementsOp);
157 auto alloca = rewriter.create<memref::AllocaOp>(
158 location: loc, args&: memrefType, args: ValueRange{vectorLen, vectorLen});
159 return alloca;
160}
161
162/// Finds or creates an alloca for a spill of a tile.
163static memref::AllocaOp getOrCreateAllocaForTile(
164 RewriterBase &rewriter, Location loc, FunctionOpInterface func,
165 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
166 // Find an alloca at the top of the function tagged with a
167 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
168 for (auto &op : func.getBlocks().front()) {
169 auto alloca = llvm::dyn_cast<memref::AllocaOp>(Val&: op);
170 if (!alloca)
171 continue;
172 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
173 Val: alloca->getDiscardableAttr(name: kInMemoryTileIdAttr));
174 if (!inMemoryTileId)
175 continue;
176 if (inMemoryTileId.getInt() == tileId)
177 return alloca;
178 }
179 // Otherwise, create a new alloca:
180 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
181 alloca->setDiscardableAttr(name: kInMemoryTileIdAttr,
182 value: rewriter.getI32IntegerAttr(value: tileId));
183 return alloca;
184}
185
186/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
187/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
188/// the op to tile 0, then emitting a full tile swap between ZA and memory
189/// before + after the tile op.
190///
191/// Example:
192///
193/// // Note: <IN MEMORY TILE> = tile ID >= 16.
194/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
195///
196/// is converted to:
197/// // At function entry:
198/// %spill = memref.alloca ... : memref<?x?xty>
199///
200/// // Around op:
201/// scf.for %slice_idx {
202/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
203/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
204/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
205/// }
206/// arm_sme.tile_op { tile_id = 0 }
207/// scf.for %slice_idx {
208/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
209/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
210/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
211/// }
212///
213/// Note that these spills/fills are not inserted earlier as concept of a
214/// register, and the need to swap the contents, can't really be represented
215/// correctly at a high level in MLIR.
216///
217/// TODO: Reduce the spills/reloads to single slices where possible (and omit
218/// redundant reloads). This could be done via a method on the
219/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
220///
221/// `tileOp.getZaUsage()` could return:
222///
223/// struct ArmSMEOpZAUsage {
224/// enum class Kind {
225/// TileRead, // Omit store after tile operation.
226/// TileWrite, // Omit load before tile operation.
227/// TileReadWrite, // Needs both tile load and store.
228/// SliceRead, // Spill single slice and omit store after operation.
229/// SliceWrite, // Spill single slice and omit load before operation.
230/// SliceReadWrite // Spill single slice.
231/// };
232/// Value sliceIndex {};
233/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
234/// };
235///
236struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
237
238 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
239 const LLVMTypeConverter &typeConverter,
240 PatternBenefit benefit)
241 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
242 typeConverter, benefit) {}
243
244 LogicalResult
245 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
246 ConversionPatternRewriter &rewriter) const override {
247 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(Val: op);
248 // Tile has a real (hardware) tile. No spills/reloads required.
249 if (!tileOp.isInMemoryTile())
250 return failure();
251
252 tileOp->emitWarning(
253 message: "failed to allocate SME virtual tile to operation, tile value will go "
254 "through memory, expect degraded performance");
255
256 // Step 1. Create an alloca for the tile at the top of the function (if one
257 // does not already exist).
258 auto loc = tileOp.getLoc();
259 auto func = tileOp->getParentOfType<FunctionOpInterface>();
260 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
261 tileId: tileOp.getTileId().getInt());
262
263 // Step 2. Assign the op a real tile ID.
264 // For simplicity, we always use tile 0 (which always exists).
265 auto zeroTileId = rewriter.getI32IntegerAttr(value: 0);
266 rewriter.modifyOpInPlace(root: tileOp, callable: [&] { tileOp.setTileId(zeroTileId); });
267
268 VectorType tileVectorType = tileOp.getTileType();
269 auto sliceType = VectorType::Builder(tileVectorType).dropDim(pos: 0);
270 auto swapInMemoryTileWithSMETileZero = [&] {
271 emitFullTileSwap(rewriter, loc, tileAlloca,
272 tileType: *arm_sme::getSMETileType(tileVectorType), sliceType,
273 tileId: zeroTileId);
274 };
275
276 // Step 3. Emit tile swaps before and after the op.
277 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
278 // touches (i.e. a single tile slice).
279 {
280 rewriter.setInsertionPoint(op);
281 // Swap the contents of ZA and the in-memory tile before the op.
282 swapInMemoryTileWithSMETileZero();
283 rewriter.setInsertionPointAfter(op);
284 // Swap the tile back out to memory again after the op.
285 swapInMemoryTileWithSMETileZero();
286 }
287
288 return success();
289 }
290
291 /// Extracts a pointer to a slice of an in-memory tile.
292 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
293 Value tileMemory, Value sliceIndex) const {
294 auto llvmType = getTypeConverter()->convertType(t: tileMemory.getType());
295 auto descriptor =
296 rewriter.create<UnrealizedConversionCastOp>(location: loc, args&: llvmType, args&: tileMemory);
297 auto zero = rewriter.create<arith::ConstantIntOp>(location: loc, args: 0, /*width=*/args: 64);
298 auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
299 location: loc, args: rewriter.getI64Type(), args&: sliceIndex);
300 return getStridedElementPtr(
301 rewriter&: static_cast<ConversionPatternRewriter &>(rewriter), loc,
302 type: llvm::cast<MemRefType>(Val: tileMemory.getType()), memRefDesc: descriptor.getResult(i: 0),
303 indices: {sliceIndexI64, zero});
304 }
305
306 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
307 /// tile-sized memref (`tileAlloca`).
308 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
309 arm_sme::ArmSMETileType tileType, VectorType sliceType,
310 IntegerAttr tileId, Value sliceIndex) const {
311 // Cast the slice index to an i32.
312 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
313 location: loc, args: rewriter.getI32Type(), args&: sliceIndex);
314 // Create an all-true predicate for the slice.
315 auto predicateType = sliceType.clone(elementType: rewriter.getI1Type());
316 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
317 location: loc, args: DenseElementsAttr::get(type: predicateType, value: true));
318 // Create padding vector (never used due to all-true predicate).
319 auto padVector = rewriter.create<LLVM::PoisonOp>(location: loc, args&: sliceType);
320 // Get a pointer to the current slice.
321 auto slicePtr =
322 getInMemoryTileSlicePtr(rewriter, loc, tileMemory: tileAlloca, sliceIndex);
323 // Read the value of the current slice from ZA.
324 auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
325 location: loc, args&: sliceType, args&: padVector, args&: allTruePredicate, args&: tileId, args&: sliceIndexI32);
326 // Load the new tile slice back from memory into ZA.
327 createLoadTileSliceIntrinsic(
328 rewriter, loc, type: tileType, layout: arm_sme::TileSliceLayout::Horizontal,
329 maskOp: allTruePredicate, ptr: slicePtr, tileId, tileSliceI32: sliceIndexI32);
330 // Store the current tile slice to memory.
331 auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
332 rewriter.create<vector::StoreOp>(location: loc, args&: currentTileSlice, args&: tileAlloca,
333 args: ValueRange{sliceIndex, zero});
334 }
335
336 /// Emits a full in-place swap of the contents of a tile in ZA and a
337 /// tile-sized memref (`tileAlloca`).
338 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
339 arm_sme::ArmSMETileType tileType, VectorType sliceType,
340 IntegerAttr tileId) const {
341 RewriterBase::InsertionGuard guard(rewriter);
342 // Create an scf.for over all tile slices.
343 auto minNumElts =
344 rewriter.create<arith::ConstantIndexOp>(location: loc, args: sliceType.getDimSize(idx: 0));
345 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
346 auto upperBound = rewriter.create<arith::MulIOp>(
347 location: loc, args&: minNumElts, args: rewriter.create<vector::VectorScaleOp>(location: loc));
348 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
349 auto forOp = rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step);
350 // Emit a swap for each tile slice.
351 rewriter.setInsertionPointToStart(forOp.getBody());
352 auto sliceIndex = forOp.getInductionVar();
353 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
354 sliceIndex);
355 }
356};
357
358enum class RequiresSpillsAndFills { Yes, No };
359
360/// Base class for ArmSME to LLVM conversion patterns. By default, this adds
361/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
362/// disabled by setting the `requiresSpillsAndFills` template parameter to
363/// `RequiresSpillsAndFills::No`.
364template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
365 RequiresSpillsAndFills::Yes>
366struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
367 using ArmSMEOp = SourceOp;
368 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
369
370 static constexpr bool requiresSpillsAndFillsConversion() {
371 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
372 }
373};
374
375template <typename Pattern>
376static void addArmSMEConversionPattern(RewritePatternSet &patterns,
377 LLVMTypeConverter const &typeConverter) {
378 // Register spills/fills for ops that implement the
379 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
380 // `RequiresSpillsAndFills::Yes`.
381 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
382 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
383 typename Pattern::ArmSMEOp>,
384 typename Pattern::ArmSMEOp>) {
385 // Add spill/fill conversions with a very high benefit to ensure
386 // they are lowered first.
387 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
388 Pattern::ArmSMEOp::getOperationName(), typeConverter,
389 /*benefit=*/1337);
390 }
391 patterns.add<Pattern>(typeConverter);
392}
393
394/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
395template <typename... Patterns>
396static void
397addArmSMEConversionPatterns(RewritePatternSet &patterns,
398 LLVMTypeConverter const &typeConverter) {
399 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
400}
401
402/// Lower 'arm_sme.zero' to SME intrinsics.
403///
404/// BEFORE:
405/// ```mlir
406/// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
407/// ```
408///
409/// AFTER:
410/// ```mlir
411/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
412/// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
413/// ```
414///
415/// The 'arm_sme.get_tile' (which models the return) will fold away once all
416/// ArmSME ops have been converted to LLVM intrinsics.
417struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
418 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
419
420 LogicalResult
421 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter) const override {
423 auto loc = zero.getLoc();
424
425 auto tileId = getTileIdOrError(op: zero);
426 if (!tileId)
427 return failure();
428
429 // Get the base mask for tile based on the element size.
430 // The base mask is just the mask to zero the first tile (of a size).
431 // These masks are derived from:
432 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
433 arm_sme::ArmSMETileType tileType =
434 *arm_sme::getSMETileType(zero.getTileType());
435 auto baseMaskForSize = [&] {
436 switch (tileType) {
437 case arm_sme::ArmSMETileType::ZAB:
438 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
439 // 64-bit element tiles named ZA0.D to ZA7.D.
440 return 0b1111'1111;
441 case arm_sme::ArmSMETileType::ZAH:
442 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
443 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
444 // once for ZA1.H.
445 return 0b0101'0101;
446 case arm_sme::ArmSMETileType::ZAS:
447 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
448 // element tiles named ZA0.D and ZA4.D.
449 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
450 return 0b0001'0001;
451 case arm_sme::ArmSMETileType::ZAD:
452 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
453 // setting the bit for that tile.
454 return 0b0000'0001;
455 default:
456 llvm_unreachable("bad element size");
457 }
458 }();
459
460 // The actual mask is just the base mask shifted by the tile ID.
461 // This will be folded to a constant after tile allocation.
462 //
463 // The shift is just derived from the layout of the tiles, and that the tile
464 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
465 // tiles:
466 //
467 // ZA0.S = ZA0.D and ZA4.D
468 // * Tile ID -> 0
469 // * Mask -> 00010001 = (00010001 << 0)
470 // ZA1.S = ZA1.D and ZA5.D
471 // * Tile ID -> 1
472 // * Mask -> 00100010 = (00010001 << 1)
473 // ZA2.S = ZA2.D and ZA6.D
474 // * Tile ID -> 2
475 // * Mask -> 01000100 = (00010001 << 2)
476 // ZA3.S = ZA3.D and ZA7.D
477 // * Tile ID -> 3
478 // * Mask -> 10001000 = (00010001 << 3)
479 //
480 // This holds for all tile sizes.
481 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
482 rewriter.create<arm_sme::aarch64_sme_zero>(
483 location: loc, args: rewriter.getI32IntegerAttr(value: zeroMask));
484
485 // Create a placeholder op to preserve dataflow.
486 // Note: Place the `get_tile` op at the start of the block. This ensures
487 // that if there are multiple `zero` ops the intrinsics will be consecutive.
488 rewriter.setInsertionPointToStart(zero->getBlock());
489 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(op: zero, args: zero.getVectorType());
490
491 return success();
492 }
493};
494
495/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
496struct LoadTileSliceConversion
497 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
498 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
499
500 LogicalResult
501 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
502 arm_sme::LoadTileSliceOp::Adaptor adaptor,
503 ConversionPatternRewriter &rewriter) const override {
504 auto loc = loadTileSliceOp.getLoc();
505 auto tileId = getTileIdOrError(op: loadTileSliceOp);
506 if (!tileId)
507 return failure();
508
509 Value ptr = this->getStridedElementPtr(
510 rewriter, loc, type: loadTileSliceOp.getMemRefType(), memRefDesc: adaptor.getBase(),
511 indices: adaptor.getIndices());
512
513 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
514
515 // Cast tile slice to i32 for intrinsic.
516 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
517 location: loc, args: rewriter.getI32Type(), args&: tileSlice);
518
519 // Create all active predicate mask.
520 auto maskOp = loadTileSliceOp.getMask();
521
522 auto tileVectorType = loadTileSliceOp.getVectorType();
523 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
524 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
525
526 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
527 createLoadTileSliceIntrinsic(rewriter, loc, type: tileType, layout, maskOp, ptr,
528 tileId, tileSliceI32);
529
530 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
531 // the input tile to preserve dataflow.
532 rewriter.replaceOp(op: loadTileSliceOp, newValues: loadTileSliceOp.getTile());
533
534 return success();
535 }
536};
537
538/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
539struct StoreTileSliceConversion
540 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
541 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
542
543 LogicalResult
544 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
545 arm_sme::StoreTileSliceOp::Adaptor adaptor,
546 ConversionPatternRewriter &rewriter) const override {
547 auto loc = storeTileSliceOp.getLoc();
548 auto tileVectorType = storeTileSliceOp.getVectorType();
549
550 auto tileId = getTileIdOrError(op: storeTileSliceOp);
551 if (!tileId)
552 return failure();
553
554 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
555 Value ptr = this->getStridedElementPtr(
556 rewriter, loc, type: storeTileSliceOp.getMemRefType(), memRefDesc: adaptor.getBase(),
557 indices: adaptor.getIndices());
558
559 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
560
561 // Cast tile slice to i32 for intrinsic.
562 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
563 location: loc, args: rewriter.getI32Type(), args&: tileSlice);
564
565 auto maskOp = storeTileSliceOp.getMask();
566
567 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
568 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
569
570 rewriter.replaceOp(op: storeTileSliceOp,
571 newOp: createStoreTileSliceIntrinsic(rewriter, loc, type: tileType,
572 layout, maskOp, ptr,
573 tileId, tileSliceI32));
574
575 return success();
576 }
577};
578
579/// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
580struct InsertTileSliceConversion
581 : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
582 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
583
584 LogicalResult
585 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
586 arm_sme::InsertTileSliceOp::Adaptor adaptor,
587 ConversionPatternRewriter &rewriter) const override {
588 auto loc = insertTileSliceOp.getLoc();
589 auto tileType = insertTileSliceOp.getTileType();
590
591 auto tileId = getTileIdOrError(op: insertTileSliceOp);
592 if (!tileId)
593 return failure();
594
595 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
596
597 // Cast tile slice from index to i32 for intrinsic.
598 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
599 location: loc, args: rewriter.getI32Type(), args&: tileSlice);
600
601 // Create all active predicate mask.
602 auto one = rewriter.create<arith::ConstantOp>(
603 location: loc, args: rewriter.getI1Type(),
604 args: rewriter.getIntegerAttr(type: rewriter.getI1Type(), value: 1));
605 auto predTy = VectorType::get(shape: tileType.getShape()[0], elementType: rewriter.getI1Type(),
606 /*scalableDims=*/{true});
607 auto allActiveMask = rewriter.create<vector::SplatOp>(location: loc, args&: predTy, args&: one);
608
609 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
610 switch (insertTileSliceOp.getLayout()) {
611 case arm_sme::TileSliceLayout::Horizontal:
612 rewriter.create<arm_sme::aarch64_sme_write_horiz>(
613 location: loc, args&: tileId, args&: tileSliceI32, args&: allActiveMask,
614 args: insertTileSliceOp.getVector());
615 break;
616 case arm_sme::TileSliceLayout::Vertical:
617 rewriter.create<arm_sme::aarch64_sme_write_vert>(
618 location: loc, args&: tileId, args&: tileSliceI32, args&: allActiveMask,
619 args: insertTileSliceOp.getVector());
620 break;
621 }
622
623 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
624 // the input tile to preserve dataflow.
625 rewriter.replaceOp(op: insertTileSliceOp, newValues: insertTileSliceOp.getTile());
626
627 return success();
628 }
629};
630
631/// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
632struct ExtractTileSliceConversion
633 : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
634 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
635
636 LogicalResult
637 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
638 ConversionPatternRewriter &rewriter) const override {
639 auto loc = extractTileSlice.getLoc();
640 auto sliceType = extractTileSlice.getSliceType();
641 auto sliceIndex = extractTileSlice.getTileSliceIndex();
642
643 auto tileId = getTileIdOrError(op: extractTileSlice);
644 if (!tileId)
645 return failure();
646
647 // Create an 'all true' predicate for the tile slice.
648 auto predicateType = sliceType.cloneWith(shape: {}, elementType: rewriter.getI1Type());
649 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
650 location: loc, args: DenseElementsAttr::get(type: predicateType, value: true));
651
652 // Zero destination/fallback for tile slice extraction.
653 auto zeroVector = rewriter.create<arith::ConstantOp>(
654 location: loc, args&: sliceType, args: rewriter.getZeroAttr(type: sliceType));
655
656 // Cast tile slice from index to i32 for intrinsic.
657 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
658 location: loc, args: rewriter.getI32Type(), args&: sliceIndex);
659
660 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
661 switch (extractTileSlice.getLayout()) {
662 case arm_sme::TileSliceLayout::Horizontal:
663 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
664 op: extractTileSlice, args&: sliceType, args&: zeroVector, args&: allTruePredicate, args&: tileId,
665 args&: sliceIndexI32);
666 break;
667 case arm_sme::TileSliceLayout::Vertical:
668 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
669 op: extractTileSlice, args&: sliceType, args&: zeroVector, args&: allTruePredicate, args&: tileId,
670 args&: sliceIndexI32);
671 break;
672 }
673
674 return success();
675 }
676};
677
678/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
679///
680/// Example:
681///
682/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
683/// : vector<[4]xf32>, vector<[4]xf32>
684///
685/// is converted to:
686///
687/// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
688/// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
689/// vector<[4]xf32>) -> ()
690///
691/// Currently only supports FMOPA and BFMOPA (non-widening).
692struct OuterProductOpConversion
693 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
694 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
695
696 LogicalResult
697 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
698 arm_sme::OuterProductOp::Adaptor adaptor,
699 ConversionPatternRewriter &rewriter) const override {
700 auto tileId = getTileIdOrError(op: outerProductOp);
701 if (!tileId)
702 return failure();
703
704 auto isSupportedType = [](VectorType vectorType) {
705 // TODO: the FP outer product instruction variants are predicated on
706 // different features [1]:
707 //
708 // * FMOPA (non-widening)
709 // * half-precision - +sme2p1,+sme-f16f16
710 // * single-precision - +sme
711 // * double-precision - +sme-f64f64
712 // * BFMOPA
713 // * half-precision - +sme2p1,+b16b16
714 //
715 // It should be possible to control lowering based on target features.
716 // [1]
717 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
718 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
719 return false;
720
721 auto elementType = vectorType.getElementType();
722
723 if (!elementType.isF16() && !elementType.isBF16() &&
724 !elementType.isF32() && !elementType.isF64())
725 return false;
726
727 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
728 vectorType.getElementTypeBitWidth();
729 return vectorType.getShape() ==
730 ArrayRef<int64_t>({minNumElts, minNumElts});
731 };
732
733 // TODO: Support CombiningKind::Sub for outer products.
734 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
735 return outerProductOp.emitError(message: "unsupported kind");
736
737 auto resultVectorType = outerProductOp.getResultType();
738 if (!isSupportedType(resultVectorType))
739 return outerProductOp.emitError(message: "unsupported type");
740
741 auto loc = outerProductOp.getLoc();
742
743 Value acc = outerProductOp.getAcc();
744 if (!acc) {
745 // Initalize accumulator with zero.
746 auto zero = rewriter.create<arm_sme::ZeroOp>(location: loc, args&: resultVectorType);
747 zero.setTileId(tileId);
748 acc = zero;
749 }
750
751 Value lhsMask = outerProductOp.getLhsMask();
752 Value rhsMask = outerProductOp.getRhsMask();
753
754 if (!lhsMask || !rhsMask) {
755 auto predTy =
756 outerProductOp.getLhsType().cloneWith(shape: {}, elementType: rewriter.getI1Type());
757 Value allActiveMask = rewriter.create<arith::ConstantOp>(
758 location: loc, args: DenseElementsAttr::get(type: predTy, value: true));
759 lhsMask = allActiveMask;
760 rhsMask = allActiveMask;
761 }
762
763 // Create 'arm_sme.intr.mopa' outer product intrinsic.
764 rewriter.create<arm_sme::aarch64_sme_mopa>(location: loc, args&: tileId, args&: lhsMask, args&: rhsMask,
765 args: outerProductOp.getLhs(),
766 args: outerProductOp.getRhs());
767
768 // The outerproduct intrinsics have no result, replace
769 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
770 rewriter.replaceOp(op: outerProductOp, newValues: acc);
771
772 return success();
773 }
774};
775
776/// Lower 2-way and 4-way widening outer products to intrinsics.
777template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
778struct OuterProductWideningOpConversion
779 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
780 using ConvertArmSMEOpToLLVMPattern<
781 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
782
783 LogicalResult
784 matchAndRewrite(OuterProductWideningOp op,
785 typename OuterProductWideningOp::Adaptor adaptor,
786 ConversionPatternRewriter &rewriter) const override {
787 auto tileId = getTileIdOrError(op);
788 if (!tileId)
789 return failure();
790
791 auto loc = op.getLoc();
792 Value acc = op.getAcc();
793 if (!acc) {
794 // Initalize accumulator with zero.
795 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
796 zero.setTileId(tileId);
797 acc = zero;
798 }
799
800 Value lhsMask = op.getLhsMask();
801 Value rhsMask = op.getRhsMask();
802 if (!lhsMask || !rhsMask) {
803 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
804 Value allActiveMask = rewriter.create<arith::ConstantOp>(
805 loc, DenseElementsAttr::get(predTy, true));
806 lhsMask = allActiveMask;
807 rhsMask = allActiveMask;
808 }
809
810 rewriter.create<OuterProductWideningIntrOp>(
811 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
812
813 // The outerproduct intrinsics have no result, replace
814 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
815 rewriter.replaceOp(op, acc);
816
817 return success();
818 }
819};
820
821/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
822///
823/// Example:
824///
825/// %0 = arm_sme.streaming_vl <half>
826///
827/// is converted to:
828///
829/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
830/// %0 = arith.index_cast %cnt : i64 to index
831///
832struct StreamingVLOpConversion
833 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
834 RequiresSpillsAndFills::No> {
835 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
836
837 LogicalResult
838 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
839 arm_sme::StreamingVLOp::Adaptor adaptor,
840 ConversionPatternRewriter &rewriter) const override {
841 auto loc = streamingVlOp.getLoc();
842 auto i64Type = rewriter.getI64Type();
843 auto *intrOp = [&]() -> Operation * {
844 switch (streamingVlOp.getTypeSize()) {
845 case arm_sme::TypeSize::Byte:
846 return rewriter.create<arm_sme::aarch64_sme_cntsb>(location: loc, args&: i64Type);
847 case arm_sme::TypeSize::Half:
848 return rewriter.create<arm_sme::aarch64_sme_cntsh>(location: loc, args&: i64Type);
849 case arm_sme::TypeSize::Word:
850 return rewriter.create<arm_sme::aarch64_sme_cntsw>(location: loc, args&: i64Type);
851 case arm_sme::TypeSize::Double:
852 return rewriter.create<arm_sme::aarch64_sme_cntsd>(location: loc, args&: i64Type);
853 }
854 llvm_unreachable("unknown type size in StreamingVLOpConversion");
855 }();
856 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
857 op: streamingVlOp, args: rewriter.getIndexType(), args: intrOp->getResult(idx: 0));
858 return success();
859 }
860};
861
862/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
863/// or-ing the zero masks. Note: In future the backend _should_ handle this.
864static void mergeConsecutiveTileZerosInBlock(Block *block) {
865 uint32_t mergedZeroMask = 0;
866 SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
867 auto replaceMergedZeroOps = [&] {
868 auto cleanup = llvm::make_scope_exit(F: [&] {
869 mergedZeroMask = 0;
870 zeroOpsToMerge.clear();
871 });
872 if (zeroOpsToMerge.size() <= 1)
873 return;
874 IRRewriter rewriter(zeroOpsToMerge.front());
875 rewriter.create<arm_sme::aarch64_sme_zero>(
876 location: zeroOpsToMerge.front().getLoc(),
877 args: rewriter.getI32IntegerAttr(value: mergedZeroMask));
878 for (auto zeroOp : zeroOpsToMerge)
879 rewriter.eraseOp(op: zeroOp);
880 };
881 for (Operation &op : *block) {
882 if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(Val&: op)) {
883 mergedZeroMask |= zeroOp.getTileMask();
884 zeroOpsToMerge.push_back(Elt: zeroOp);
885 } else {
886 replaceMergedZeroOps();
887 }
888 }
889 replaceMergedZeroOps();
890}
891
892} // namespace
893
894namespace {
895
896struct ConvertArmSMEToLLVMPass
897 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
898 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
899 this->dumpTileLiveRanges = dumpTileLiveRanges;
900 }
901 void runOnOperation() override {
902 auto function = getOperation();
903
904 if (failed(Result: arm_sme::allocateSMETiles(function, dumpRanges: dumpTileLiveRanges)))
905 return signalPassFailure();
906
907 LLVMConversionTarget target(getContext());
908 RewritePatternSet patterns(&getContext());
909 LLVMTypeConverter converter(&getContext());
910 configureArmSMEToLLVMConversionLegality(target);
911 populateArmSMEToLLVMConversionPatterns(converter, patterns);
912
913 if (failed(Result: applyPartialConversion(op: function, target, patterns: std::move(patterns))))
914 signalPassFailure();
915
916 function->walk(callback&: mergeConsecutiveTileZerosInBlock);
917
918 // Walk the function and fail if there are unexpected operations on SME
919 // tile types after conversion.
920 function->walk(callback: [&](Operation *op) {
921 // These ops are legal post conversion, skip these.
922 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(Val: op) ||
923 !op->isRegistered())
924 return;
925 auto isSMETileType = [](Type type) {
926 return arm_sme::isValidSMETileVectorType(type);
927 };
928 if (llvm::any_of(Range: op->getResultTypes(), P: isSMETileType) ||
929 llvm::any_of(Range: op->getOperandTypes(), P: isSMETileType)) {
930 op->emitOpError(message: "unexpected operation with SME tile type after "
931 "conversion to LLVM");
932 signalPassFailure();
933 }
934 });
935 }
936};
937
938} // namespace
939
940void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
941 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
942 target.addLegalOp<
943 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
944 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
945 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
946 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
947 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
948 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
949 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
950 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
951 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
952 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
953 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
954 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
955 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
956 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
957 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
958 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
959 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
960 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
961 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
962 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
963 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
964 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
965 arm_sme::aarch64_sme_cntsd>();
966 target.addLegalDialect<arith::ArithDialect,
967 /* The following are used to lower tile spills/fills */
968 vector::VectorDialect, scf::SCFDialect,
969 memref::MemRefDialect>();
970 // Pseudo operations. These cannot be code-generated but may exist in the
971 // input IR, or be generated during the conversion. They need to be eliminated
972 // before the final conversion to LLVM IR (and likely will be due to DCE).
973 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
974 UnrealizedConversionCastOp>();
975}
976
977void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
978 RewritePatternSet &patterns) {
979 converter.addConversion(callback: [&](VectorType type) -> std::optional<Type> {
980 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
981 // SME vector types should be eliminated.
982 if (arm_sme::isValidSMETileVectorType(vType: type))
983 return type;
984 return std::nullopt;
985 });
986
987 addArmSMEConversionPatterns<
988 LoadTileSliceConversion, ExtractTileSliceConversion,
989 InsertTileSliceConversion, StoreTileSliceConversion,
990 StreamingVLOpConversion, OuterProductOpConversion,
991 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
992 arm_sme::aarch64_sme_mopa_wide>,
993 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
994 arm_sme::aarch64_sme_mops_wide>,
995 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
996 arm_sme::aarch64_sme_smopa_za32>,
997 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
998 arm_sme::aarch64_sme_smops_za32>,
999 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1000 arm_sme::aarch64_sme_umopa_za32>,
1001 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1002 arm_sme::aarch64_sme_umops_za32>,
1003 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1004 arm_sme::aarch64_sme_smopa_wide>,
1005 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1006 arm_sme::aarch64_sme_smops_wide>,
1007 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1008 arm_sme::aarch64_sme_umopa_wide>,
1009 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1010 arm_sme::aarch64_sme_umops_wide>,
1011 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1012 arm_sme::aarch64_sme_sumopa_wide>,
1013 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1014 arm_sme::aarch64_sme_sumops_wide>,
1015 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1016 arm_sme::aarch64_sme_usmopa_wide>,
1017 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1018 arm_sme::aarch64_sme_usmops_wide>,
1019 ZeroOpConversion>(patterns, typeConverter: converter);
1020}
1021
1022std::unique_ptr<Pass>
1023mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1024 return std::make_unique<ConvertArmSMEToLLVMPass>(args&: dumpTileLiveRanges);
1025}
1026

source code of mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp