1 | //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering of ArmSME operations to LLVM intrinsics. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" |
14 | |
15 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
16 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
19 | #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
21 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
23 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
24 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
26 | #include "mlir/Pass/Pass.h" |
27 | #include "mlir/Transforms/DialectConversion.h" |
28 | #include "llvm/ADT/ScopeExit.h" |
29 | |
30 | namespace mlir { |
31 | #define GEN_PASS_DEF_CONVERTARMSMETOLLVM |
32 | #include "mlir/Conversion/Passes.h.inc" |
33 | } // namespace mlir |
34 | |
35 | using namespace mlir; |
36 | |
37 | namespace { |
38 | |
39 | static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id"); |
40 | |
41 | /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. |
42 | static Operation *createLoadTileSliceIntrinsic( |
43 | RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, |
44 | arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, |
45 | IntegerAttr tileId, Value tileSliceI32) { |
46 | if (layout == arm_sme::TileSliceLayout::Horizontal) { |
47 | switch (type) { |
48 | case arm_sme::ArmSMETileType::ZAB: |
49 | return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( |
50 | loc, maskOp, ptr, tileId, tileSliceI32); |
51 | case arm_sme::ArmSMETileType::ZAH: |
52 | return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( |
53 | loc, maskOp, ptr, tileId, tileSliceI32); |
54 | case arm_sme::ArmSMETileType::ZAS: |
55 | return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( |
56 | loc, maskOp, ptr, tileId, tileSliceI32); |
57 | case arm_sme::ArmSMETileType::ZAD: |
58 | return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( |
59 | loc, maskOp, ptr, tileId, tileSliceI32); |
60 | case arm_sme::ArmSMETileType::ZAQ: |
61 | return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( |
62 | loc, maskOp, ptr, tileId, tileSliceI32); |
63 | } |
64 | } else { |
65 | switch (type) { |
66 | case arm_sme::ArmSMETileType::ZAB: |
67 | return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( |
68 | loc, maskOp, ptr, tileId, tileSliceI32); |
69 | case arm_sme::ArmSMETileType::ZAH: |
70 | return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( |
71 | loc, maskOp, ptr, tileId, tileSliceI32); |
72 | case arm_sme::ArmSMETileType::ZAS: |
73 | return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( |
74 | loc, maskOp, ptr, tileId, tileSliceI32); |
75 | case arm_sme::ArmSMETileType::ZAD: |
76 | return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( |
77 | loc, maskOp, ptr, tileId, tileSliceI32); |
78 | case arm_sme::ArmSMETileType::ZAQ: |
79 | return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( |
80 | loc, maskOp, ptr, tileId, tileSliceI32); |
81 | break; |
82 | } |
83 | } |
84 | llvm_unreachable("unknown type in createLoadTileSliceIntrinsic"); |
85 | } |
86 | |
87 | /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic. |
88 | static Operation *createStoreTileSliceIntrinsic( |
89 | RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, |
90 | arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, |
91 | IntegerAttr tileId, Value tileSliceI32) { |
92 | if (layout == arm_sme::TileSliceLayout::Horizontal) { |
93 | switch (type) { |
94 | case arm_sme::ArmSMETileType::ZAB: |
95 | return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( |
96 | loc, maskOp, ptr, tileId, tileSliceI32); |
97 | case arm_sme::ArmSMETileType::ZAH: |
98 | return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( |
99 | loc, maskOp, ptr, tileId, tileSliceI32); |
100 | case arm_sme::ArmSMETileType::ZAS: |
101 | return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( |
102 | loc, maskOp, ptr, tileId, tileSliceI32); |
103 | case arm_sme::ArmSMETileType::ZAD: |
104 | return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( |
105 | loc, maskOp, ptr, tileId, tileSliceI32); |
106 | case arm_sme::ArmSMETileType::ZAQ: |
107 | return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( |
108 | loc, maskOp, ptr, tileId, tileSliceI32); |
109 | } |
110 | } else { |
111 | switch (type) { |
112 | case arm_sme::ArmSMETileType::ZAB: |
113 | return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( |
114 | loc, maskOp, ptr, tileId, tileSliceI32); |
115 | case arm_sme::ArmSMETileType::ZAH: |
116 | return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( |
117 | loc, maskOp, ptr, tileId, tileSliceI32); |
118 | case arm_sme::ArmSMETileType::ZAS: |
119 | return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( |
120 | loc, maskOp, ptr, tileId, tileSliceI32); |
121 | case arm_sme::ArmSMETileType::ZAD: |
122 | return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( |
123 | loc, maskOp, ptr, tileId, tileSliceI32); |
124 | case arm_sme::ArmSMETileType::ZAQ: |
125 | return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( |
126 | loc, maskOp, ptr, tileId, tileSliceI32); |
127 | } |
128 | } |
129 | llvm_unreachable("unknown type in createStoreTileSliceIntrinsic"); |
130 | } |
131 | |
132 | IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { |
133 | auto tileId = op.getTileId(); |
134 | if (!tileId) |
135 | op.emitOpError( |
136 | "expected tile ID to be allocated before conversion to LLVM"); |
137 | return tileId; |
138 | } |
139 | |
140 | /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is |
141 | /// placed in the first block of the function. |
142 | static memref::AllocaOp |
143 | createAllocaForTile(RewriterBase &rewriter, Location loc, |
144 | FunctionOpInterface func, |
145 | arm_sme::ArmSMETileOpInterface tileOp) { |
146 | RewriterBase::InsertionGuard g(rewriter); |
147 | // Move to the first operation in the function. |
148 | rewriter.setInsertionPointToStart(&func.getBlocks().front()); |
149 | // Create an alloca matching the tile size of the `tileOp`. |
150 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
151 | auto tileElementType = tileOp.getTileType().getElementType(); |
152 | auto memrefType = MemRefType::get( |
153 | {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); |
154 | unsigned minElements = arm_sme::getSMETileSliceMinNumElts(type: tileElementType); |
155 | auto minElementsOp = |
156 | rewriter.create<arith::ConstantIndexOp>(loc, minElements); |
157 | auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp); |
158 | auto alloca = rewriter.create<memref::AllocaOp>( |
159 | loc, memrefType, ValueRange{vectorLen, vectorLen}); |
160 | return alloca; |
161 | } |
162 | |
163 | /// Finds or creates an alloca for a spill of a tile. |
164 | static memref::AllocaOp getOrCreateAllocaForTile( |
165 | RewriterBase &rewriter, Location loc, FunctionOpInterface func, |
166 | arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { |
167 | // Find an alloca at the top of the function tagged with a |
168 | // 'arm_sme.in_memory_tile_id' that matches `tileId`. |
169 | for (auto &op : func.getBlocks().front()) { |
170 | auto alloca = llvm::dyn_cast<memref::AllocaOp>(op); |
171 | if (!alloca) |
172 | continue; |
173 | auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>( |
174 | alloca->getDiscardableAttr(kInMemoryTileIdAttr)); |
175 | if (!inMemoryTileId) |
176 | continue; |
177 | if (inMemoryTileId.getInt() == tileId) |
178 | return alloca; |
179 | } |
180 | // Otherwise, create a new alloca: |
181 | auto alloca = createAllocaForTile(rewriter, loc, func, tileOp); |
182 | alloca->setDiscardableAttr(kInMemoryTileIdAttr, |
183 | rewriter.getI32IntegerAttr(tileId)); |
184 | return alloca; |
185 | } |
186 | |
187 | /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a |
188 | /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning |
189 | /// the op to tile 0, then emitting a full tile swap between ZA and memory |
190 | /// before + after the tile op. |
191 | /// |
192 | /// Example: |
193 | /// |
194 | /// // Note: <IN MEMORY TILE> = tile ID >= 16. |
195 | /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> } |
196 | /// |
197 | /// is converted to: |
198 | /// // At function entry: |
199 | /// %spill = memref.alloca ... : memref<?x?xty> |
200 | /// |
201 | /// // Around op: |
202 | /// scf.for %slice_idx { |
203 | /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> |
204 | /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> |
205 | /// vector.store %slice_to_save, %spill[%slice_idx, %c0] |
206 | /// } |
207 | /// arm_sme.tile_op { tile_id = 0 } |
208 | /// scf.for %slice_idx { |
209 | /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> |
210 | /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> |
211 | /// vector.store %slice_to_save, %spill[%slice_idx, %c0] |
212 | /// } |
213 | /// |
214 | /// Note that these spills/fills are not inserted earlier as concept of a |
215 | /// register, and the need to swap the contents, can't really be represented |
216 | /// correctly at a high level in MLIR. |
217 | /// |
218 | /// TODO: Reduce the spills/reloads to single slices where possible (and omit |
219 | /// redundant reloads). This could be done via a method on the |
220 | /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.: |
221 | /// |
222 | /// `tileOp.getZaUsage()` could return: |
223 | /// |
224 | /// struct ArmSMEOpZAUsage { |
225 | /// enum class Kind { |
226 | /// TileRead, // Omit store after tile operation. |
227 | /// TileWrite, // Omit load before tile operation. |
228 | /// TileReadWrite, // Needs both tile load and store. |
229 | /// SliceRead, // Spill single slice and omit store after operation. |
230 | /// SliceWrite, // Spill single slice and omit load before operation. |
231 | /// SliceReadWrite // Spill single slice. |
232 | /// }; |
233 | /// Value sliceIndex {}; |
234 | /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; |
235 | /// }; |
236 | /// |
237 | struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { |
238 | |
239 | ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, |
240 | const LLVMTypeConverter &typeConverter, |
241 | PatternBenefit benefit) |
242 | : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(), |
243 | typeConverter, benefit) {} |
244 | |
245 | LogicalResult |
246 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
247 | ConversionPatternRewriter &rewriter) const override { |
248 | auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op); |
249 | // Tile has a real (hardware) tile. No spills/reloads required. |
250 | if (!tileOp.isInMemoryTile()) |
251 | return failure(); |
252 | |
253 | tileOp->emitWarning( |
254 | "failed to allocate SME virtual tile to operation, tile value will go " |
255 | "through memory, expect degraded performance"); |
256 | |
257 | // Step 1. Create an alloca for the tile at the top of the function (if one |
258 | // does not already exist). |
259 | auto loc = tileOp.getLoc(); |
260 | auto func = tileOp->getParentOfType<FunctionOpInterface>(); |
261 | auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp, |
262 | tileOp.getTileId().getInt()); |
263 | |
264 | // Step 2. Assign the op a real tile ID. |
265 | // For simplicity, we always use tile 0 (which always exists). |
266 | auto zeroTileId = rewriter.getI32IntegerAttr(0); |
267 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); |
268 | |
269 | VectorType tileVectorType = tileOp.getTileType(); |
270 | auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); |
271 | auto swapInMemoryTileWithSMETileZero = [&] { |
272 | emitFullTileSwap(rewriter, loc, tileAlloca, |
273 | *arm_sme::getSMETileType(tileVectorType), sliceType, |
274 | zeroTileId); |
275 | }; |
276 | |
277 | // Step 3. Emit tile swaps before and after the op. |
278 | // TODO: Reduce the amount spilled to the amount of data the `tileOp` |
279 | // touches (i.e. a single tile slice). |
280 | { |
281 | rewriter.setInsertionPoint(op); |
282 | // Swap the contents of ZA and the in-memory tile before the op. |
283 | swapInMemoryTileWithSMETileZero(); |
284 | rewriter.setInsertionPointAfter(op); |
285 | // Swap the tile back out to memory again after the op. |
286 | swapInMemoryTileWithSMETileZero(); |
287 | } |
288 | |
289 | return success(); |
290 | } |
291 | |
292 | /// Extracts a pointer to a slice of an in-memory tile. |
293 | Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc, |
294 | Value tileMemory, Value sliceIndex) const { |
295 | auto llvmType = getTypeConverter()->convertType(t: tileMemory.getType()); |
296 | auto descriptor = |
297 | rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory); |
298 | auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64); |
299 | auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>( |
300 | loc, rewriter.getI64Type(), sliceIndex); |
301 | return getStridedElementPtr( |
302 | static_cast<ConversionPatternRewriter &>(rewriter), loc, |
303 | llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0), |
304 | {sliceIndexI64, zero}); |
305 | } |
306 | |
307 | /// Emits an in-place swap of a slice of a tile in ZA and a slice of a |
308 | /// tile-sized memref (`tileAlloca`). |
309 | void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, |
310 | arm_sme::ArmSMETileType tileType, VectorType sliceType, |
311 | IntegerAttr tileId, Value sliceIndex) const { |
312 | // Cast the slice index to an i32. |
313 | auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( |
314 | loc, rewriter.getI32Type(), sliceIndex); |
315 | // Create an all-true predicate for the slice. |
316 | auto predicateType = sliceType.clone(rewriter.getI1Type()); |
317 | auto allTruePredicate = rewriter.create<arith::ConstantOp>( |
318 | loc, DenseElementsAttr::get(predicateType, true)); |
319 | // Create padding vector (never used due to all-true predicate). |
320 | auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType); |
321 | // Get a pointer to the current slice. |
322 | auto slicePtr = |
323 | getInMemoryTileSlicePtr(rewriter, loc, tileMemory: tileAlloca, sliceIndex); |
324 | // Read the value of the current slice from ZA. |
325 | auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>( |
326 | loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); |
327 | // Load the new tile slice back from memory into ZA. |
328 | createLoadTileSliceIntrinsic( |
329 | rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, |
330 | allTruePredicate, slicePtr, tileId, sliceIndexI32); |
331 | // Store the current tile slice to memory. |
332 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
333 | rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca, |
334 | ValueRange{sliceIndex, zero}); |
335 | } |
336 | |
337 | /// Emits a full in-place swap of the contents of a tile in ZA and a |
338 | /// tile-sized memref (`tileAlloca`). |
339 | void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, |
340 | arm_sme::ArmSMETileType tileType, VectorType sliceType, |
341 | IntegerAttr tileId) const { |
342 | RewriterBase::InsertionGuard guard(rewriter); |
343 | // Create an scf.for over all tile slices. |
344 | auto minNumElts = |
345 | rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0)); |
346 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
347 | auto upperBound = rewriter.create<arith::MulIOp>( |
348 | loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc)); |
349 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
350 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
351 | // Emit a swap for each tile slice. |
352 | rewriter.setInsertionPointToStart(forOp.getBody()); |
353 | auto sliceIndex = forOp.getInductionVar(); |
354 | emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId, |
355 | sliceIndex); |
356 | } |
357 | }; |
358 | |
359 | enum class RequiresSpillsAndFills { Yes, No }; |
360 | |
361 | /// Base class for ArmSME to LLVM conversion patterns. By default, this adds |
362 | /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be |
363 | /// disabled by setting the `requiresSpillsAndFills` template parameter to |
364 | /// `RequiresSpillsAndFills::No`. |
365 | template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills = |
366 | RequiresSpillsAndFills::Yes> |
367 | struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> { |
368 | using ArmSMEOp = SourceOp; |
369 | using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; |
370 | |
371 | static constexpr bool requiresSpillsAndFillsConversion() { |
372 | return requiresSpillsAndFills == RequiresSpillsAndFills::Yes; |
373 | } |
374 | }; |
375 | |
376 | template <typename Pattern> |
377 | static void addArmSMEConversionPattern(RewritePatternSet &patterns, |
378 | LLVMTypeConverter const &typeConverter) { |
379 | // Register spills/fills for ops that implement the |
380 | // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to |
381 | // `RequiresSpillsAndFills::Yes`. |
382 | if constexpr (Pattern::requiresSpillsAndFillsConversion() && |
383 | std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait< |
384 | typename Pattern::ArmSMEOp>, |
385 | typename Pattern::ArmSMEOp>) { |
386 | // Add spill/fill conversions with a very high benefit to ensure |
387 | // they are lowered first. |
388 | patterns.add<ConvertArmSMESpillsAndFillsToLLVM>( |
389 | Pattern::ArmSMEOp::getOperationName(), typeConverter, |
390 | /*benefit=*/1337); |
391 | } |
392 | patterns.add<Pattern>(typeConverter); |
393 | } |
394 | |
395 | /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. |
396 | template <typename... Patterns> |
397 | static void |
398 | addArmSMEConversionPatterns(RewritePatternSet &patterns, |
399 | LLVMTypeConverter const &typeConverter) { |
400 | (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...); |
401 | } |
402 | |
403 | /// Lower 'arm_sme.zero' to SME intrinsics. |
404 | /// |
405 | /// BEFORE: |
406 | /// ```mlir |
407 | /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32> |
408 | /// ``` |
409 | /// |
410 | /// AFTER: |
411 | /// ```mlir |
412 | /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> () |
413 | /// %v = arm_sme.get_tile : vector<[4]x[4]xi32> |
414 | /// ``` |
415 | /// |
416 | /// The 'arm_sme.get_tile' (which models the return) will fold away once all |
417 | /// ArmSME ops have been converted to LLVM intrinsics. |
418 | struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { |
419 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
420 | |
421 | LogicalResult |
422 | matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, |
423 | ConversionPatternRewriter &rewriter) const override { |
424 | auto loc = zero.getLoc(); |
425 | |
426 | auto tileId = getTileIdOrError(zero); |
427 | if (!tileId) |
428 | return failure(); |
429 | |
430 | // Get the base mask for tile based on the element size. |
431 | // The base mask is just the mask to zero the first tile (of a size). |
432 | // These masks are derived from: |
433 | // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- |
434 | arm_sme::ArmSMETileType tileType = |
435 | *arm_sme::getSMETileType(zero.getTileType()); |
436 | auto baseMaskForSize = [&] { |
437 | switch (tileType) { |
438 | case arm_sme::ArmSMETileType::ZAB: |
439 | // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight |
440 | // 64-bit element tiles named ZA0.D to ZA7.D. |
441 | return 0b1111'1111; |
442 | case arm_sme::ArmSMETileType::ZAH: |
443 | // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit |
444 | // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left |
445 | // once for ZA1.H. |
446 | return 0b0101'0101; |
447 | case arm_sme::ArmSMETileType::ZAS: |
448 | // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit |
449 | // element tiles named ZA0.D and ZA4.D. |
450 | // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S. |
451 | return 0b0001'0001; |
452 | case arm_sme::ArmSMETileType::ZAD: |
453 | // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires |
454 | // setting the bit for that tile. |
455 | return 0b0000'0001; |
456 | default: |
457 | llvm_unreachable("bad element size"); |
458 | } |
459 | }(); |
460 | |
461 | // The actual mask is just the base mask shifted by the tile ID. |
462 | // This will be folded to a constant after tile allocation. |
463 | // |
464 | // The shift is just derived from the layout of the tiles, and that the tile |
465 | // ID is the index of the tile. For example, looking at the 32-bit ZAx.S |
466 | // tiles: |
467 | // |
468 | // ZA0.S = ZA0.D and ZA4.D |
469 | // * Tile ID -> 0 |
470 | // * Mask -> 00010001 = (00010001 << 0) |
471 | // ZA1.S = ZA1.D and ZA5.D |
472 | // * Tile ID -> 1 |
473 | // * Mask -> 00100010 = (00010001 << 1) |
474 | // ZA2.S = ZA2.D and ZA6.D |
475 | // * Tile ID -> 2 |
476 | // * Mask -> 01000100 = (00010001 << 2) |
477 | // ZA3.S = ZA3.D and ZA7.D |
478 | // * Tile ID -> 3 |
479 | // * Mask -> 10001000 = (00010001 << 3) |
480 | // |
481 | // This holds for all tile sizes. |
482 | int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); |
483 | rewriter.create<arm_sme::aarch64_sme_zero>( |
484 | loc, rewriter.getI32IntegerAttr(zeroMask)); |
485 | |
486 | // Create a placeholder op to preserve dataflow. |
487 | // Note: Place the `get_tile` op at the start of the block. This ensures |
488 | // that if there are multiple `zero` ops the intrinsics will be consecutive. |
489 | rewriter.setInsertionPointToStart(zero->getBlock()); |
490 | rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType()); |
491 | |
492 | return success(); |
493 | } |
494 | }; |
495 | |
496 | /// Lower `arm_sme.load_tile_slice` to SME intrinsics. |
497 | struct LoadTileSliceConversion |
498 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> { |
499 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
500 | |
501 | LogicalResult |
502 | matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, |
503 | arm_sme::LoadTileSliceOp::Adaptor adaptor, |
504 | ConversionPatternRewriter &rewriter) const override { |
505 | auto loc = loadTileSliceOp.getLoc(); |
506 | auto tileId = getTileIdOrError(loadTileSliceOp); |
507 | if (!tileId) |
508 | return failure(); |
509 | |
510 | Value ptr = this->getStridedElementPtr( |
511 | rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(), |
512 | adaptor.getIndices()); |
513 | |
514 | auto tileSlice = loadTileSliceOp.getTileSliceIndex(); |
515 | |
516 | // Cast tile slice to i32 for intrinsic. |
517 | auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( |
518 | loc, rewriter.getI32Type(), tileSlice); |
519 | |
520 | // Create all active predicate mask. |
521 | auto maskOp = loadTileSliceOp.getMask(); |
522 | |
523 | auto tileVectorType = loadTileSliceOp.getVectorType(); |
524 | arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); |
525 | arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); |
526 | |
527 | // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice. |
528 | createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr, |
529 | tileId, tileSliceI32); |
530 | |
531 | // The load intrinsics have no result, replace 'arm_sme.tile_load' with |
532 | // the input tile to preserve dataflow. |
533 | rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile()); |
534 | |
535 | return success(); |
536 | } |
537 | }; |
538 | |
539 | /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. |
540 | struct StoreTileSliceConversion |
541 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> { |
542 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
543 | |
544 | LogicalResult |
545 | matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, |
546 | arm_sme::StoreTileSliceOp::Adaptor adaptor, |
547 | ConversionPatternRewriter &rewriter) const override { |
548 | auto loc = storeTileSliceOp.getLoc(); |
549 | auto tileVectorType = storeTileSliceOp.getVectorType(); |
550 | |
551 | auto tileId = getTileIdOrError(storeTileSliceOp); |
552 | if (!tileId) |
553 | return failure(); |
554 | |
555 | // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. |
556 | Value ptr = this->getStridedElementPtr( |
557 | rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), |
558 | adaptor.getIndices()); |
559 | |
560 | auto tileSlice = storeTileSliceOp.getTileSliceIndex(); |
561 | |
562 | // Cast tile slice to i32 for intrinsic. |
563 | auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( |
564 | loc, rewriter.getI32Type(), tileSlice); |
565 | |
566 | auto maskOp = storeTileSliceOp.getMask(); |
567 | |
568 | arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout(); |
569 | arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); |
570 | |
571 | rewriter.replaceOp(storeTileSliceOp, |
572 | createStoreTileSliceIntrinsic(rewriter, loc, tileType, |
573 | layout, maskOp, ptr, |
574 | tileId, tileSliceI32)); |
575 | |
576 | return success(); |
577 | } |
578 | }; |
579 | |
580 | /// Lower `arm_sme.insert_tile_slice` to SME intrinsics. |
581 | struct InsertTileSliceConversion |
582 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> { |
583 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
584 | |
585 | LogicalResult |
586 | matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp, |
587 | arm_sme::InsertTileSliceOp::Adaptor adaptor, |
588 | ConversionPatternRewriter &rewriter) const override { |
589 | auto loc = insertTileSliceOp.getLoc(); |
590 | auto tileType = insertTileSliceOp.getTileType(); |
591 | |
592 | auto tileId = getTileIdOrError(insertTileSliceOp); |
593 | if (!tileId) |
594 | return failure(); |
595 | |
596 | auto tileSlice = insertTileSliceOp.getTileSliceIndex(); |
597 | |
598 | // Cast tile slice from index to i32 for intrinsic. |
599 | auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( |
600 | loc, rewriter.getI32Type(), tileSlice); |
601 | |
602 | // Create all active predicate mask. |
603 | auto one = rewriter.create<arith::ConstantOp>( |
604 | loc, rewriter.getI1Type(), |
605 | rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); |
606 | auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), |
607 | /*scalableDims=*/{true}); |
608 | auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); |
609 | |
610 | // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. |
611 | switch (insertTileSliceOp.getLayout()) { |
612 | case arm_sme::TileSliceLayout::Horizontal: |
613 | rewriter.create<arm_sme::aarch64_sme_write_horiz>( |
614 | loc, tileId, tileSliceI32, allActiveMask, |
615 | insertTileSliceOp.getVector()); |
616 | break; |
617 | case arm_sme::TileSliceLayout::Vertical: |
618 | rewriter.create<arm_sme::aarch64_sme_write_vert>( |
619 | loc, tileId, tileSliceI32, allActiveMask, |
620 | insertTileSliceOp.getVector()); |
621 | break; |
622 | } |
623 | |
624 | // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with |
625 | // the input tile to preserve dataflow. |
626 | rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile()); |
627 | |
628 | return success(); |
629 | } |
630 | }; |
631 | |
632 | /// Lower `arm_sme.extract_tile_slice` to SME intrinsics. |
633 | struct ExtractTileSliceConversion |
634 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> { |
635 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
636 | |
637 | LogicalResult |
638 | matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor, |
639 | ConversionPatternRewriter &rewriter) const override { |
640 | auto loc = extractTileSlice.getLoc(); |
641 | auto sliceType = extractTileSlice.getSliceType(); |
642 | auto sliceIndex = extractTileSlice.getTileSliceIndex(); |
643 | |
644 | auto tileId = getTileIdOrError(extractTileSlice); |
645 | if (!tileId) |
646 | return failure(); |
647 | |
648 | // Create an 'all true' predicate for the tile slice. |
649 | auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); |
650 | auto allTruePredicate = rewriter.create<arith::ConstantOp>( |
651 | loc, DenseElementsAttr::get(predicateType, true)); |
652 | |
653 | // Zero destination/fallback for tile slice extraction. |
654 | auto zeroVector = rewriter.create<arith::ConstantOp>( |
655 | loc, sliceType, rewriter.getZeroAttr(sliceType)); |
656 | |
657 | // Cast tile slice from index to i32 for intrinsic. |
658 | auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( |
659 | loc, rewriter.getI32Type(), sliceIndex); |
660 | |
661 | // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. |
662 | switch (extractTileSlice.getLayout()) { |
663 | case arm_sme::TileSliceLayout::Horizontal: |
664 | rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>( |
665 | extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId, |
666 | sliceIndexI32); |
667 | break; |
668 | case arm_sme::TileSliceLayout::Vertical: |
669 | rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>( |
670 | extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId, |
671 | sliceIndexI32); |
672 | break; |
673 | } |
674 | |
675 | return success(); |
676 | } |
677 | }; |
678 | |
679 | /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics. |
680 | /// |
681 | /// Example: |
682 | /// |
683 | /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) |
684 | /// : vector<[4]xf32>, vector<[4]xf32> |
685 | /// |
686 | /// is converted to: |
687 | /// |
688 | /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}> |
689 | /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, |
690 | /// vector<[4]xf32>) -> () |
691 | /// |
692 | /// Currently only supports FMOPA and BFMOPA (non-widening). |
693 | struct OuterProductOpConversion |
694 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> { |
695 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
696 | |
697 | LogicalResult |
698 | matchAndRewrite(arm_sme::OuterProductOp outerProductOp, |
699 | arm_sme::OuterProductOp::Adaptor adaptor, |
700 | ConversionPatternRewriter &rewriter) const override { |
701 | auto tileId = getTileIdOrError(outerProductOp); |
702 | if (!tileId) |
703 | return failure(); |
704 | |
705 | auto isSupportedType = [](VectorType vectorType) { |
706 | // TODO: the FP outer product instruction variants are predicated on |
707 | // different features [1]: |
708 | // |
709 | // * FMOPA (non-widening) |
710 | // * half-precision - +sme2p1,+sme-f16f16 |
711 | // * single-precision - +sme |
712 | // * double-precision - +sme-f64f64 |
713 | // * BFMOPA |
714 | // * half-precision - +sme2p1,+b16b16 |
715 | // |
716 | // It should be possible to control lowering based on target features. |
717 | // [1] |
718 | // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile |
719 | if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) |
720 | return false; |
721 | |
722 | auto elementType = vectorType.getElementType(); |
723 | |
724 | if (!elementType.isF16() && !elementType.isBF16() && |
725 | !elementType.isF32() && !elementType.isF64()) |
726 | return false; |
727 | |
728 | unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits / |
729 | vectorType.getElementTypeBitWidth(); |
730 | return vectorType.getShape() == |
731 | ArrayRef<int64_t>({minNumElts, minNumElts}); |
732 | }; |
733 | |
734 | // TODO: Support CombiningKind::Sub for outer products. |
735 | if (outerProductOp.getKind() != arm_sme::CombiningKind::Add) |
736 | return outerProductOp.emitError("unsupported kind"); |
737 | |
738 | auto resultVectorType = outerProductOp.getResultType(); |
739 | if (!isSupportedType(resultVectorType)) |
740 | return outerProductOp.emitError("unsupported type"); |
741 | |
742 | auto loc = outerProductOp.getLoc(); |
743 | |
744 | Value acc = outerProductOp.getAcc(); |
745 | if (!acc) { |
746 | // Initalize accumulator with zero. |
747 | auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType); |
748 | zero.setTileId(tileId); |
749 | acc = zero; |
750 | } |
751 | |
752 | Value lhsMask = outerProductOp.getLhsMask(); |
753 | Value rhsMask = outerProductOp.getRhsMask(); |
754 | |
755 | if (!lhsMask || !rhsMask) { |
756 | auto predTy = |
757 | outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); |
758 | Value allActiveMask = rewriter.create<arith::ConstantOp>( |
759 | loc, DenseElementsAttr::get(predTy, true)); |
760 | lhsMask = allActiveMask; |
761 | rhsMask = allActiveMask; |
762 | } |
763 | |
764 | // Create 'arm_sme.intr.mopa' outer product intrinsic. |
765 | rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask, |
766 | outerProductOp.getLhs(), |
767 | outerProductOp.getRhs()); |
768 | |
769 | // The outerproduct intrinsics have no result, replace |
770 | // 'arm_sme.outerproduct' with the input tile to preserve dataflow. |
771 | rewriter.replaceOp(outerProductOp, acc); |
772 | |
773 | return success(); |
774 | } |
775 | }; |
776 | |
777 | /// Lower 2-way and 4-way widening outer products to intrinsics. |
778 | template <class OuterProductWideningOp, class OuterProductWideningIntrOp> |
779 | struct OuterProductWideningOpConversion |
780 | : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> { |
781 | using ConvertArmSMEOpToLLVMPattern< |
782 | OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern; |
783 | |
784 | LogicalResult |
785 | matchAndRewrite(OuterProductWideningOp op, |
786 | typename OuterProductWideningOp::Adaptor adaptor, |
787 | ConversionPatternRewriter &rewriter) const override { |
788 | auto tileId = getTileIdOrError(op); |
789 | if (!tileId) |
790 | return failure(); |
791 | |
792 | auto loc = op.getLoc(); |
793 | Value acc = op.getAcc(); |
794 | if (!acc) { |
795 | // Initalize accumulator with zero. |
796 | auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType()); |
797 | zero.setTileId(tileId); |
798 | acc = zero; |
799 | } |
800 | |
801 | Value lhsMask = op.getLhsMask(); |
802 | Value rhsMask = op.getRhsMask(); |
803 | if (!lhsMask || !rhsMask) { |
804 | auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); |
805 | Value allActiveMask = rewriter.create<arith::ConstantOp>( |
806 | loc, DenseElementsAttr::get(predTy, true)); |
807 | lhsMask = allActiveMask; |
808 | rhsMask = allActiveMask; |
809 | } |
810 | |
811 | rewriter.create<OuterProductWideningIntrOp>( |
812 | loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); |
813 | |
814 | // The outerproduct intrinsics have no result, replace |
815 | // 'arm_sme.outerproduct' with the input tile to preserve dataflow. |
816 | rewriter.replaceOp(op, acc); |
817 | |
818 | return success(); |
819 | } |
820 | }; |
821 | |
822 | /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. |
823 | /// |
824 | /// Example: |
825 | /// |
826 | /// %0 = arm_sme.streaming_vl <half> |
827 | /// |
828 | /// is converted to: |
829 | /// |
830 | /// %cnt = "arm_sme.intr.cntsh"() : () -> i64 |
831 | /// %0 = arith.index_cast %cnt : i64 to index |
832 | /// |
833 | struct StreamingVLOpConversion |
834 | : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp, |
835 | RequiresSpillsAndFills::No> { |
836 | using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; |
837 | |
838 | LogicalResult |
839 | matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp, |
840 | arm_sme::StreamingVLOp::Adaptor adaptor, |
841 | ConversionPatternRewriter &rewriter) const override { |
842 | auto loc = streamingVlOp.getLoc(); |
843 | auto i64Type = rewriter.getI64Type(); |
844 | auto *intrOp = [&]() -> Operation * { |
845 | switch (streamingVlOp.getTypeSize()) { |
846 | case arm_sme::TypeSize::Byte: |
847 | return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type); |
848 | case arm_sme::TypeSize::Half: |
849 | return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type); |
850 | case arm_sme::TypeSize::Word: |
851 | return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type); |
852 | case arm_sme::TypeSize::Double: |
853 | return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type); |
854 | } |
855 | llvm_unreachable("unknown type size in StreamingVLOpConversion"); |
856 | }(); |
857 | rewriter.replaceOpWithNewOp<arith::IndexCastOp>( |
858 | streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0)); |
859 | return success(); |
860 | } |
861 | }; |
862 | |
863 | /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise |
864 | /// or-ing the zero masks. Note: In future the backend _should_ handle this. |
865 | static void mergeConsecutiveTileZerosInBlock(Block *block) { |
866 | uint32_t mergedZeroMask = 0; |
867 | SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge; |
868 | auto replaceMergedZeroOps = [&] { |
869 | auto cleanup = llvm::make_scope_exit([&] { |
870 | mergedZeroMask = 0; |
871 | zeroOpsToMerge.clear(); |
872 | }); |
873 | if (zeroOpsToMerge.size() <= 1) |
874 | return; |
875 | IRRewriter rewriter(zeroOpsToMerge.front()); |
876 | rewriter.create<arm_sme::aarch64_sme_zero>( |
877 | zeroOpsToMerge.front().getLoc(), |
878 | rewriter.getI32IntegerAttr(mergedZeroMask)); |
879 | for (auto zeroOp : zeroOpsToMerge) |
880 | rewriter.eraseOp(zeroOp); |
881 | }; |
882 | for (Operation &op : *block) { |
883 | if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) { |
884 | mergedZeroMask |= zeroOp.getTileMask(); |
885 | zeroOpsToMerge.push_back(zeroOp); |
886 | } else { |
887 | replaceMergedZeroOps(); |
888 | } |
889 | } |
890 | replaceMergedZeroOps(); |
891 | } |
892 | |
893 | } // namespace |
894 | |
895 | namespace { |
896 | |
897 | struct ConvertArmSMEToLLVMPass |
898 | : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> { |
899 | ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { |
900 | this->dumpTileLiveRanges = dumpTileLiveRanges; |
901 | } |
902 | void runOnOperation() override { |
903 | auto function = getOperation(); |
904 | |
905 | if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) |
906 | return signalPassFailure(); |
907 | |
908 | LLVMConversionTarget target(getContext()); |
909 | RewritePatternSet patterns(&getContext()); |
910 | LLVMTypeConverter converter(&getContext()); |
911 | configureArmSMEToLLVMConversionLegality(target); |
912 | populateArmSMEToLLVMConversionPatterns(converter, patterns); |
913 | |
914 | if (failed(applyPartialConversion(function, target, std::move(patterns)))) |
915 | signalPassFailure(); |
916 | |
917 | function->walk(mergeConsecutiveTileZerosInBlock); |
918 | |
919 | // Walk the function and fail if there are unexpected operations on SME |
920 | // tile types after conversion. |
921 | function->walk([&](Operation *op) { |
922 | // These ops are legal post conversion, skip these. |
923 | if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) || |
924 | !op->isRegistered()) |
925 | return; |
926 | auto isSMETileType = [](Type type) { |
927 | return arm_sme::isValidSMETileVectorType(type); |
928 | }; |
929 | if (llvm::any_of(Range: op->getResultTypes(), P: isSMETileType) || |
930 | llvm::any_of(Range: op->getOperandTypes(), P: isSMETileType)) { |
931 | op->emitOpError(message: "unexpected operation with SME tile type after " |
932 | "conversion to LLVM"); |
933 | signalPassFailure(); |
934 | } |
935 | }); |
936 | } |
937 | }; |
938 | |
939 | } // namespace |
940 | |
941 | void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { |
942 | target.addIllegalDialect<arm_sme::ArmSMEDialect>(); |
943 | target.addLegalOp< |
944 | arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str, |
945 | arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz, |
946 | arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz, |
947 | arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz, |
948 | arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, |
949 | arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz, |
950 | arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert, |
951 | arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert, |
952 | arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert, |
953 | arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert, |
954 | arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, |
955 | arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert, |
956 | arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert, |
957 | arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide, |
958 | arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide, |
959 | arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide, |
960 | arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32, |
961 | arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32, |
962 | arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide, |
963 | arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide, |
964 | arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb, |
965 | arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, |
966 | arm_sme::aarch64_sme_cntsd>(); |
967 | target.addLegalDialect<arith::ArithDialect, |
968 | /* The following are used to lower tile spills/fills */ |
969 | vector::VectorDialect, scf::SCFDialect, |
970 | memref::MemRefDialect>(); |
971 | // Pseudo operations. These cannot be code-generated but may exist in the |
972 | // input IR, or be generated during the conversion. They need to be eliminated |
973 | // before the final conversion to LLVM IR (and likely will be due to DCE). |
974 | target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp, |
975 | UnrealizedConversionCastOp>(); |
976 | } |
977 | |
978 | void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, |
979 | RewritePatternSet &patterns) { |
980 | converter.addConversion(callback: [&](VectorType type) -> std::optional<Type> { |
981 | // There's no LLVM type for SME tiles, but after lowering to intrinsics all |
982 | // SME vector types should be eliminated. |
983 | if (arm_sme::isValidSMETileVectorType(type)) |
984 | return type; |
985 | return std::nullopt; |
986 | }); |
987 | |
988 | addArmSMEConversionPatterns< |
989 | LoadTileSliceConversion, ExtractTileSliceConversion, |
990 | InsertTileSliceConversion, StoreTileSliceConversion, |
991 | StreamingVLOpConversion, OuterProductOpConversion, |
992 | OuterProductWideningOpConversion<arm_sme::FMopa2WayOp, |
993 | arm_sme::aarch64_sme_mopa_wide>, |
994 | OuterProductWideningOpConversion<arm_sme::FMops2WayOp, |
995 | arm_sme::aarch64_sme_mops_wide>, |
996 | OuterProductWideningOpConversion<arm_sme::SMopa2WayOp, |
997 | arm_sme::aarch64_sme_smopa_za32>, |
998 | OuterProductWideningOpConversion<arm_sme::SMops2WayOp, |
999 | arm_sme::aarch64_sme_smops_za32>, |
1000 | OuterProductWideningOpConversion<arm_sme::UMopa2WayOp, |
1001 | arm_sme::aarch64_sme_umopa_za32>, |
1002 | OuterProductWideningOpConversion<arm_sme::UMops2WayOp, |
1003 | arm_sme::aarch64_sme_umops_za32>, |
1004 | OuterProductWideningOpConversion<arm_sme::SMopa4WayOp, |
1005 | arm_sme::aarch64_sme_smopa_wide>, |
1006 | OuterProductWideningOpConversion<arm_sme::SMops4WayOp, |
1007 | arm_sme::aarch64_sme_smops_wide>, |
1008 | OuterProductWideningOpConversion<arm_sme::UMopa4WayOp, |
1009 | arm_sme::aarch64_sme_umopa_wide>, |
1010 | OuterProductWideningOpConversion<arm_sme::UMops4WayOp, |
1011 | arm_sme::aarch64_sme_umops_wide>, |
1012 | OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp, |
1013 | arm_sme::aarch64_sme_sumopa_wide>, |
1014 | OuterProductWideningOpConversion<arm_sme::SuMops4WayOp, |
1015 | arm_sme::aarch64_sme_sumops_wide>, |
1016 | OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp, |
1017 | arm_sme::aarch64_sme_usmopa_wide>, |
1018 | OuterProductWideningOpConversion<arm_sme::UsMops4WayOp, |
1019 | arm_sme::aarch64_sme_usmops_wide>, |
1020 | ZeroOpConversion>(patterns, converter); |
1021 | } |
1022 | |
1023 | std::unique_ptr<Pass> |
1024 | mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { |
1025 | return std::make_unique<ConvertArmSMEToLLVMPass>(args&: dumpTileLiveRanges); |
1026 | } |
1027 |
Definitions
- kInMemoryTileIdAttr
- createLoadTileSliceIntrinsic
- createStoreTileSliceIntrinsic
- getTileIdOrError
- createAllocaForTile
- getOrCreateAllocaForTile
- ConvertArmSMESpillsAndFillsToLLVM
- ConvertArmSMESpillsAndFillsToLLVM
- matchAndRewrite
- getInMemoryTileSlicePtr
- emitSliceSwap
- emitFullTileSwap
- RequiresSpillsAndFills
- ConvertArmSMEOpToLLVMPattern
- requiresSpillsAndFillsConversion
- addArmSMEConversionPattern
- addArmSMEConversionPatterns
- ZeroOpConversion
- matchAndRewrite
- LoadTileSliceConversion
- matchAndRewrite
- StoreTileSliceConversion
- matchAndRewrite
- InsertTileSliceConversion
- matchAndRewrite
- ExtractTileSliceConversion
- matchAndRewrite
- OuterProductOpConversion
- matchAndRewrite
- OuterProductWideningOpConversion
- matchAndRewrite
- StreamingVLOpConversion
- matchAndRewrite
- mergeConsecutiveTileZerosInBlock
- ConvertArmSMEToLLVMPass
- ConvertArmSMEToLLVMPass
- runOnOperation
- configureArmSMEToLLVMConversionLegality
- populateArmSMEToLLVMConversionPatterns
Improve your Profiling and Debugging skills
Find out more