1 | //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transform allocates SME tiles at the 'func.func' op level for ArmSME |
10 | // operations. It roughly implements a linear scan register allocator, similar |
11 | // to the one outlined in [1], but with simplifications and assumptions made for |
12 | // our use case. Note that this is a greedy allocator (so it may not always find |
13 | // the most optimal allocation of tiles). |
14 | // |
15 | // The allocator operates at the CF dialect level. It is the responsibility of |
16 | // users to ensure the IR has been lowered to CF before invoking the tile |
17 | // allocator. |
18 | // |
19 | // The 128-bit tiles overlap with other element tiles as follows (see section |
20 | // B2.3.2 of SME spec [2]): |
21 | // |
22 | // Tile Overlaps |
23 | // --------------------------------------------------------------------------- |
24 | // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, |
25 | // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q |
26 | // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q |
27 | // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q |
28 | // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q |
29 | // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q |
30 | // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q |
31 | // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q |
32 | // ZA0.D ZA0.Q, ZA8.Q |
33 | // ZA1.D ZA1.Q, ZA9.Q |
34 | // ZA2.D ZA2.Q, ZA10.Q |
35 | // ZA3.D ZA3.Q, ZA11.Q |
36 | // ZA4.D ZA4.Q, ZA12.Q |
37 | // ZA5.D ZA5.Q, ZA13.Q |
38 | // ZA6.D ZA6.Q, ZA14.Q |
39 | // ZA7.D ZA7.Q, ZA15.Q |
40 | // |
41 | // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register |
42 | // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) |
43 | // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf |
44 | // [2] https://developer.arm.com/documentation/ddi0616/aa |
45 | // |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | #include "mlir/Analysis/Liveness.h" |
49 | #include "mlir/Analysis/TopologicalSortUtils.h" |
50 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
51 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
52 | #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" |
53 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
54 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
55 | #include "mlir/Transforms/RegionUtils.h" |
56 | #include "llvm/ADT/IntervalMap.h" |
57 | #include "llvm/ADT/TypeSwitch.h" |
58 | #include <algorithm> |
59 | |
60 | namespace mlir::arm_sme { |
61 | #define GEN_PASS_DEF_TESTTILEALLOCATION |
62 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
63 | } // namespace mlir::arm_sme |
64 | |
65 | using namespace mlir; |
66 | using namespace mlir::arm_sme; |
67 | |
68 | namespace { |
69 | |
70 | enum class TileMask : unsigned { |
71 | // clang-format off |
72 | kZA0B = 0xffff, // 1111 1111 1111 1111 |
73 | |
74 | kZA0H = 0xaaaa, // 1010 1010 1010 1010 |
75 | kZA1H = 0x5555, // 0101 0101 0101 0101 |
76 | |
77 | kZA0S = 0x8888, // 1000 1000 1000 1000 |
78 | kZA1S = 0x4444, // 0100 0100 0100 0100 |
79 | kZA2S = 0x2222, // 0010 0010 0010 0010 |
80 | kZA3S = 0x1111, // 0001 0001 0001 0001 |
81 | |
82 | kZA0D = 0x8080, // 1000 0000 1000 0000 |
83 | kZA1D = 0x4040, // 0100 0000 0100 0000 |
84 | kZA2D = 0x2020, // 0010 0000 0010 0000 |
85 | kZA3D = 0x1010, // 0001 0000 0001 0000 |
86 | kZA4D = 0x808, // 0000 1000 0000 1000 |
87 | kZA5D = 0x404, // 0000 0100 0000 0100 |
88 | kZA6D = 0x202, // 0000 0010 0000 0010 |
89 | kZA7D = 0x101, // 0000 0001 0000 0001 |
90 | |
91 | kZA0Q = 0x8000, // 1000 0000 0000 0000 |
92 | kZA1Q = 0x4000, // 0100 0000 0000 0000 |
93 | kZA2Q = 0x2000, // 0010 0000 0000 0000 |
94 | kZA3Q = 0x1000, // 0001 0000 0000 0000 |
95 | kZA4Q = 0x800, // 0000 1000 0000 0000 |
96 | kZA5Q = 0x400, // 0000 0100 0000 0000 |
97 | kZA6Q = 0x200, // 0000 0010 0000 0000 |
98 | kZA7Q = 0x100, // 0000 0001 0000 0000 |
99 | kZA8Q = 0x80, // 0000 0000 1000 0000 |
100 | kZA9Q = 0x40, // 0000 0000 0100 0000 |
101 | kZA10Q = 0x20, // 0000 0000 0010 0000 |
102 | kZA11Q = 0x10, // 0000 0000 0001 0000 |
103 | kZA12Q = 0x8, // 0000 0000 0000 1000 |
104 | kZA13Q = 0x4, // 0000 0000 0000 0100 |
105 | kZA14Q = 0x2, // 0000 0000 0000 0010 |
106 | kZA15Q = 0x1, // 0000 0000 0000 0001 |
107 | |
108 | kNone = 0x0, // 0000 0000 0000 0000 |
109 | // clang-format on |
110 | |
111 | LLVM_MARK_AS_BITMASK_ENUM(kZA0B) |
112 | }; |
113 | |
114 | /// Returns the set of masks relevant for the given type. |
115 | static ArrayRef<TileMask> getMasks(ArmSMETileType type) { |
116 | static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; |
117 | static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; |
118 | static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, |
119 | TileMask::kZA2S, TileMask::kZA3S}; |
120 | static constexpr std::array ZA_D_MASKS = { |
121 | TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, |
122 | TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; |
123 | static constexpr std::array ZA_Q_MASKS = { |
124 | TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, |
125 | TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, |
126 | TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, |
127 | TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; |
128 | switch (type) { |
129 | case ArmSMETileType::ZAB: |
130 | return ZA_B_MASKS; |
131 | case ArmSMETileType::ZAH: |
132 | return ZA_H_MASKS; |
133 | case ArmSMETileType::ZAS: |
134 | return ZA_S_MASKS; |
135 | case ArmSMETileType::ZAD: |
136 | return ZA_D_MASKS; |
137 | case ArmSMETileType::ZAQ: |
138 | return ZA_Q_MASKS; |
139 | } |
140 | llvm_unreachable("unknown type in getMasks"); |
141 | } |
142 | |
143 | class TileAllocator { |
144 | public: |
145 | /// Allocates and returns a tile ID. Fails if there are no tiles left. |
146 | FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) { |
147 | auto masks = getMasks(tileType); |
148 | for (auto [tileId, tileMask] : llvm::enumerate(masks)) { |
149 | if ((tilesInUse & tileMask) == TileMask::kNone) { |
150 | tilesInUse |= tileMask; |
151 | return tileId; |
152 | } |
153 | } |
154 | return failure(); |
155 | } |
156 | |
157 | /// Acquires a specific tile ID. Asserts the tile is initially free. |
158 | void acquireTileId(ArmSMETileType tileType, unsigned tileId) { |
159 | TileMask tileMask = getMasks(tileType)[tileId]; |
160 | assert((tilesInUse & tileMask) == TileMask::kNone && |
161 | "cannot acquire allocated tile!"); |
162 | tilesInUse |= tileMask; |
163 | } |
164 | |
165 | /// Releases a previously allocated tile ID. |
166 | void releaseTileId(ArmSMETileType tileType, unsigned tileId) { |
167 | TileMask tileMask = getMasks(tileType)[tileId]; |
168 | assert((tilesInUse & tileMask) == tileMask && |
169 | "cannot release unallocated tile!"); |
170 | tilesInUse ^= tileMask; |
171 | } |
172 | |
173 | /// Allocates an in-memory tile ID. |
174 | unsigned allocateInMemoryTileId() { |
175 | // Note: We never release in-memory tile IDs. We could, which may allow |
176 | // reusing an allocation, but as we _never_ want to spill an SME tile this |
177 | // is not optimized. |
178 | return nextInMemoryTileId++; |
179 | } |
180 | |
181 | private: |
182 | TileMask tilesInUse = TileMask::kNone; |
183 | unsigned nextInMemoryTileId = kInMemoryTileIdBase; |
184 | }; |
185 | |
186 | /// Add new intermediate blocks for the true and false destinations of |
187 | /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness |
188 | /// overlaps due to copies at branches. |
189 | /// |
190 | /// BEFORE: |
191 | /// ```mlir |
192 | /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 |
193 | /// ``` |
194 | /// |
195 | /// AFTER: |
196 | /// ```mlir |
197 | /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy |
198 | /// ^bb1_copy: |
199 | /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) |
200 | /// ^bb2_copy: |
201 | /// cf.br ^bb2 |
202 | /// ``` |
203 | void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { |
204 | SmallVector<cf::CondBranchOp> worklist; |
205 | function.walk([&](cf::CondBranchOp condBranch) { |
206 | if (llvm::any_of(condBranch->getOperands(), [&](Value value) { |
207 | return isValidSMETileVectorType(type: value.getType()); |
208 | })) { |
209 | worklist.push_back(condBranch); |
210 | } |
211 | }); |
212 | |
213 | auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { |
214 | rewriter.setInsertionPointToEnd(source); |
215 | rewriter.create<cf::BranchOp>(loc, dest, args); |
216 | }; |
217 | |
218 | for (auto condBranch : worklist) { |
219 | auto loc = condBranch.getLoc(); |
220 | Block *block = condBranch->getBlock(); |
221 | auto newTrueBranch = rewriter.splitBlock(block, block->end()); |
222 | auto newFalseBranch = rewriter.splitBlock(block, block->end()); |
223 | insertJump(loc, newTrueBranch, condBranch.getTrueDest(), |
224 | condBranch.getTrueDestOperands()); |
225 | insertJump(loc, newFalseBranch, condBranch.getFalseDest(), |
226 | condBranch.getFalseDestOperands()); |
227 | rewriter.modifyOpInPlace(condBranch, [&] { |
228 | condBranch.getFalseDestOperandsMutable().clear(); |
229 | condBranch.getTrueDestOperandsMutable().clear(); |
230 | condBranch.setSuccessor(newTrueBranch, 0); |
231 | condBranch.setSuccessor(newFalseBranch, 1); |
232 | }); |
233 | } |
234 | } |
235 | |
236 | /// Inserts tile copies at `cf.br` operations. |
237 | /// |
238 | /// BEFORE: |
239 | /// ```mlir |
240 | /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) |
241 | /// ``` |
242 | /// |
243 | /// AFTER: |
244 | /// ```mlir |
245 | /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> |
246 | /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) |
247 | /// ``` |
248 | void insertCopiesAtBranches(IRRewriter &rewriter, |
249 | FunctionOpInterface function) { |
250 | for (Block &block : function.getBlocks()) { |
251 | Operation *terminator = block.getTerminator(); |
252 | if (!isa<cf::BranchOp>(terminator)) |
253 | continue; |
254 | rewriter.setInsertionPoint(terminator); |
255 | for (OpOperand &operand : terminator->getOpOperands()) { |
256 | if (isValidSMETileVectorType(operand.get().getType())) { |
257 | auto copy = |
258 | rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get()); |
259 | rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); |
260 | } |
261 | } |
262 | } |
263 | } |
264 | |
265 | /// Prepares the IR for tile allocation. It does this by first 'splitting' |
266 | /// conditional branches (see `splitCondBranches`), then inserting tile copies |
267 | /// at branch operations. The conditional branches are split to prevent the |
268 | /// copies needed for them overlapping between the true and false paths of the |
269 | /// branch (see `tile-allocation-copies.mlir` and |
270 | /// `tile-allocation-liveness.mlir` for examples). The copies break up live |
271 | /// ranges and ensure when moving out of SSA the semantics of the program are |
272 | /// preserved. |
273 | void preprocessForTileAllocation(IRRewriter &rewriter, |
274 | FunctionOpInterface function) { |
275 | splitCondBranches(rewriter, function); |
276 | insertCopiesAtBranches(rewriter, function); |
277 | } |
278 | |
279 | /// A live range for a (collection of) tile values. A live range is built up of |
280 | /// non-overlapping intervals [start, end) which represent parts of the program |
281 | /// where a value in the range needs to be live (i.e. in an SME virtual tile). |
282 | /// Note that as the intervals are non-overlapping all values within a live |
283 | /// range can be allocated to the same SME virtual tile. |
284 | struct LiveRange { |
285 | using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16, |
286 | llvm::IntervalMapHalfOpenInfo<unsigned>>; |
287 | using Allocator = RangeSet::Allocator; |
288 | // Dummy value for the IntervalMap. Only the keys matter (the intervals). |
289 | static constexpr uint8_t kValidLiveRange = 0xff; |
290 | |
291 | LiveRange(Allocator &allocator) |
292 | : ranges(std::make_unique<RangeSet>(args&: allocator)) {} |
293 | |
294 | /// Returns true if this range overlaps with `otherRange`. |
295 | bool overlaps(LiveRange const &otherRange) const { |
296 | return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges, |
297 | *otherRange.ranges) |
298 | .valid(); |
299 | } |
300 | |
301 | /// Returns true if this range is active at `point` in the program. |
302 | bool overlaps(uint64_t point) const { |
303 | return ranges->lookup(x: point) == kValidLiveRange; |
304 | } |
305 | |
306 | /// Unions this live range with `otherRange`, aborts if the ranges overlap. |
307 | void unionWith(LiveRange const &otherRange) { |
308 | for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end(); |
309 | ++it) |
310 | ranges->insert(a: it.start(), b: it.stop(), y: kValidLiveRange); |
311 | values.set_union(otherRange.values); |
312 | } |
313 | |
314 | /// Inserts an interval [start, end) for `value` into this range. |
315 | void insert(Value value, unsigned start, unsigned end) { |
316 | values.insert(X: value); |
317 | if (start != end) |
318 | ranges->insert(a: start, b: end, y: kValidLiveRange); |
319 | } |
320 | |
321 | bool empty() const { return ranges->empty(); } |
322 | unsigned start() const { return ranges->start(); } |
323 | unsigned end() const { return ranges->stop(); } |
324 | bool operator<(LiveRange const &other) const { |
325 | return start() < other.start(); |
326 | } |
327 | |
328 | ArmSMETileType getTileType() const { |
329 | return *getSMETileType(cast<VectorType>(values[0].getType())); |
330 | } |
331 | |
332 | /// The values contained in this live range. |
333 | SetVector<Value> values; |
334 | |
335 | /// A set of (non-overlapping) intervals that mark where any value in `values` |
336 | /// is live. |
337 | std::unique_ptr<RangeSet> ranges; |
338 | |
339 | /// The tile ID (or none) assigned to this live range. |
340 | std::optional<unsigned> tileId; |
341 | }; |
342 | |
343 | /// Number operations within a function to allow computing live ranges. |
344 | /// Operations are numbered consecutively wihin blocks, and the blocks are |
345 | /// topologically sorted (using forward edges). This function is only correct if |
346 | /// all ArmSME have been converted to CF (which is asserted). |
347 | DenseMap<Operation *, unsigned> |
348 | generateOperationNumbering(FunctionOpInterface function) { |
349 | unsigned index = 0; |
350 | SetVector<Block *> blocks = |
351 | getBlocksSortedByDominance(function.getFunctionBody()); |
352 | DenseMap<Operation *, unsigned> operationToIndexMap; |
353 | for (Block *block : blocks) { |
354 | index++; // We want block args to have their own number. |
355 | for (Operation &op : block->getOperations()) { |
356 | #ifndef NDEBUG |
357 | op.walk([&](ArmSMETileOpInterface nestedOp) { |
358 | assert(&op == nestedOp.getOperation() && |
359 | "ArmSME tile allocation does not support nested regions"); |
360 | }); |
361 | #endif |
362 | operationToIndexMap.try_emplace(&op, index++); |
363 | } |
364 | } |
365 | return operationToIndexMap; |
366 | } |
367 | |
368 | /// Gather live ranges for SME tiles from the MLIR liveness analysis. |
369 | DenseMap<Value, LiveRange> |
370 | gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, |
371 | LiveRange::Allocator &liveRangeAllocator, |
372 | Liveness &liveness, FunctionOpInterface function) { |
373 | assert(!operationToIndexMap.empty() && "expected operation numbering"); |
374 | DenseMap<Value, LiveRange> liveRanges; |
375 | /// Defines or updates a live range for an SME tile value. Live-ins may update |
376 | /// an existing live range (rather than define a new one). Note: If |
377 | /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in |
378 | /// the block. |
379 | auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef, |
380 | LivenessBlockInfo const &livenessInfo, |
381 | bool liveAtBlockEntry = false) { |
382 | if (!isValidSMETileVectorType(type: value.getType())) |
383 | return; |
384 | // Find or create a live range for `value`. |
385 | auto [it, _] = liveRanges.try_emplace(Key: value, Args&: liveRangeAllocator); |
386 | LiveRange &valueLiveRange = it->second; |
387 | auto lastUseInBlock = livenessInfo.getEndOperation(value, startOperation: firstUseOrDef); |
388 | // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. |
389 | unsigned startOpIdx = |
390 | operationToIndexMap.at(Val: firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); |
391 | unsigned endOpIdx = operationToIndexMap.at(Val: lastUseInBlock); |
392 | valueLiveRange.insert(value, start: startOpIdx, end: endOpIdx); |
393 | }; |
394 | |
395 | for (Block &block : function.getBlocks()) { |
396 | LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block); |
397 | // Handle block arguments: |
398 | for (Value argument : block.getArguments()) |
399 | defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo, |
400 | /*liveAtBlockEntry=*/true); |
401 | // Handle live-ins: |
402 | for (Value liveIn : livenessInfo->in()) |
403 | defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo, |
404 | /*liveAtBlockEntry=*/true); |
405 | // Handle new definitions: |
406 | for (Operation &op : block) { |
407 | for (Value result : op.getResults()) |
408 | defineOrUpdateValueLiveRange(result, &op, *livenessInfo); |
409 | } |
410 | } |
411 | |
412 | return liveRanges; |
413 | } |
414 | |
415 | /// Iterate over all predecessor tile values to a (tile) block argument. |
416 | static void forEachPredecessorTileValue(BlockArgument blockArg, |
417 | function_ref<void(Value)> callback) { |
418 | Block *block = blockArg.getOwner(); |
419 | unsigned argNumber = blockArg.getArgNumber(); |
420 | for (Block *pred : block->getPredecessors()) { |
421 | TypeSwitch<Operation *>(pred->getTerminator()) |
422 | .Case<cf::BranchOp>([&](auto branch) { |
423 | Value predecessorOperand = branch.getDestOperands()[argNumber]; |
424 | callback(predecessorOperand); |
425 | }) |
426 | .Case<cf::CondBranchOp>([&](auto condBranch) { |
427 | if (condBranch.getFalseDest() == block) { |
428 | Value predecessorOperand = |
429 | condBranch.getFalseDestOperands()[argNumber]; |
430 | callback(predecessorOperand); |
431 | } |
432 | if (condBranch.getTrueDest() == block) { |
433 | Value predecessorOperand = |
434 | condBranch.getTrueDestOperands()[argNumber]; |
435 | callback(predecessorOperand); |
436 | } |
437 | }); |
438 | } |
439 | } |
440 | |
441 | /// Coalesce live ranges where it would prevent unnecessary tile moves. |
442 | SmallVector<LiveRange *> |
443 | coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) { |
444 | DenseMap<Value, LiveRange *> liveRanges; |
445 | for (auto &[value, liveRange] : initialLiveRanges) { |
446 | liveRanges.insert(KV: {value, &liveRange}); |
447 | } |
448 | |
449 | // Merge the live ranges of values `a` and `b` into one (if they do not |
450 | // overlap). After this, the values `a` and `b` will both point to the same |
451 | // live range (which will contain multiple values). |
452 | auto mergeValuesIfNonOverlapping = [&](Value a, Value b) { |
453 | LiveRange *aLiveRange = liveRanges.at(Val: a); |
454 | LiveRange *bLiveRange = liveRanges.at(Val: b); |
455 | if (aLiveRange != bLiveRange && !aLiveRange->overlaps(otherRange: *bLiveRange)) { |
456 | aLiveRange->unionWith(otherRange: *bLiveRange); |
457 | for (Value value : bLiveRange->values) |
458 | liveRanges[value] = aLiveRange; |
459 | } |
460 | }; |
461 | |
462 | // Merge the live ranges of new definitions with their tile operands. |
463 | auto unifyDefinitionsWithOperands = [&](Value value) { |
464 | auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>(); |
465 | if (!armSMEOp) |
466 | return; |
467 | for (auto operand : armSMEOp->getOperands()) { |
468 | if (isValidSMETileVectorType(operand.getType())) |
469 | mergeValuesIfNonOverlapping(value, operand); |
470 | } |
471 | }; |
472 | |
473 | // Merge the live ranges of block arguments with their predecessors. |
474 | auto unifyBlockArgumentsWithPredecessors = [&](Value value) { |
475 | auto blockArg = dyn_cast<BlockArgument>(Val&: value); |
476 | if (!blockArg) |
477 | return; |
478 | forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) { |
479 | mergeValuesIfNonOverlapping(blockArg, predecessorTile); |
480 | }); |
481 | }; |
482 | |
483 | auto applyRule = [&](auto rule) { |
484 | llvm::for_each(llvm::make_first_range(c&: initialLiveRanges), rule); |
485 | }; |
486 | |
487 | // Unify as many live ranges as we can. This prevents unnecessary moves. |
488 | applyRule(unifyBlockArgumentsWithPredecessors); |
489 | applyRule(unifyDefinitionsWithOperands); |
490 | |
491 | // Remove duplicate live range entries. |
492 | SetVector<LiveRange *> uniqueLiveRanges; |
493 | for (auto [_, liveRange] : liveRanges) { |
494 | if (!liveRange->empty()) |
495 | uniqueLiveRanges.insert(X: liveRange); |
496 | } |
497 | |
498 | // Sort the new live ranges by starting point (ready for tile allocation). |
499 | auto coalescedLiveRanges = uniqueLiveRanges.takeVector(); |
500 | llvm::sort(C&: coalescedLiveRanges, |
501 | Comp: [](LiveRange *a, LiveRange *b) { return *a < *b; }); |
502 | return std::move(coalescedLiveRanges); |
503 | } |
504 | |
505 | /// Choose a live range to spill (via some heuristics). This picks either a live |
506 | /// range from `overlappingRanges`, or the new live range `newRange`. |
507 | template <typename OverlappingRangesIterator> |
508 | LiveRange * |
509 | chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges, |
510 | LiveRange *newRange) { |
511 | // Heuristic: Spill trivially copyable operations (usually free). |
512 | auto isTrivialSpill = [&](LiveRange &allocatedRange) { |
513 | return isTileTypeGreaterOrEqual(allocatedRange.getTileType(), |
514 | newRange->getTileType()) && |
515 | allocatedRange.values.size() == 1 && |
516 | isTriviallyCloneableTileOp( |
517 | allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>()); |
518 | }; |
519 | if (isTrivialSpill(*newRange)) |
520 | return newRange; |
521 | auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill); |
522 | if (trivialSpill != overlappingRanges.end()) |
523 | return &*trivialSpill; |
524 | |
525 | // Heuristic: Spill the range that ends last (with a compatible tile type). |
526 | auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) { |
527 | return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) || |
528 | a.end() < b.end(); |
529 | }; |
530 | LiveRange &latestEndingLiveRange = |
531 | *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier); |
532 | if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange)) |
533 | return &latestEndingLiveRange; |
534 | return newRange; |
535 | } |
536 | |
537 | /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. |
538 | void allocateTilesToLiveRanges( |
539 | ArrayRef<LiveRange *> liveRangesSortedByStartPoint) { |
540 | TileAllocator tileAllocator; |
541 | // `activeRanges` = Live ranges that need to be in a tile at the |
542 | // `currentPoint` in the program. |
543 | SetVector<LiveRange *> activeRanges; |
544 | // `inactiveRanges` = Live ranges that _do not_ need to be in a tile |
545 | // at the `currentPoint` in the program but could become active again later. |
546 | // An inactive section of a live range can be seen as a 'hole' in the live |
547 | // range, where it is possible to reuse the live range's tile ID _before_ it |
548 | // has ended. By identifying 'holes', the allocator can reuse tiles more |
549 | // often, which helps avoid costly tile spills. |
550 | SetVector<LiveRange *> inactiveRanges; |
551 | for (LiveRange *nextRange : liveRangesSortedByStartPoint) { |
552 | auto currentPoint = nextRange->start(); |
553 | // 1. Update the `activeRanges` at `currentPoint`. |
554 | activeRanges.remove_if(P: [&](LiveRange *activeRange) { |
555 | // Check for live ranges that have expired. |
556 | if (activeRange->end() <= currentPoint) { |
557 | tileAllocator.releaseTileId(activeRange->getTileType(), |
558 | *activeRange->tileId); |
559 | return true; |
560 | } |
561 | // Check for live ranges that have become inactive. |
562 | if (!activeRange->overlaps(point: currentPoint)) { |
563 | tileAllocator.releaseTileId(activeRange->getTileType(), |
564 | *activeRange->tileId); |
565 | inactiveRanges.insert(X: activeRange); |
566 | return true; |
567 | } |
568 | return false; |
569 | }); |
570 | // 2. Update the `inactiveRanges` at `currentPoint`. |
571 | inactiveRanges.remove_if(P: [&](LiveRange *inactiveRange) { |
572 | // Check for live ranges that have expired. |
573 | if (inactiveRange->end() <= currentPoint) { |
574 | return true; |
575 | } |
576 | // Check for live ranges that have become active. |
577 | if (inactiveRange->overlaps(point: currentPoint)) { |
578 | tileAllocator.acquireTileId(inactiveRange->getTileType(), |
579 | *inactiveRange->tileId); |
580 | activeRanges.insert(X: inactiveRange); |
581 | return true; |
582 | } |
583 | return false; |
584 | }); |
585 | |
586 | // 3. Collect inactive live ranges that overlap with the new live range. |
587 | // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint` |
588 | // whereas this checks if there is an overlap at any future point too. |
589 | SmallVector<LiveRange *> overlappingInactiveRanges; |
590 | for (LiveRange *inactiveRange : inactiveRanges) { |
591 | if (inactiveRange->overlaps(otherRange: *nextRange)) { |
592 | // We need to reserve the tile IDs of overlapping inactive ranges to |
593 | // prevent two (overlapping) live ranges from getting the same tile ID. |
594 | tileAllocator.acquireTileId(inactiveRange->getTileType(), |
595 | *inactiveRange->tileId); |
596 | overlappingInactiveRanges.push_back(Elt: inactiveRange); |
597 | } |
598 | } |
599 | |
600 | // 4. Allocate a tile ID to `nextRange`. |
601 | auto rangeTileType = nextRange->getTileType(); |
602 | auto tileId = tileAllocator.allocateTileId(rangeTileType); |
603 | if (succeeded(tileId)) { |
604 | nextRange->tileId = *tileId; |
605 | } else { |
606 | // Create an iterator over all overlapping live ranges. |
607 | auto allOverlappingRanges = llvm::concat<LiveRange>( |
608 | Ranges: llvm::make_pointee_range(Range: activeRanges.getArrayRef()), |
609 | Ranges: llvm::make_pointee_range(Range&: overlappingInactiveRanges)); |
610 | // Choose an overlapping live range to spill. |
611 | LiveRange *rangeToSpill = |
612 | chooseSpillUsingHeuristics(overlappingRanges: allOverlappingRanges, newRange: nextRange); |
613 | if (rangeToSpill != nextRange) { |
614 | // Spill an (in)active live range (so release its tile ID first). |
615 | tileAllocator.releaseTileId(rangeToSpill->getTileType(), |
616 | *rangeToSpill->tileId); |
617 | // This will always succeed after a spill (of an active live range). |
618 | nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType); |
619 | // Remove the live range from the active/inactive sets. |
620 | if (!activeRanges.remove(X: rangeToSpill)) { |
621 | bool removed = inactiveRanges.remove(X: rangeToSpill); |
622 | assert(removed && "expected a range to be removed!"); |
623 | (void)removed; |
624 | } |
625 | } |
626 | rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId(); |
627 | } |
628 | |
629 | // 5. Insert the live range into the active ranges. |
630 | if (nextRange->tileId < kInMemoryTileIdBase) |
631 | activeRanges.insert(X: nextRange); |
632 | |
633 | // 6. Release tiles reserved for inactive live ranges (in step 3). |
634 | for (LiveRange *range : overlappingInactiveRanges) { |
635 | if (*range->tileId < kInMemoryTileIdBase) |
636 | tileAllocator.releaseTileId(range->getTileType(), *range->tileId); |
637 | } |
638 | } |
639 | } |
640 | |
641 | /// Assigns a tile ID to an MLIR value. |
642 | void assignTileIdToValue(IRRewriter &rewriter, Value value, |
643 | IntegerAttr tileIdAttr) { |
644 | if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) |
645 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); |
646 | for (Operation *user : value.getUsers()) { |
647 | if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) { |
648 | // Ensure ArmSME ops that don't produce a value still get a tile ID. |
649 | if (!hasTileResult(tileOp)) |
650 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); |
651 | } |
652 | } |
653 | } |
654 | |
655 | /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. |
656 | LogicalResult assignTileIdsAndResolveTrivialConflicts( |
657 | IRRewriter &rewriter, FunctionOpInterface function, |
658 | ArrayRef<LiveRange *> allocatedLiveRanges) { |
659 | for (LiveRange const *liveRange : allocatedLiveRanges) { |
660 | auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId); |
661 | auto isAllocatedToSameTile = [&](Value value) { |
662 | if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); |
663 | tileOp && tileOp.getTileId() == tileIdAttr) |
664 | return true; |
665 | return liveRange->values.contains(key: value); |
666 | }; |
667 | |
668 | /// Eliminates copies where the operand has the same tile ID. |
669 | auto foldRedundantCopies = [&](Value value) -> LogicalResult { |
670 | auto copyOp = value.getDefiningOp<CopyTileOp>(); |
671 | if (!copyOp || !isAllocatedToSameTile(copyOp.getTile())) |
672 | return failure(); |
673 | rewriter.replaceAllUsesWith(copyOp, copyOp.getTile()); |
674 | return success(); |
675 | }; |
676 | |
677 | /// Validates each predecessor to a tile block argument has been assigned |
678 | /// the same tile ID. |
679 | auto validateBlockArguments = [&](Value value) { |
680 | auto blockArg = dyn_cast<BlockArgument>(Val&: value); |
681 | if (!blockArg) { |
682 | // Not a block argument (nothing to validate). |
683 | return success(); |
684 | } |
685 | bool tileMismatch = false; |
686 | forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) { |
687 | if (tileMismatch) |
688 | return; |
689 | if (!isAllocatedToSameTile(predecessorTile)) { |
690 | blockArg.getOwner()->getParentOp()->emitOpError( |
691 | message: "block argument not allocated to the same SME virtial tile as " |
692 | "predecessors"); |
693 | tileMismatch = true; |
694 | } |
695 | }); |
696 | return success(/*isSuccess=*/IsSuccess: !tileMismatch); |
697 | }; |
698 | |
699 | /// Attempts to resolve (trivial) tile ID conflicts. |
700 | auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult { |
701 | auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); |
702 | OpOperand *tileOperand = getTileOpOperand(tileOp); |
703 | if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) { |
704 | // Operand already allocated to the correct tile. |
705 | // No conflict to resolve. |
706 | return success(); |
707 | } |
708 | auto operandTileOp = |
709 | tileOperand->get().getDefiningOp<ArmSMETileOpInterface>(); |
710 | if (!isTriviallyCloneableTileOp(operandTileOp)) { |
711 | auto error = |
712 | tileOp.emitOpError("tile operand allocated to different SME " |
713 | "virtial tile (move required)"); |
714 | error.attachNote(tileOperand->get().getLoc()) |
715 | << "tile operand is: "<< tileOperand->get(); |
716 | return error; |
717 | } |
718 | // Cloning prevents a move/spill (though may require recomputation). |
719 | rewriter.setInsertionPoint(tileOp); |
720 | auto clonedOp = operandTileOp.clone(); |
721 | rewriter.modifyOpInPlace(clonedOp, |
722 | [&] { clonedOp.setTileId(tileOp.getTileId()); }); |
723 | rewriter.insert(op: clonedOp); |
724 | if (isa<CopyTileOp>(tileOp)) { |
725 | rewriter.replaceAllUsesWith(tileOp->getResult(0), |
726 | clonedOp->getResult(0)); |
727 | } else { |
728 | rewriter.modifyOpInPlace( |
729 | tileOp, [&] { tileOperand->assign(value: clonedOp->getResult(0)); }); |
730 | } |
731 | return success(); |
732 | }; |
733 | |
734 | for (Value value : liveRange->values) { |
735 | // 1. Assign the tile ID to the value. |
736 | assignTileIdToValue(rewriter, value, tileIdAttr); |
737 | |
738 | // 2. Attempt to eliminate redundant tile copies. |
739 | if (succeeded(Result: foldRedundantCopies(value))) |
740 | continue; |
741 | |
742 | // 3. Validate tile block arguments. |
743 | if (failed(Result: validateBlockArguments(value))) |
744 | return failure(); |
745 | |
746 | // 4. Attempt to resolve (trivial) tile ID conflicts. |
747 | if (failed(Result: resolveTrivialTileConflicts(value))) |
748 | return failure(); |
749 | } |
750 | } |
751 | return success(); |
752 | } |
753 | |
754 | /// Prints live ranges alongside operation names for debugging. |
755 | void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, |
756 | ArrayRef<LiveRange const *> liveRanges, |
757 | FunctionOpInterface function) { |
758 | llvm::errs() << "SME Tile Liveness: @"<< function.getName() |
759 | << "\nKey:\nS - Start\nE - End\n| - Live\n"; |
760 | for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) { |
761 | llvm::errs() << "^bb"<< blockIdx << ":\n"; |
762 | for (Operation &op : block.getOperations()) { |
763 | unsigned operationIndex = operationToIndexMap.at(&op); |
764 | for (LiveRange const *range : liveRanges) { |
765 | char liveness = ' '; |
766 | for (auto it = range->ranges->begin(); it != range->ranges->end(); |
767 | ++it) { |
768 | if (it.start() == operationIndex) |
769 | liveness = (liveness == 'E' ? '|' : 'S'); |
770 | else if (it.stop() == operationIndex) |
771 | liveness = (liveness == 'S' ? '|' : 'E'); |
772 | else if (operationIndex >= it.start() && operationIndex < it.stop()) |
773 | liveness = '|'; |
774 | } |
775 | llvm::errs() << liveness; |
776 | } |
777 | llvm::errs() << ' ' << op.getName() << '\n'; |
778 | } |
779 | } |
780 | llvm::errs() << "==========\n"; |
781 | } |
782 | |
783 | struct TestTileAllocationPass |
784 | : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> { |
785 | using TestTileAllocationBase::TestTileAllocationBase; |
786 | void runOnOperation() override { |
787 | FunctionOpInterface function = getOperation(); |
788 | if (preprocessOnly) { |
789 | IRRewriter rewriter(function); |
790 | return preprocessForTileAllocation(rewriter, function); |
791 | } |
792 | if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) |
793 | signalPassFailure(); |
794 | } |
795 | }; |
796 | } // namespace |
797 | |
798 | LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, |
799 | bool dumpRanges) { |
800 | if (function.empty()) { |
801 | // TODO: Also return early if the function contains no ArmSME ops? |
802 | return success(); |
803 | } |
804 | |
805 | LiveRange::Allocator liveRangeAllocator; |
806 | IRRewriter rewriter(function.getContext()); |
807 | |
808 | // 1. Preprocess the IR for tile allocation. |
809 | preprocessForTileAllocation(rewriter, function); |
810 | |
811 | // 2. Gather live ranges for each ArmSME tile within the function. |
812 | Liveness liveness(function); |
813 | auto operationToIndexMap = generateOperationNumbering(function); |
814 | auto initialLiveRanges = gatherTileLiveRanges( |
815 | operationToIndexMap, liveRangeAllocator, liveness, function); |
816 | if (initialLiveRanges.empty()) |
817 | return success(); |
818 | |
819 | if (dumpRanges) { |
820 | // Wrangle initial live ranges into a form suitable for printing. |
821 | auto nonEmpty = llvm::make_filter_range( |
822 | llvm::make_second_range(initialLiveRanges), |
823 | [&](LiveRange const &liveRange) { return !liveRange.empty(); }); |
824 | auto initialRanges = llvm::to_vector(llvm::map_range( |
825 | nonEmpty, [](LiveRange const &liveRange) { return &liveRange; })); |
826 | llvm::sort(initialRanges, |
827 | [](LiveRange const *a, LiveRange const *b) { return *a < *b; }); |
828 | llvm::errs() << "\n========== Initial Live Ranges:\n"; |
829 | dumpLiveRanges(operationToIndexMap, initialRanges, function); |
830 | } |
831 | |
832 | // 3. Coalesce (non-overlapping) live ranges where it would be beneficial |
833 | // for tile allocation. E.g. Unify the result of an operation with its |
834 | // operands. |
835 | auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges); |
836 | |
837 | if (dumpRanges) { |
838 | llvm::errs() << "\n========== Coalesced Live Ranges:\n"; |
839 | dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function); |
840 | } |
841 | |
842 | // 4. Allocate tile IDs to live ranges. |
843 | allocateTilesToLiveRanges(coalescedLiveRanges); |
844 | |
845 | // 5. Assign the tile IDs back to the ArmSME operations. |
846 | if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function, |
847 | coalescedLiveRanges))) { |
848 | return failure(); |
849 | } |
850 | |
851 | // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no |
852 | // users). This prevents the LLVM conversion needlessly inserting spills. |
853 | eraseTriviallyDeadTileOps(rewriter, function); |
854 | return success(); |
855 | } |
856 |
Definitions
- TileMask
- getMasks
- TileAllocator
- allocateTileId
- acquireTileId
- releaseTileId
- allocateInMemoryTileId
- splitCondBranches
- insertCopiesAtBranches
- preprocessForTileAllocation
- LiveRange
- kValidLiveRange
- LiveRange
- overlaps
- overlaps
- unionWith
- insert
- empty
- start
- end
- operator<
- getTileType
- generateOperationNumbering
- gatherTileLiveRanges
- forEachPredecessorTileValue
- coalesceTileLiveRanges
- chooseSpillUsingHeuristics
- allocateTilesToLiveRanges
- assignTileIdToValue
- assignTileIdsAndResolveTrivialConflicts
- dumpLiveRanges
- TestTileAllocationPass
- runOnOperation
Improve your Profiling and Debugging skills
Find out more