1//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to LLVM intrinsics.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
14
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
20#include "mlir/Dialect/ArmSME/Utils/Utils.h"
21#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22#include "mlir/Dialect/Func/IR/FuncOps.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "llvm/ADT/ScopeExit.h"
29
30namespace mlir {
31#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32#include "mlir/Conversion/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36
37namespace {
38
39static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
40
41/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
42static Operation *createLoadTileSliceIntrinsic(
43 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
44 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
45 IntegerAttr tileId, Value tileSliceI32) {
46 if (layout == arm_sme::TileSliceLayout::Horizontal) {
47 switch (type) {
48 case arm_sme::ArmSMETileType::ZAB:
49 return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
50 loc, maskOp, ptr, tileId, tileSliceI32);
51 case arm_sme::ArmSMETileType::ZAH:
52 return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
53 loc, maskOp, ptr, tileId, tileSliceI32);
54 case arm_sme::ArmSMETileType::ZAS:
55 return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
56 loc, maskOp, ptr, tileId, tileSliceI32);
57 case arm_sme::ArmSMETileType::ZAD:
58 return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
59 loc, maskOp, ptr, tileId, tileSliceI32);
60 case arm_sme::ArmSMETileType::ZAQ:
61 return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
62 loc, maskOp, ptr, tileId, tileSliceI32);
63 }
64 } else {
65 switch (type) {
66 case arm_sme::ArmSMETileType::ZAB:
67 return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
68 loc, maskOp, ptr, tileId, tileSliceI32);
69 case arm_sme::ArmSMETileType::ZAH:
70 return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
71 loc, maskOp, ptr, tileId, tileSliceI32);
72 case arm_sme::ArmSMETileType::ZAS:
73 return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
74 loc, maskOp, ptr, tileId, tileSliceI32);
75 case arm_sme::ArmSMETileType::ZAD:
76 return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
77 loc, maskOp, ptr, tileId, tileSliceI32);
78 case arm_sme::ArmSMETileType::ZAQ:
79 return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
80 loc, maskOp, ptr, tileId, tileSliceI32);
81 break;
82 }
83 }
84 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
85}
86
87/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
88static Operation *createStoreTileSliceIntrinsic(
89 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
90 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
91 IntegerAttr tileId, Value tileSliceI32) {
92 if (layout == arm_sme::TileSliceLayout::Horizontal) {
93 switch (type) {
94 case arm_sme::ArmSMETileType::ZAB:
95 return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
96 loc, maskOp, ptr, tileId, tileSliceI32);
97 case arm_sme::ArmSMETileType::ZAH:
98 return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
99 loc, maskOp, ptr, tileId, tileSliceI32);
100 case arm_sme::ArmSMETileType::ZAS:
101 return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
102 loc, maskOp, ptr, tileId, tileSliceI32);
103 case arm_sme::ArmSMETileType::ZAD:
104 return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
105 loc, maskOp, ptr, tileId, tileSliceI32);
106 case arm_sme::ArmSMETileType::ZAQ:
107 return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
108 loc, maskOp, ptr, tileId, tileSliceI32);
109 }
110 } else {
111 switch (type) {
112 case arm_sme::ArmSMETileType::ZAB:
113 return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
114 loc, maskOp, ptr, tileId, tileSliceI32);
115 case arm_sme::ArmSMETileType::ZAH:
116 return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
117 loc, maskOp, ptr, tileId, tileSliceI32);
118 case arm_sme::ArmSMETileType::ZAS:
119 return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
120 loc, maskOp, ptr, tileId, tileSliceI32);
121 case arm_sme::ArmSMETileType::ZAD:
122 return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
123 loc, maskOp, ptr, tileId, tileSliceI32);
124 case arm_sme::ArmSMETileType::ZAQ:
125 return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
126 loc, maskOp, ptr, tileId, tileSliceI32);
127 }
128 }
129 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
130}
131
132IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
133 auto tileId = op.getTileId();
134 if (!tileId)
135 op.emitOpError(
136 "expected tile ID to be allocated before conversion to LLVM");
137 return tileId;
138}
139
140/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
141/// placed in the first block of the function.
142static memref::AllocaOp
143createAllocaForTile(RewriterBase &rewriter, Location loc,
144 FunctionOpInterface func,
145 arm_sme::ArmSMETileOpInterface tileOp) {
146 RewriterBase::InsertionGuard g(rewriter);
147 // Move to the first operation in the function.
148 rewriter.setInsertionPointToStart(&func.getBlocks().front());
149 // Create an alloca matching the tile size of the `tileOp`.
150 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
151 auto tileElementType = tileOp.getTileType().getElementType();
152 auto memrefType = MemRefType::get(
153 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
154 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(type: tileElementType);
155 auto minElementsOp =
156 rewriter.create<arith::ConstantIndexOp>(loc, minElements);
157 auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
158 auto alloca = rewriter.create<memref::AllocaOp>(
159 loc, memrefType, ValueRange{vectorLen, vectorLen});
160 return alloca;
161}
162
163/// Finds or creates an alloca for a spill of a tile.
164static memref::AllocaOp getOrCreateAllocaForTile(
165 RewriterBase &rewriter, Location loc, FunctionOpInterface func,
166 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
167 // Find an alloca at the top of the function tagged with a
168 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
169 for (auto &op : func.getBlocks().front()) {
170 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
171 if (!alloca)
172 continue;
173 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
174 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
175 if (!inMemoryTileId)
176 continue;
177 if (inMemoryTileId.getInt() == tileId)
178 return alloca;
179 }
180 // Otherwise, create a new alloca:
181 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
182 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
183 rewriter.getI32IntegerAttr(tileId));
184 return alloca;
185}
186
187/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
188/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
189/// the op to tile 0, then emitting a full tile swap between ZA and memory
190/// before + after the tile op.
191///
192/// Example:
193///
194/// // Note: <IN MEMORY TILE> = tile ID >= 16.
195/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
196///
197/// is converted to:
198/// // At function entry:
199/// %spill = memref.alloca ... : memref<?x?xty>
200///
201/// // Around op:
202/// scf.for %slice_idx {
203/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
204/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
205/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
206/// }
207/// arm_sme.tile_op { tile_id = 0 }
208/// scf.for %slice_idx {
209/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
210/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
211/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
212/// }
213///
214/// Note that these spills/fills are not inserted earlier as concept of a
215/// register, and the need to swap the contents, can't really be represented
216/// correctly at a high level in MLIR.
217///
218/// TODO: Reduce the spills/reloads to single slices where possible (and omit
219/// redundant reloads). This could be done via a method on the
220/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
221///
222/// `tileOp.getZaUsage()` could return:
223///
224/// struct ArmSMEOpZAUsage {
225/// enum class Kind {
226/// TileRead, // Omit store after tile operation.
227/// TileWrite, // Omit load before tile operation.
228/// TileReadWrite, // Needs both tile load and store.
229/// SliceRead, // Spill single slice and omit store after operation.
230/// SliceWrite, // Spill single slice and omit load before operation.
231/// SliceReadWrite // Spill single slice.
232/// };
233/// Value sliceIndex {};
234/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
235/// };
236///
237struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
238
239 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
240 const LLVMTypeConverter &typeConverter,
241 PatternBenefit benefit)
242 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
243 typeConverter, benefit) {}
244
245 LogicalResult
246 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
247 ConversionPatternRewriter &rewriter) const override {
248 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
249 // Tile has a real (hardware) tile. No spills/reloads required.
250 if (!tileOp.isInMemoryTile())
251 return failure();
252
253 tileOp->emitWarning(
254 "failed to allocate SME virtual tile to operation, tile value will go "
255 "through memory, expect degraded performance");
256
257 // Step 1. Create an alloca for the tile at the top of the function (if one
258 // does not already exist).
259 auto loc = tileOp.getLoc();
260 auto func = tileOp->getParentOfType<FunctionOpInterface>();
261 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
262 tileOp.getTileId().getInt());
263
264 // Step 2. Assign the op a real tile ID.
265 // For simplicity, we always use tile 0 (which always exists).
266 auto zeroTileId = rewriter.getI32IntegerAttr(0);
267 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
268
269 VectorType tileVectorType = tileOp.getTileType();
270 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
271 auto swapInMemoryTileWithSMETileZero = [&] {
272 emitFullTileSwap(rewriter, loc, tileAlloca,
273 *arm_sme::getSMETileType(tileVectorType), sliceType,
274 zeroTileId);
275 };
276
277 // Step 3. Emit tile swaps before and after the op.
278 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
279 // touches (i.e. a single tile slice).
280 {
281 rewriter.setInsertionPoint(op);
282 // Swap the contents of ZA and the in-memory tile before the op.
283 swapInMemoryTileWithSMETileZero();
284 rewriter.setInsertionPointAfter(op);
285 // Swap the tile back out to memory again after the op.
286 swapInMemoryTileWithSMETileZero();
287 }
288
289 return success();
290 }
291
292 /// Extracts a pointer to a slice of an in-memory tile.
293 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
294 Value tileMemory, Value sliceIndex) const {
295 auto llvmType = getTypeConverter()->convertType(t: tileMemory.getType());
296 auto descriptor =
297 rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
298 auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
299 auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300 loc, rewriter.getI64Type(), sliceIndex);
301 return getStridedElementPtr(
302 static_cast<ConversionPatternRewriter &>(rewriter), loc,
303 llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
304 {sliceIndexI64, zero});
305 }
306
307 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
308 /// tile-sized memref (`tileAlloca`).
309 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
310 arm_sme::ArmSMETileType tileType, VectorType sliceType,
311 IntegerAttr tileId, Value sliceIndex) const {
312 // Cast the slice index to an i32.
313 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
314 loc, rewriter.getI32Type(), sliceIndex);
315 // Create an all-true predicate for the slice.
316 auto predicateType = sliceType.clone(rewriter.getI1Type());
317 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
318 loc, DenseElementsAttr::get(predicateType, true));
319 // Create padding vector (never used due to all-true predicate).
320 auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
321 // Get a pointer to the current slice.
322 auto slicePtr =
323 getInMemoryTileSlicePtr(rewriter, loc, tileMemory: tileAlloca, sliceIndex);
324 // Read the value of the current slice from ZA.
325 auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
326 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
327 // Load the new tile slice back from memory into ZA.
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
331 // Store the current tile slice to memory.
332 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
333 rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
334 ValueRange{sliceIndex, zero});
335 }
336
337 /// Emits a full in-place swap of the contents of a tile in ZA and a
338 /// tile-sized memref (`tileAlloca`).
339 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId) const {
342 RewriterBase::InsertionGuard guard(rewriter);
343 // Create an scf.for over all tile slices.
344 auto minNumElts =
345 rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
346 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
347 auto upperBound = rewriter.create<arith::MulIOp>(
348 loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
349 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
350 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
351 // Emit a swap for each tile slice.
352 rewriter.setInsertionPointToStart(forOp.getBody());
353 auto sliceIndex = forOp.getInductionVar();
354 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
355 sliceIndex);
356 }
357};
358
359enum class RequiresSpillsAndFills { Yes, No };
360
361/// Base class for ArmSME to LLVM conversion patterns. By default, this adds
362/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
363/// disabled by setting the `requiresSpillsAndFills` template parameter to
364/// `RequiresSpillsAndFills::No`.
365template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
366 RequiresSpillsAndFills::Yes>
367struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
368 using ArmSMEOp = SourceOp;
369 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
370
371 static constexpr bool requiresSpillsAndFillsConversion() {
372 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
373 }
374};
375
376template <typename Pattern>
377static void addArmSMEConversionPattern(RewritePatternSet &patterns,
378 LLVMTypeConverter const &typeConverter) {
379 // Register spills/fills for ops that implement the
380 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
381 // `RequiresSpillsAndFills::Yes`.
382 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
383 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
384 typename Pattern::ArmSMEOp>,
385 typename Pattern::ArmSMEOp>) {
386 // Add spill/fill conversions with a very high benefit to ensure
387 // they are lowered first.
388 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
389 Pattern::ArmSMEOp::getOperationName(), typeConverter,
390 /*benefit=*/1337);
391 }
392 patterns.add<Pattern>(typeConverter);
393}
394
395/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
396template <typename... Patterns>
397static void
398addArmSMEConversionPatterns(RewritePatternSet &patterns,
399 LLVMTypeConverter const &typeConverter) {
400 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
401}
402
403/// Lower 'arm_sme.zero' to SME intrinsics.
404///
405/// BEFORE:
406/// ```mlir
407/// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
408/// ```
409///
410/// AFTER:
411/// ```mlir
412/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
413/// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
414/// ```
415///
416/// The 'arm_sme.get_tile' (which models the return) will fold away once all
417/// ArmSME ops have been converted to LLVM intrinsics.
418struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
419 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
420
421 LogicalResult
422 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter) const override {
424 auto loc = zero.getLoc();
425
426 auto tileId = getTileIdOrError(zero);
427 if (!tileId)
428 return failure();
429
430 // Get the base mask for tile based on the element size.
431 // The base mask is just the mask to zero the first tile (of a size).
432 // These masks are derived from:
433 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
434 arm_sme::ArmSMETileType tileType =
435 *arm_sme::getSMETileType(zero.getTileType());
436 auto baseMaskForSize = [&] {
437 switch (tileType) {
438 case arm_sme::ArmSMETileType::ZAB:
439 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
440 // 64-bit element tiles named ZA0.D to ZA7.D.
441 return 0b1111'1111;
442 case arm_sme::ArmSMETileType::ZAH:
443 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
444 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
445 // once for ZA1.H.
446 return 0b0101'0101;
447 case arm_sme::ArmSMETileType::ZAS:
448 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
449 // element tiles named ZA0.D and ZA4.D.
450 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
451 return 0b0001'0001;
452 case arm_sme::ArmSMETileType::ZAD:
453 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
454 // setting the bit for that tile.
455 return 0b0000'0001;
456 default:
457 llvm_unreachable("bad element size");
458 }
459 }();
460
461 // The actual mask is just the base mask shifted by the tile ID.
462 // This will be folded to a constant after tile allocation.
463 //
464 // The shift is just derived from the layout of the tiles, and that the tile
465 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
466 // tiles:
467 //
468 // ZA0.S = ZA0.D and ZA4.D
469 // * Tile ID -> 0
470 // * Mask -> 00010001 = (00010001 << 0)
471 // ZA1.S = ZA1.D and ZA5.D
472 // * Tile ID -> 1
473 // * Mask -> 00100010 = (00010001 << 1)
474 // ZA2.S = ZA2.D and ZA6.D
475 // * Tile ID -> 2
476 // * Mask -> 01000100 = (00010001 << 2)
477 // ZA3.S = ZA3.D and ZA7.D
478 // * Tile ID -> 3
479 // * Mask -> 10001000 = (00010001 << 3)
480 //
481 // This holds for all tile sizes.
482 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
483 rewriter.create<arm_sme::aarch64_sme_zero>(
484 loc, rewriter.getI32IntegerAttr(zeroMask));
485
486 // Create a placeholder op to preserve dataflow.
487 // Note: Place the `get_tile` op at the start of the block. This ensures
488 // that if there are multiple `zero` ops the intrinsics will be consecutive.
489 rewriter.setInsertionPointToStart(zero->getBlock());
490 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
491
492 return success();
493 }
494};
495
496/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
497struct LoadTileSliceConversion
498 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
499 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
500
501 LogicalResult
502 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
503 arm_sme::LoadTileSliceOp::Adaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505 auto loc = loadTileSliceOp.getLoc();
506 auto tileId = getTileIdOrError(loadTileSliceOp);
507 if (!tileId)
508 return failure();
509
510 Value ptr = this->getStridedElementPtr(
511 rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
512 adaptor.getIndices());
513
514 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515
516 // Cast tile slice to i32 for intrinsic.
517 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
518 loc, rewriter.getI32Type(), tileSlice);
519
520 // Create all active predicate mask.
521 auto maskOp = loadTileSliceOp.getMask();
522
523 auto tileVectorType = loadTileSliceOp.getVectorType();
524 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
525 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
526
527 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
528 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
529 tileId, tileSliceI32);
530
531 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
532 // the input tile to preserve dataflow.
533 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
534
535 return success();
536 }
537};
538
539/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
540struct StoreTileSliceConversion
541 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
542 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
543
544 LogicalResult
545 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
546 arm_sme::StoreTileSliceOp::Adaptor adaptor,
547 ConversionPatternRewriter &rewriter) const override {
548 auto loc = storeTileSliceOp.getLoc();
549 auto tileVectorType = storeTileSliceOp.getVectorType();
550
551 auto tileId = getTileIdOrError(storeTileSliceOp);
552 if (!tileId)
553 return failure();
554
555 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556 Value ptr = this->getStridedElementPtr(
557 rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558 adaptor.getIndices());
559
560 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561
562 // Cast tile slice to i32 for intrinsic.
563 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
564 loc, rewriter.getI32Type(), tileSlice);
565
566 auto maskOp = storeTileSliceOp.getMask();
567
568 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
569 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
570
571 rewriter.replaceOp(storeTileSliceOp,
572 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
573 layout, maskOp, ptr,
574 tileId, tileSliceI32));
575
576 return success();
577 }
578};
579
580/// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
581struct InsertTileSliceConversion
582 : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
583 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
584
585 LogicalResult
586 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
587 arm_sme::InsertTileSliceOp::Adaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 auto loc = insertTileSliceOp.getLoc();
590 auto tileType = insertTileSliceOp.getTileType();
591
592 auto tileId = getTileIdOrError(insertTileSliceOp);
593 if (!tileId)
594 return failure();
595
596 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
597
598 // Cast tile slice from index to i32 for intrinsic.
599 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
600 loc, rewriter.getI32Type(), tileSlice);
601
602 // Create all active predicate mask.
603 auto one = rewriter.create<arith::ConstantOp>(
604 loc, rewriter.getI1Type(),
605 rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
606 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
607 /*scalableDims=*/{true});
608 auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
609
610 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
611 switch (insertTileSliceOp.getLayout()) {
612 case arm_sme::TileSliceLayout::Horizontal:
613 rewriter.create<arm_sme::aarch64_sme_write_horiz>(
614 loc, tileId, tileSliceI32, allActiveMask,
615 insertTileSliceOp.getVector());
616 break;
617 case arm_sme::TileSliceLayout::Vertical:
618 rewriter.create<arm_sme::aarch64_sme_write_vert>(
619 loc, tileId, tileSliceI32, allActiveMask,
620 insertTileSliceOp.getVector());
621 break;
622 }
623
624 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
625 // the input tile to preserve dataflow.
626 rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
627
628 return success();
629 }
630};
631
632/// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
633struct ExtractTileSliceConversion
634 : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
635 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
636
637 LogicalResult
638 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 auto loc = extractTileSlice.getLoc();
641 auto sliceType = extractTileSlice.getSliceType();
642 auto sliceIndex = extractTileSlice.getTileSliceIndex();
643
644 auto tileId = getTileIdOrError(extractTileSlice);
645 if (!tileId)
646 return failure();
647
648 // Create an 'all true' predicate for the tile slice.
649 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
650 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
651 loc, DenseElementsAttr::get(predicateType, true));
652
653 // Zero destination/fallback for tile slice extraction.
654 auto zeroVector = rewriter.create<arith::ConstantOp>(
655 loc, sliceType, rewriter.getZeroAttr(sliceType));
656
657 // Cast tile slice from index to i32 for intrinsic.
658 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
659 loc, rewriter.getI32Type(), sliceIndex);
660
661 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
662 switch (extractTileSlice.getLayout()) {
663 case arm_sme::TileSliceLayout::Horizontal:
664 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
665 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
666 sliceIndexI32);
667 break;
668 case arm_sme::TileSliceLayout::Vertical:
669 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
670 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
671 sliceIndexI32);
672 break;
673 }
674
675 return success();
676 }
677};
678
679/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
680///
681/// Example:
682///
683/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
684/// : vector<[4]xf32>, vector<[4]xf32>
685///
686/// is converted to:
687///
688/// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
689/// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
690/// vector<[4]xf32>) -> ()
691///
692/// Currently only supports FMOPA and BFMOPA (non-widening).
693struct OuterProductOpConversion
694 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
695 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
696
697 LogicalResult
698 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
699 arm_sme::OuterProductOp::Adaptor adaptor,
700 ConversionPatternRewriter &rewriter) const override {
701 auto tileId = getTileIdOrError(outerProductOp);
702 if (!tileId)
703 return failure();
704
705 auto isSupportedType = [](VectorType vectorType) {
706 // TODO: the FP outer product instruction variants are predicated on
707 // different features [1]:
708 //
709 // * FMOPA (non-widening)
710 // * half-precision - +sme2p1,+sme-f16f16
711 // * single-precision - +sme
712 // * double-precision - +sme-f64f64
713 // * BFMOPA
714 // * half-precision - +sme2p1,+b16b16
715 //
716 // It should be possible to control lowering based on target features.
717 // [1]
718 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
719 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
720 return false;
721
722 auto elementType = vectorType.getElementType();
723
724 if (!elementType.isF16() && !elementType.isBF16() &&
725 !elementType.isF32() && !elementType.isF64())
726 return false;
727
728 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
729 vectorType.getElementTypeBitWidth();
730 return vectorType.getShape() ==
731 ArrayRef<int64_t>({minNumElts, minNumElts});
732 };
733
734 // TODO: Support CombiningKind::Sub for outer products.
735 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
736 return outerProductOp.emitError("unsupported kind");
737
738 auto resultVectorType = outerProductOp.getResultType();
739 if (!isSupportedType(resultVectorType))
740 return outerProductOp.emitError("unsupported type");
741
742 auto loc = outerProductOp.getLoc();
743
744 Value acc = outerProductOp.getAcc();
745 if (!acc) {
746 // Initalize accumulator with zero.
747 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
748 zero.setTileId(tileId);
749 acc = zero;
750 }
751
752 Value lhsMask = outerProductOp.getLhsMask();
753 Value rhsMask = outerProductOp.getRhsMask();
754
755 if (!lhsMask || !rhsMask) {
756 auto predTy =
757 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
758 Value allActiveMask = rewriter.create<arith::ConstantOp>(
759 loc, DenseElementsAttr::get(predTy, true));
760 lhsMask = allActiveMask;
761 rhsMask = allActiveMask;
762 }
763
764 // Create 'arm_sme.intr.mopa' outer product intrinsic.
765 rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
766 outerProductOp.getLhs(),
767 outerProductOp.getRhs());
768
769 // The outerproduct intrinsics have no result, replace
770 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
771 rewriter.replaceOp(outerProductOp, acc);
772
773 return success();
774 }
775};
776
777/// Lower 2-way and 4-way widening outer products to intrinsics.
778template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
779struct OuterProductWideningOpConversion
780 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
781 using ConvertArmSMEOpToLLVMPattern<
782 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
783
784 LogicalResult
785 matchAndRewrite(OuterProductWideningOp op,
786 typename OuterProductWideningOp::Adaptor adaptor,
787 ConversionPatternRewriter &rewriter) const override {
788 auto tileId = getTileIdOrError(op);
789 if (!tileId)
790 return failure();
791
792 auto loc = op.getLoc();
793 Value acc = op.getAcc();
794 if (!acc) {
795 // Initalize accumulator with zero.
796 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
797 zero.setTileId(tileId);
798 acc = zero;
799 }
800
801 Value lhsMask = op.getLhsMask();
802 Value rhsMask = op.getRhsMask();
803 if (!lhsMask || !rhsMask) {
804 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
805 Value allActiveMask = rewriter.create<arith::ConstantOp>(
806 loc, DenseElementsAttr::get(predTy, true));
807 lhsMask = allActiveMask;
808 rhsMask = allActiveMask;
809 }
810
811 rewriter.create<OuterProductWideningIntrOp>(
812 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
813
814 // The outerproduct intrinsics have no result, replace
815 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816 rewriter.replaceOp(op, acc);
817
818 return success();
819 }
820};
821
822/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
823///
824/// Example:
825///
826/// %0 = arm_sme.streaming_vl <half>
827///
828/// is converted to:
829///
830/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
831/// %0 = arith.index_cast %cnt : i64 to index
832///
833struct StreamingVLOpConversion
834 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835 RequiresSpillsAndFills::No> {
836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
837
838 LogicalResult
839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840 arm_sme::StreamingVLOp::Adaptor adaptor,
841 ConversionPatternRewriter &rewriter) const override {
842 auto loc = streamingVlOp.getLoc();
843 auto i64Type = rewriter.getI64Type();
844 auto *intrOp = [&]() -> Operation * {
845 switch (streamingVlOp.getTypeSize()) {
846 case arm_sme::TypeSize::Byte:
847 return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848 case arm_sme::TypeSize::Half:
849 return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850 case arm_sme::TypeSize::Word:
851 return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852 case arm_sme::TypeSize::Double:
853 return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
854 }
855 llvm_unreachable("unknown type size in StreamingVLOpConversion");
856 }();
857 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
858 streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
859 return success();
860 }
861};
862
863/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
864/// or-ing the zero masks. Note: In future the backend _should_ handle this.
865static void mergeConsecutiveTileZerosInBlock(Block *block) {
866 uint32_t mergedZeroMask = 0;
867 SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
868 auto replaceMergedZeroOps = [&] {
869 auto cleanup = llvm::make_scope_exit([&] {
870 mergedZeroMask = 0;
871 zeroOpsToMerge.clear();
872 });
873 if (zeroOpsToMerge.size() <= 1)
874 return;
875 IRRewriter rewriter(zeroOpsToMerge.front());
876 rewriter.create<arm_sme::aarch64_sme_zero>(
877 zeroOpsToMerge.front().getLoc(),
878 rewriter.getI32IntegerAttr(mergedZeroMask));
879 for (auto zeroOp : zeroOpsToMerge)
880 rewriter.eraseOp(zeroOp);
881 };
882 for (Operation &op : *block) {
883 if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
884 mergedZeroMask |= zeroOp.getTileMask();
885 zeroOpsToMerge.push_back(zeroOp);
886 } else {
887 replaceMergedZeroOps();
888 }
889 }
890 replaceMergedZeroOps();
891}
892
893} // namespace
894
895namespace {
896
897struct ConvertArmSMEToLLVMPass
898 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
899 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
900 this->dumpTileLiveRanges = dumpTileLiveRanges;
901 }
902 void runOnOperation() override {
903 auto function = getOperation();
904
905 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
906 return signalPassFailure();
907
908 LLVMConversionTarget target(getContext());
909 RewritePatternSet patterns(&getContext());
910 LLVMTypeConverter converter(&getContext());
911 configureArmSMEToLLVMConversionLegality(target);
912 populateArmSMEToLLVMConversionPatterns(converter, patterns);
913
914 if (failed(applyPartialConversion(function, target, std::move(patterns))))
915 signalPassFailure();
916
917 function->walk(mergeConsecutiveTileZerosInBlock);
918
919 // Walk the function and fail if there are unexpected operations on SME
920 // tile types after conversion.
921 function->walk([&](Operation *op) {
922 // These ops are legal post conversion, skip these.
923 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
924 !op->isRegistered())
925 return;
926 auto isSMETileType = [](Type type) {
927 return arm_sme::isValidSMETileVectorType(type);
928 };
929 if (llvm::any_of(Range: op->getResultTypes(), P: isSMETileType) ||
930 llvm::any_of(Range: op->getOperandTypes(), P: isSMETileType)) {
931 op->emitOpError(message: "unexpected operation with SME tile type after "
932 "conversion to LLVM");
933 signalPassFailure();
934 }
935 });
936 }
937};
938
939} // namespace
940
941void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
942 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
943 target.addLegalOp<
944 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
945 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
946 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
947 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
948 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
949 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
950 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
951 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
952 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
953 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
954 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
955 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
956 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
957 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
958 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
959 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
960 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
961 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
962 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
963 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
964 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
965 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
966 arm_sme::aarch64_sme_cntsd>();
967 target.addLegalDialect<arith::ArithDialect,
968 /* The following are used to lower tile spills/fills */
969 vector::VectorDialect, scf::SCFDialect,
970 memref::MemRefDialect>();
971 // Pseudo operations. These cannot be code-generated but may exist in the
972 // input IR, or be generated during the conversion. They need to be eliminated
973 // before the final conversion to LLVM IR (and likely will be due to DCE).
974 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
975 UnrealizedConversionCastOp>();
976}
977
978void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
979 RewritePatternSet &patterns) {
980 converter.addConversion(callback: [&](VectorType type) -> std::optional<Type> {
981 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
982 // SME vector types should be eliminated.
983 if (arm_sme::isValidSMETileVectorType(type))
984 return type;
985 return std::nullopt;
986 });
987
988 addArmSMEConversionPatterns<
989 LoadTileSliceConversion, ExtractTileSliceConversion,
990 InsertTileSliceConversion, StoreTileSliceConversion,
991 StreamingVLOpConversion, OuterProductOpConversion,
992 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
993 arm_sme::aarch64_sme_mopa_wide>,
994 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
995 arm_sme::aarch64_sme_mops_wide>,
996 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
997 arm_sme::aarch64_sme_smopa_za32>,
998 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
999 arm_sme::aarch64_sme_smops_za32>,
1000 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1001 arm_sme::aarch64_sme_umopa_za32>,
1002 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1003 arm_sme::aarch64_sme_umops_za32>,
1004 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1005 arm_sme::aarch64_sme_smopa_wide>,
1006 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1007 arm_sme::aarch64_sme_smops_wide>,
1008 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1009 arm_sme::aarch64_sme_umopa_wide>,
1010 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1011 arm_sme::aarch64_sme_umops_wide>,
1012 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1013 arm_sme::aarch64_sme_sumopa_wide>,
1014 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1015 arm_sme::aarch64_sme_sumops_wide>,
1016 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1017 arm_sme::aarch64_sme_usmopa_wide>,
1018 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1019 arm_sme::aarch64_sme_usmops_wide>,
1020 ZeroOpConversion>(patterns, converter);
1021}
1022
1023std::unique_ptr<Pass>
1024mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1025 return std::make_unique<ConvertArmSMEToLLVMPass>(args&: dumpTileLiveRanges);
1026}
1027

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp