1 | //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering of ArmSME operations to LLVM intrinsics. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" |
14 | |
15 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
16 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
19 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/Pass/Pass.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | |
27 | namespace mlir { |
28 | #define GEN_PASS_DEF_CONVERTARMSMETOLLVM |
29 | #include "mlir/Conversion/Passes.h.inc" |
30 | } // namespace mlir |
31 | |
32 | using namespace mlir; |
33 | |
34 | namespace { |
35 | |
36 | static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id" ); |
37 | |
38 | /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. |
39 | static Operation *createLoadTileSliceIntrinsic( |
40 | RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, |
41 | arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, |
42 | IntegerAttr tileId, Value tileSliceI32) { |
43 | if (layout == arm_sme::TileSliceLayout::Horizontal) { |
44 | switch (type) { |
45 | case arm_sme::ArmSMETileType::ZAB: |
46 | return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( |
47 | loc, maskOp, ptr, tileId, tileSliceI32); |
48 | case arm_sme::ArmSMETileType::ZAH: |
49 | return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( |
50 | loc, maskOp, ptr, tileId, tileSliceI32); |
51 | case arm_sme::ArmSMETileType::ZAS: |
52 | return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( |
53 | loc, maskOp, ptr, tileId, tileSliceI32); |
54 | case arm_sme::ArmSMETileType::ZAD: |
55 | return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( |
56 | loc, maskOp, ptr, tileId, tileSliceI32); |
57 | case arm_sme::ArmSMETileType::ZAQ: |
58 | return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( |
59 | loc, maskOp, ptr, tileId, tileSliceI32); |
60 | } |
61 | } else { |
62 | switch (type) { |
63 | case arm_sme::ArmSMETileType::ZAB: |
64 | return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( |
65 | loc, maskOp, ptr, tileId, tileSliceI32); |
66 | case arm_sme::ArmSMETileType::ZAH: |
67 | return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( |
68 | loc, maskOp, ptr, tileId, tileSliceI32); |
69 | case arm_sme::ArmSMETileType::ZAS: |
70 | return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( |
71 | loc, maskOp, ptr, tileId, tileSliceI32); |
72 | case arm_sme::ArmSMETileType::ZAD: |
73 | return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( |
74 | loc, maskOp, ptr, tileId, tileSliceI32); |
75 | case arm_sme::ArmSMETileType::ZAQ: |
76 | return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( |
77 | loc, maskOp, ptr, tileId, tileSliceI32); |
78 | break; |
79 | } |
80 | } |
81 | } |
82 | |
83 | /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic. |
84 | static Operation *createStoreTileSliceIntrinsic( |
85 | RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, |
86 | arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, |
87 | IntegerAttr tileId, Value tileSliceI32) { |
88 | if (layout == arm_sme::TileSliceLayout::Horizontal) { |
89 | switch (type) { |
90 | case arm_sme::ArmSMETileType::ZAB: |
91 | return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( |
92 | loc, maskOp, ptr, tileId, tileSliceI32); |
93 | case arm_sme::ArmSMETileType::ZAH: |
94 | return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( |
95 | loc, maskOp, ptr, tileId, tileSliceI32); |
96 | case arm_sme::ArmSMETileType::ZAS: |
97 | return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( |
98 | loc, maskOp, ptr, tileId, tileSliceI32); |
99 | case arm_sme::ArmSMETileType::ZAD: |
100 | return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( |
101 | loc, maskOp, ptr, tileId, tileSliceI32); |
102 | case arm_sme::ArmSMETileType::ZAQ: |
103 | return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( |
104 | loc, maskOp, ptr, tileId, tileSliceI32); |
105 | } |
106 | } else { |
107 | switch (type) { |
108 | case arm_sme::ArmSMETileType::ZAB: |
109 | return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( |
110 | loc, maskOp, ptr, tileId, tileSliceI32); |
111 | case arm_sme::ArmSMETileType::ZAH: |
112 | return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( |
113 | loc, maskOp, ptr, tileId, tileSliceI32); |
114 | case arm_sme::ArmSMETileType::ZAS: |
115 | return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( |
116 | loc, maskOp, ptr, tileId, tileSliceI32); |
117 | case arm_sme::ArmSMETileType::ZAD: |
118 | return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( |
119 | loc, maskOp, ptr, tileId, tileSliceI32); |
120 | case arm_sme::ArmSMETileType::ZAQ: |
121 | return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( |
122 | loc, maskOp, ptr, tileId, tileSliceI32); |
123 | } |
124 | } |
125 | } |
126 | |
127 | IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { |
128 | auto tileId = op.getTileId(); |
129 | if (!tileId) |
130 | op.emitOpError( |
131 | "expected tile ID to be allocated before conversion to LLVM" ); |
132 | return tileId; |
133 | } |
134 | |
135 | /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is |
136 | /// placed in the first block of the function. |
137 | static memref::AllocaOp |
138 | createAllocaForTile(RewriterBase &rewriter, Location loc, |
139 | FunctionOpInterface func, |
140 | arm_sme::ArmSMETileOpInterface tileOp) { |
141 | RewriterBase::InsertionGuard g(rewriter); |
142 | // Move to the first operation in the function. |
143 | rewriter.setInsertionPointToStart(&func.getBlocks().front()); |
144 | // Create an alloca matching the tile size of the `tileOp`. |
145 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
146 | auto tileElementType = tileOp.getTileType().getElementType(); |
147 | auto memrefType = MemRefType::get( |
148 | {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); |
149 | unsigned minElements = arm_sme::getSMETileSliceMinNumElts(type: tileElementType); |
150 | auto minElementsOp = |
151 | rewriter.create<arith::ConstantIndexOp>(loc, minElements); |
152 | auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp); |
153 | auto alloca = rewriter.create<memref::AllocaOp>( |
154 | loc, memrefType, ValueRange{vectorLen, vectorLen}); |
155 | return alloca; |
156 | } |
157 | |
158 | /// Finds or creates an alloca for a spill of a tile. |
159 | static memref::AllocaOp getOrCreateAllocaForTile( |
160 | RewriterBase &rewriter, Location loc, FunctionOpInterface func, |
161 | arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { |
162 | // Find an alloca at the top of the function tagged with a |
163 | // 'arm_sme.in_memory_tile_id' that matches `tileId`. |
164 | for (auto &op : func.getBlocks().front()) { |
165 | auto alloca = llvm::dyn_cast<memref::AllocaOp>(op); |
166 | if (!alloca) |
167 | continue; |
168 | auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>( |
169 | alloca->getDiscardableAttr(kInMemoryTileIdAttr)); |
170 | if (!inMemoryTileId) |
171 | continue; |
172 | if (inMemoryTileId.getInt() == tileId) |
173 | return alloca; |
174 | } |
175 | // Otherwise, create a new alloca: |
176 | auto alloca = createAllocaForTile(rewriter, loc, func, tileOp); |
177 | alloca->setDiscardableAttr(kInMemoryTileIdAttr, |
178 | rewriter.getI32IntegerAttr(tileId)); |
179 | return alloca; |
180 | } |
181 | |
182 | /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a |
183 | /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning |
184 | /// the op to tile 0, then emitting a full tile swap between ZA and memory |
185 | /// before + after the tile op. |
186 | /// |
187 | /// Example: |
188 | /// |
189 | /// // Note: <IN MEMORY TILE> = tile ID >= 16. |
190 | /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> } |
191 | /// |
192 | /// is converted to: |
193 | /// // At function entry: |
194 | /// %spill = memref.alloca ... : memref<?x?xty> |
195 | /// |
196 | /// // Around op: |
197 | /// scf.for %slice_idx { |
198 | /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> |
199 | /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> |
200 | /// vector.store %slice_to_save, %spill[%slice_idx, %c0] |
201 | /// } |
202 | /// arm_sme.tile_op { tile_id = 0 } |
203 | /// scf.for %slice_idx { |
204 | /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> |
205 | /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> |
206 | /// vector.store %slice_to_save, %spill[%slice_idx, %c0] |
207 | /// } |
208 | /// |
209 | /// Note that these spills/fills are not inserted earlier as concept of a |
210 | /// register, and the need to swap the contents, can't really be represented |
211 | /// correctly at a high level in MLIR. |
212 | /// |
213 | /// TODO: Reduce the spills/reloads to single slices where possible (and omit |
214 | /// redundant reloads). This could be done via a method on the |
215 | /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.: |
216 | /// |
217 | /// `tileOp.getZaUsage()` could return: |
218 | /// |
219 | /// struct ArmSMEOpZAUsage { |
220 | /// enum class Kind { |
221 | /// TileRead, // Omit store after tile operation. |
222 | /// TileWrite, // Omit load before tile operation. |
223 | /// TileReadWrite, // Needs both tile load and store. |
224 | /// SliceRead, // Spill single slice and omit store after operation. |
225 | /// SliceWrite, // Spill single slice and omit load before operation. |
226 | /// SliceReadWrite // Spill single slice. |
227 | /// }; |
228 | /// Value sliceIndex {}; |
229 | /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; |
230 | /// }; |
231 | /// |
232 | struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { |
233 | |
234 | ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, |
235 | const LLVMTypeConverter &typeConverter, |
236 | PatternBenefit benefit) |
237 | : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(), |
238 | typeConverter, benefit) {} |
239 | |
240 | LogicalResult |
241 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
242 | ConversionPatternRewriter &rewriter) const override { |
243 | auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op); |
244 | // Tile has a real (hardware) tile. No spills/reloads required. |
245 | if (!tileOp.isInMemoryTile()) |
246 | return failure(); |
247 | |
248 | // Step 1. Create an alloca for the tile at the top of the function (if one |
249 | // does not already exist). |
250 | auto loc = tileOp.getLoc(); |
251 | auto func = tileOp->getParentOfType<FunctionOpInterface>(); |
252 | auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp, |
253 | tileOp.getTileId().getInt()); |
254 | |
255 | // Step 2. Assign the op a real tile ID. |
256 | // For simplicity, we always use tile 0 (which always exists). |
257 | auto zeroTileId = rewriter.getI32IntegerAttr(0); |
258 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); |
259 | |
260 | VectorType tileVectorType = tileOp.getTileType(); |
261 | auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); |
262 | auto swapInMemoryTileWithSMETileZero = [&] { |
263 | emitFullTileSwap(rewriter, loc, tileAlloca, |
264 | *arm_sme::getSMETileType(tileVectorType), sliceType, |
265 | zeroTileId); |
266 | }; |
267 | |
268 | // Step 3. Emit tile swaps before and after the op. |
269 | // TODO: Reduce the amount spilled to the amount of data the `tileOp` |
270 | // touches (i.e. a single tile slice). |
271 | { |
272 | rewriter.setInsertionPoint(op); |
273 | // Swap the contents of ZA and the in-memory tile before the op. |
274 | swapInMemoryTileWithSMETileZero(); |
275 | rewriter.setInsertionPointAfter(op); |
276 | // Swap the tile back out to memory again after the op. |
277 | swapInMemoryTileWithSMETileZero(); |
278 | } |
279 | |
280 | return success(); |
281 | } |
282 | |
283 | /// Extracts a pointer to a slice of an in-memory tile. |
284 | Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc, |
285 | Value tileMemory, Value sliceIndex) const { |
286 | auto llvmType = getTypeConverter()->convertType(t: tileMemory.getType()); |
287 | auto descriptor = |
288 | rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory); |
289 | auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64); |
290 | auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>( |
291 | loc, rewriter.getI64Type(), sliceIndex); |
292 | return getStridedElementPtr( |
293 | loc, type: llvm::cast<MemRefType>(tileMemory.getType()), |
294 | memRefDesc: descriptor.getResult(0), indices: {sliceIndexI64, zero}, |
295 | rewriter&: static_cast<ConversionPatternRewriter &>(rewriter)); |
296 | } |
297 | |
298 | /// Emits an in-place swap of a slice of a tile in ZA and a slice of a |
299 | /// tile-sized memref (`tileAlloca`). |
300 | void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, |
301 | arm_sme::ArmSMETileType tileType, VectorType sliceType, |
302 | IntegerAttr tileId, Value sliceIndex) const { |
303 | // Cast the slice index to an i32. |
304 | auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( |
305 | loc, rewriter.getI32Type(), sliceIndex); |
306 | // Create an all-true predicate for the slice. |
307 | auto predicateType = sliceType.clone(rewriter.getI1Type()); |
308 | auto allTruePredicate = rewriter.create<arith::ConstantOp>( |
309 | loc, DenseElementsAttr::get(predicateType, true)); |
310 | // Create padding vector (never used due to all-true predicate). |
311 | auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType); |
312 | // Get a pointer to the current slice. |
313 | auto slicePtr = |
314 | getInMemoryTileSlicePtr(rewriter, loc, tileMemory: tileAlloca, sliceIndex); |
315 | // Read the value of the current slice from ZA. |
316 | auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>( |
317 | loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); |
318 | // Load the new tile slice back from memory into ZA. |
319 | createLoadTileSliceIntrinsic( |
320 | rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, |
321 | allTruePredicate, slicePtr, tileId, sliceIndexI32); |
322 | // Store the current tile slice to memory. |
323 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
324 | rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca, |
325 | ValueRange{sliceIndex, zero}); |
326 | } |
327 | |
328 | /// Emits a full in-place swap of the contents of a tile in ZA and a |
329 | /// tile-sized memref (`tileAlloca`). |
330 | void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, |
331 | arm_sme::ArmSMETileType tileType, VectorType sliceType, |
332 | IntegerAttr tileId) const { |
333 | RewriterBase::InsertionGuard guard(rewriter); |
334 | // Create an scf.for over all tile slices. |
335 | auto minNumElts = |
336 | rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0)); |
337 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
338 | auto upperBound = rewriter.create<arith::MulIOp>( |
339 | loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc)); |
340 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
341 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
342 | // Emit a swap for each tile slice. |
343 | rewriter.setInsertionPointToStart(forOp.getBody()); |
344 | auto sliceIndex = forOp.getInductionVar(); |
345 | emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId, |
346 | sliceIndex); |
347 | } |
348 | }; |
349 | |
350 | enum class RequiresSpillsAndFills { Yes, No }; |
351 | |
352 | /// Base class for ArmSME to LLVM conversion patterns. By default, this adds |
353 | /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be |
354 | /// disabled by setting the `requiresSpillsAndFills` template parameter to |
355 | /// `RequiresSpillsAndFills::No`. |
356 | template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills = |
357 | RequiresSpillsAndFills::Yes> |
358 | struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> { |
359 | using ArmSMEOp = SourceOp; |
360 | using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; |
361 | |
362 | static constexpr bool requiresSpillsAndFillsConversion() { |
363 | return requiresSpillsAndFills == RequiresSpillsAndFills::Yes; |
364 | } |
365 | }; |
366 | |
367 | template <typename Pattern> |
368 | static void addArmSMEConversionPattern(RewritePatternSet &patterns, |
369 | LLVMTypeConverter const &typeConverter) { |
370 | // Register spills/fills for ops that implement the |
371 | // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to |
372 | // `RequiresSpillsAndFills::Yes`. |
373 | if constexpr (Pattern::requiresSpillsAndFillsConversion() && |
374 | std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait< |
375 | typename Pattern::ArmSMEOp>, |
376 | typename Pattern::ArmSMEOp>) { |
377 | // Add spill/fill conversions with a very high benefit to ensure |
378 | // they are lowered first. |
379 | patterns.add<ConvertArmSMESpillsAndFillsToLLVM>( |
380 | Pattern::ArmSMEOp::getOperationName(), typeConverter, |
381 | /*benefit=*/1337); |
382 | } |
383 | patterns.add<Pattern>(typeConverter); |
384 | } |
385 | |
386 | /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. |
387 | template <typename... Patterns> |
388 | static void |
389 | addArmSMEConversionPatterns(RewritePatternSet &patterns, |
390 | LLVMTypeConverter const &typeConverter) { |
391 | (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...); |
392 | } |
393 | |
394 | struct GetTileConversion |
395 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp, |
396 | RequiresSpillsAndFills::No> { |
397 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
398 | |
399 | LogicalResult |
400 | matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor, |
401 | ConversionPatternRewriter &rewriter) const override { |
402 | rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>( |
403 | getTile, getTile.getTileType()); |
404 | return success(); |
405 | } |
406 | }; |
407 | |
408 | /// Lower 'arm_sme.zero' to SME intrinsics. |
409 | /// |
410 | /// BEFORE: |
411 | /// ```mlir |
412 | /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32> |
413 | /// ``` |
414 | /// |
415 | /// AFTER: |
416 | /// ```mlir |
417 | /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> () |
418 | /// %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32> |
419 | /// ``` |
420 | /// |
421 | /// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away |
422 | /// once all ArmSME ops have been converted to LLVM intrinsics. |
423 | struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { |
424 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
425 | |
426 | LogicalResult |
427 | matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, |
428 | ConversionPatternRewriter &rewriter) const override { |
429 | auto loc = zero.getLoc(); |
430 | |
431 | auto tileId = getTileIdOrError(zero); |
432 | if (!tileId) |
433 | return failure(); |
434 | |
435 | // Get the base mask for tile based on the element size. |
436 | // The base mask is just the mask to zero the first tile (of a size). |
437 | // These masks are derived from: |
438 | // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- |
439 | arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType(); |
440 | auto baseMaskForSize = [&] { |
441 | switch (tileType) { |
442 | case arm_sme::ArmSMETileType::ZAB: |
443 | // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight |
444 | // 64-bit element tiles named ZA0.D to ZA7.D. |
445 | return 0b1111'1111; |
446 | case arm_sme::ArmSMETileType::ZAH: |
447 | // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit |
448 | // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left |
449 | // once for ZA1.H. |
450 | return 0b0101'0101; |
451 | case arm_sme::ArmSMETileType::ZAS: |
452 | // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit |
453 | // element tiles named ZA0.D and ZA4.D. |
454 | // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S. |
455 | return 0b0001'0001; |
456 | case arm_sme::ArmSMETileType::ZAD: |
457 | // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires |
458 | // setting the bit for that tile. |
459 | return 0b0000'0001; |
460 | default: |
461 | llvm_unreachable("bad element size" ); |
462 | } |
463 | }(); |
464 | |
465 | // The actual mask is just the base mask shifted by the tile ID. |
466 | // This will be folded to a constant after tile allocation. |
467 | // |
468 | // The shift is just derived from the layout of the tiles, and that the tile |
469 | // ID is the index of the tile. For example, looking at the 32-bit ZAx.S |
470 | // tiles: |
471 | // |
472 | // ZA0.S = ZA0.D and ZA4.D |
473 | // * Tile ID -> 0 |
474 | // * Mask -> 00010001 = (00010001 << 0) |
475 | // ZA1.S = ZA1.D and ZA5.D |
476 | // * Tile ID -> 1 |
477 | // * Mask -> 00100010 = (00010001 << 1) |
478 | // ZA2.S = ZA2.D and ZA6.D |
479 | // * Tile ID -> 2 |
480 | // * Mask -> 01000100 = (00010001 << 2) |
481 | // ZA3.S = ZA3.D and ZA7.D |
482 | // * Tile ID -> 3 |
483 | // * Mask -> 10001000 = (00010001 << 3) |
484 | // |
485 | // This holds for all tile sizes. |
486 | int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); |
487 | rewriter.create<arm_sme::aarch64_sme_zero>( |
488 | loc, rewriter.getI32IntegerAttr(zeroMask)); |
489 | |
490 | // Create a placeholder op to preserve dataflow. |
491 | rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>( |
492 | zero, zero.getVectorType()); |
493 | |
494 | return success(); |
495 | } |
496 | }; |
497 | |
498 | /// Lower `arm_sme.load_tile_slice` to SME intrinsics. |
499 | struct LoadTileSliceConversion |
500 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> { |
501 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
502 | |
503 | LogicalResult |
504 | matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, |
505 | arm_sme::LoadTileSliceOp::Adaptor adaptor, |
506 | ConversionPatternRewriter &rewriter) const override { |
507 | auto loc = loadTileSliceOp.getLoc(); |
508 | auto tileId = getTileIdOrError(loadTileSliceOp); |
509 | if (!tileId) |
510 | return failure(); |
511 | |
512 | Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(), |
513 | adaptor.getBase(), |
514 | adaptor.getIndices(), rewriter); |
515 | |
516 | auto tileSlice = loadTileSliceOp.getTileSliceIndex(); |
517 | |
518 | // Cast tile slice to i32 for intrinsic. |
519 | auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( |
520 | loc, rewriter.getI32Type(), tileSlice); |
521 | |
522 | // Create all active predicate mask. |
523 | auto maskOp = loadTileSliceOp.getMask(); |
524 | |
525 | auto tileVectorType = loadTileSliceOp.getVectorType(); |
526 | arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); |
527 | arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); |
528 | |
529 | // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice. |
530 | createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr, |
531 | tileId, tileSliceI32); |
532 | |
533 | // The load intrinsics have no result, replace 'arm_sme.tile_load' with |
534 | // the input tile to preserve dataflow. |
535 | rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile()); |
536 | |
537 | return success(); |
538 | } |
539 | }; |
540 | |
541 | /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. |
542 | struct StoreTileSliceConversion |
543 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> { |
544 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
545 | |
546 | LogicalResult |
547 | matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, |
548 | arm_sme::StoreTileSliceOp::Adaptor adaptor, |
549 | ConversionPatternRewriter &rewriter) const override { |
550 | auto loc = storeTileSliceOp.getLoc(); |
551 | auto tileVectorType = storeTileSliceOp.getVectorType(); |
552 | |
553 | auto tileId = getTileIdOrError(storeTileSliceOp); |
554 | if (!tileId) |
555 | return failure(); |
556 | |
557 | // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. |
558 | Value ptr = this->getStridedElementPtr( |
559 | loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), |
560 | adaptor.getIndices(), rewriter); |
561 | |
562 | auto tileSlice = storeTileSliceOp.getTileSliceIndex(); |
563 | |
564 | // Cast tile slice to i32 for intrinsic. |
565 | auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( |
566 | loc, rewriter.getI32Type(), tileSlice); |
567 | |
568 | auto maskOp = storeTileSliceOp.getMask(); |
569 | |
570 | arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout(); |
571 | arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); |
572 | |
573 | rewriter.replaceOp(storeTileSliceOp, |
574 | createStoreTileSliceIntrinsic(rewriter, loc, tileType, |
575 | layout, maskOp, ptr, |
576 | tileId, tileSliceI32)); |
577 | |
578 | return success(); |
579 | } |
580 | }; |
581 | |
582 | /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. |
583 | struct MoveVectorToTileSliceConversion |
584 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> { |
585 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
586 | |
587 | LogicalResult |
588 | matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp, |
589 | arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor, |
590 | ConversionPatternRewriter &rewriter) const override { |
591 | auto loc = moveVectorToTileSliceOp.getLoc(); |
592 | auto tileType = moveVectorToTileSliceOp.getTileType(); |
593 | |
594 | auto tileId = getTileIdOrError(moveVectorToTileSliceOp); |
595 | if (!tileId) |
596 | return failure(); |
597 | |
598 | auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex(); |
599 | |
600 | // Cast tile slice from index to i32 for intrinsic. |
601 | auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( |
602 | loc, rewriter.getI32Type(), tileSlice); |
603 | |
604 | // Create all active predicate mask. |
605 | auto one = rewriter.create<arith::ConstantOp>( |
606 | loc, rewriter.getI1Type(), |
607 | rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); |
608 | auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), |
609 | /*scalableDims=*/{true}); |
610 | auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); |
611 | |
612 | // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. |
613 | switch (moveVectorToTileSliceOp.getLayout()) { |
614 | case arm_sme::TileSliceLayout::Horizontal: |
615 | rewriter.create<arm_sme::aarch64_sme_write_horiz>( |
616 | loc, tileId, tileSliceI32, allActiveMask, |
617 | moveVectorToTileSliceOp.getVector()); |
618 | break; |
619 | case arm_sme::TileSliceLayout::Vertical: |
620 | rewriter.create<arm_sme::aarch64_sme_write_vert>( |
621 | loc, tileId, tileSliceI32, allActiveMask, |
622 | moveVectorToTileSliceOp.getVector()); |
623 | break; |
624 | } |
625 | |
626 | // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with |
627 | // the input tile to preserve dataflow. |
628 | rewriter.replaceOp(moveVectorToTileSliceOp, |
629 | moveVectorToTileSliceOp.getTile()); |
630 | |
631 | return success(); |
632 | } |
633 | }; |
634 | |
635 | /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. |
636 | struct MoveTileSliceToVectorConversion |
637 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> { |
638 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
639 | |
640 | LogicalResult |
641 | matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector, |
642 | OpAdaptor, |
643 | ConversionPatternRewriter &rewriter) const override { |
644 | auto loc = moveTileSliceToVector.getLoc(); |
645 | auto sliceType = moveTileSliceToVector.getSliceType(); |
646 | auto sliceIndex = moveTileSliceToVector.getTileSliceIndex(); |
647 | |
648 | auto tileId = getTileIdOrError(moveTileSliceToVector); |
649 | if (!tileId) |
650 | return failure(); |
651 | |
652 | // Create an 'all true' predicate for the tile slice. |
653 | auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); |
654 | auto allTruePredicate = rewriter.create<arith::ConstantOp>( |
655 | loc, DenseElementsAttr::get(predicateType, true)); |
656 | |
657 | // Zero destination/fallback for tile slice extraction. |
658 | auto zeroVector = rewriter.create<arith::ConstantOp>( |
659 | loc, sliceType, rewriter.getZeroAttr(sliceType)); |
660 | |
661 | // Cast tile slice from index to i32 for intrinsic. |
662 | auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( |
663 | loc, rewriter.getI32Type(), sliceIndex); |
664 | |
665 | // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. |
666 | switch (moveTileSliceToVector.getLayout()) { |
667 | case arm_sme::TileSliceLayout::Horizontal: |
668 | rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>( |
669 | moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, |
670 | tileId, sliceIndexI32); |
671 | break; |
672 | case arm_sme::TileSliceLayout::Vertical: |
673 | rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>( |
674 | moveTileSliceToVector, sliceType, zeroVector, allTruePredicate, |
675 | tileId, sliceIndexI32); |
676 | break; |
677 | } |
678 | |
679 | return success(); |
680 | } |
681 | }; |
682 | |
683 | /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics. |
684 | /// |
685 | /// Example: |
686 | /// |
687 | /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) |
688 | /// : vector<[4]xf32>, vector<[4]xf32> |
689 | /// |
690 | /// is converted to: |
691 | /// |
692 | /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}> |
693 | /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, |
694 | /// vector<[4]xf32>) -> () |
695 | /// |
696 | /// Currently only supports FMOPA and BFMOPA (non-widening). |
697 | struct OuterProductOpConversion |
698 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> { |
699 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
700 | |
701 | LogicalResult |
702 | matchAndRewrite(arm_sme::OuterProductOp outerProductOp, |
703 | arm_sme::OuterProductOp::Adaptor adaptor, |
704 | ConversionPatternRewriter &rewriter) const override { |
705 | auto tileId = getTileIdOrError(outerProductOp); |
706 | if (!tileId) |
707 | return failure(); |
708 | |
709 | auto isSupportedType = [](VectorType vectorType) { |
710 | // TODO: the FP outer product instruction variants are predicated on |
711 | // different features [1]: |
712 | // |
713 | // * FMOPA (non-widening) |
714 | // * half-precision - +sme2p1,+sme-f16f16 |
715 | // * single-precision - +sme |
716 | // * double-precision - +sme-f64f64 |
717 | // * BFMOPA |
718 | // * half-precision - +sme2p1,+b16b16 |
719 | // |
720 | // It should be possible to control lowering based on target features. |
721 | // [1] |
722 | // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile |
723 | if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) |
724 | return false; |
725 | |
726 | auto elementType = vectorType.getElementType(); |
727 | |
728 | if (!elementType.isF16() && !elementType.isBF16() && |
729 | !elementType.isF32() && !elementType.isF64()) |
730 | return false; |
731 | |
732 | unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits / |
733 | vectorType.getElementTypeBitWidth(); |
734 | return vectorType.getShape() == |
735 | ArrayRef<int64_t>({minNumElts, minNumElts}); |
736 | }; |
737 | |
738 | // TODO: Support CombiningKind::Sub for outer products. |
739 | if (outerProductOp.getKind() != arm_sme::CombiningKind::Add) |
740 | return outerProductOp.emitError("unsupported kind" ); |
741 | |
742 | auto resultVectorType = outerProductOp.getResultType(); |
743 | if (!isSupportedType(resultVectorType)) |
744 | return outerProductOp.emitError("unsupported type" ); |
745 | |
746 | auto loc = outerProductOp.getLoc(); |
747 | |
748 | Value acc = outerProductOp.getAcc(); |
749 | if (!acc) |
750 | // Initalize accumulator with zero. |
751 | acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>( |
752 | rewriter, loc, resultVectorType); |
753 | |
754 | Value lhsMask = outerProductOp.getLhsMask(); |
755 | Value rhsMask = outerProductOp.getRhsMask(); |
756 | |
757 | if (!lhsMask || !rhsMask) { |
758 | auto predTy = |
759 | outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); |
760 | Value allActiveMask = rewriter.create<arith::ConstantOp>( |
761 | loc, DenseElementsAttr::get(predTy, true)); |
762 | lhsMask = allActiveMask; |
763 | rhsMask = allActiveMask; |
764 | } |
765 | |
766 | // Create 'arm_sme.intr.mopa' outer product intrinsic. |
767 | rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask, |
768 | outerProductOp.getLhs(), |
769 | outerProductOp.getRhs()); |
770 | |
771 | // The outerproduct intrinsics have no result, replace |
772 | // 'arm_sme.outerproduct' with the input tile to preserve dataflow. |
773 | rewriter.replaceOp(outerProductOp, acc); |
774 | |
775 | return success(); |
776 | } |
777 | }; |
778 | |
779 | /// Lower 2-way and 4-way widening outer products to intrinsics. |
780 | template <class OuterProductWideningOp, class OuterProductWideningIntrOp> |
781 | struct OuterProductWideningOpConversion |
782 | : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> { |
783 | using ConvertArmSMEOpToLLVMPattern< |
784 | OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern; |
785 | |
786 | LogicalResult |
787 | matchAndRewrite(OuterProductWideningOp op, |
788 | typename OuterProductWideningOp::Adaptor adaptor, |
789 | ConversionPatternRewriter &rewriter) const override { |
790 | auto tileId = getTileIdOrError(op); |
791 | if (!tileId) |
792 | return failure(); |
793 | |
794 | Value acc = op.getAcc(); |
795 | if (!acc) |
796 | // Initalize accumulator with zero. |
797 | acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>( |
798 | rewriter, op.getLoc(), op.getResultType()); |
799 | |
800 | Value lhsMask = op.getLhsMask(); |
801 | Value rhsMask = op.getRhsMask(); |
802 | if (!lhsMask || !rhsMask) { |
803 | auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); |
804 | Value allActiveMask = rewriter.create<arith::ConstantOp>( |
805 | op.getLoc(), DenseElementsAttr::get(predTy, true)); |
806 | lhsMask = allActiveMask; |
807 | rhsMask = allActiveMask; |
808 | } |
809 | |
810 | rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask, |
811 | rhsMask, adaptor.getLhs(), |
812 | adaptor.getRhs()); |
813 | |
814 | // The outerproduct intrinsics have no result, replace |
815 | // 'arm_sme.outerproduct' with the input tile to preserve dataflow. |
816 | rewriter.replaceOp(op, acc); |
817 | |
818 | return success(); |
819 | } |
820 | }; |
821 | |
822 | /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. |
823 | /// |
824 | /// Example: |
825 | /// |
826 | /// %0 = arm_sme.streaming_vl <half> |
827 | /// |
828 | /// is converted to: |
829 | /// |
830 | /// %cnt = "arm_sme.intr.cntsh"() : () -> i64 |
831 | /// %0 = arith.index_cast %cnt : i64 to index |
832 | /// |
833 | struct StreamingVLOpConversion |
834 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp, |
835 | RequiresSpillsAndFills::No> { |
836 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
837 | |
838 | LogicalResult |
839 | matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp, |
840 | arm_sme::StreamingVLOp::Adaptor adaptor, |
841 | ConversionPatternRewriter &rewriter) const override { |
842 | auto loc = streamingVlOp.getLoc(); |
843 | auto i64Type = rewriter.getI64Type(); |
844 | auto *intrOp = [&]() -> Operation * { |
845 | switch (streamingVlOp.getTypeSize()) { |
846 | case arm_sme::TypeSize::Byte: |
847 | return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type); |
848 | case arm_sme::TypeSize::Half: |
849 | return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type); |
850 | case arm_sme::TypeSize::Word: |
851 | return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type); |
852 | case arm_sme::TypeSize::Double: |
853 | return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type); |
854 | } |
855 | }(); |
856 | rewriter.replaceOpWithNewOp<arith::IndexCastOp>( |
857 | streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0)); |
858 | return success(); |
859 | } |
860 | }; |
861 | |
862 | } // namespace |
863 | |
864 | namespace { |
865 | |
866 | struct ConvertArmSMEToLLVMPass |
867 | : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> { |
868 | void runOnOperation() override { |
869 | LLVMConversionTarget target(getContext()); |
870 | RewritePatternSet patterns(&getContext()); |
871 | LLVMTypeConverter converter(&getContext()); |
872 | configureArmSMEToLLVMConversionLegality(target); |
873 | populateArmSMEToLLVMConversionPatterns(converter, patterns); |
874 | |
875 | if (failed(applyPartialConversion(getOperation(), target, |
876 | std::move(patterns)))) |
877 | signalPassFailure(); |
878 | } |
879 | }; |
880 | |
881 | } // namespace |
882 | |
883 | void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { |
884 | target.addIllegalDialect<arm_sme::ArmSMEDialect>(); |
885 | target.addLegalOp< |
886 | arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero, |
887 | arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz, |
888 | arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz, |
889 | arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz, |
890 | arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, |
891 | arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, |
892 | arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert, |
893 | arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert, |
894 | arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert, |
895 | arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert, |
896 | arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert, |
897 | arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, |
898 | arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, |
899 | arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, |
900 | arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide, |
901 | arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide, |
902 | arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide, |
903 | arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32, |
904 | arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32, |
905 | arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide, |
906 | arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide, |
907 | arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh, |
908 | arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>(); |
909 | target.addLegalDialect<arith::ArithDialect, |
910 | /* The following are used to lower tile spills/fills */ |
911 | vector::VectorDialect, scf::SCFDialect, |
912 | memref::MemRefDialect>(); |
913 | target.addLegalOp<UnrealizedConversionCastOp>(); |
914 | } |
915 | |
916 | void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, |
917 | RewritePatternSet &patterns) { |
918 | converter.addConversion(callback: [&](VectorType type) -> std::optional<Type> { |
919 | // There's no LLVM type for SME tiles, but after lowering to intrinsics all |
920 | // SME vector types should be eliminated. |
921 | if (arm_sme::isValidSMETileVectorType(vType: type)) |
922 | return type; |
923 | return std::nullopt; |
924 | }); |
925 | |
926 | addArmSMEConversionPatterns< |
927 | LoadTileSliceConversion, MoveTileSliceToVectorConversion, |
928 | MoveVectorToTileSliceConversion, StoreTileSliceConversion, |
929 | StreamingVLOpConversion, OuterProductOpConversion, |
930 | OuterProductWideningOpConversion<arm_sme::FMopa2WayOp, |
931 | arm_sme::aarch64_sme_mopa_wide>, |
932 | OuterProductWideningOpConversion<arm_sme::FMops2WayOp, |
933 | arm_sme::aarch64_sme_mops_wide>, |
934 | OuterProductWideningOpConversion<arm_sme::SMopa2WayOp, |
935 | arm_sme::aarch64_sme_smopa_za32>, |
936 | OuterProductWideningOpConversion<arm_sme::SMops2WayOp, |
937 | arm_sme::aarch64_sme_smops_za32>, |
938 | OuterProductWideningOpConversion<arm_sme::UMopa2WayOp, |
939 | arm_sme::aarch64_sme_umopa_za32>, |
940 | OuterProductWideningOpConversion<arm_sme::UMops2WayOp, |
941 | arm_sme::aarch64_sme_umops_za32>, |
942 | OuterProductWideningOpConversion<arm_sme::SMopa4WayOp, |
943 | arm_sme::aarch64_sme_smopa_wide>, |
944 | OuterProductWideningOpConversion<arm_sme::SMops4WayOp, |
945 | arm_sme::aarch64_sme_smops_wide>, |
946 | OuterProductWideningOpConversion<arm_sme::UMopa4WayOp, |
947 | arm_sme::aarch64_sme_umopa_wide>, |
948 | OuterProductWideningOpConversion<arm_sme::UMops4WayOp, |
949 | arm_sme::aarch64_sme_umops_wide>, |
950 | OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp, |
951 | arm_sme::aarch64_sme_sumopa_wide>, |
952 | OuterProductWideningOpConversion<arm_sme::SuMops4WayOp, |
953 | arm_sme::aarch64_sme_sumops_wide>, |
954 | OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp, |
955 | arm_sme::aarch64_sme_usmopa_wide>, |
956 | OuterProductWideningOpConversion<arm_sme::UsMops4WayOp, |
957 | arm_sme::aarch64_sme_usmops_wide>, |
958 | ZeroOpConversion, GetTileConversion>(patterns, converter); |
959 | } |
960 | |
961 | std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() { |
962 | return std::make_unique<ConvertArmSMEToLLVMPass>(); |
963 | } |
964 | |