1//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10// operations. It roughly implements a linear scan register allocator, similar
11// to the one outlined in [1], but with simplifications and assumptions made for
12// our use case. Note that this is a greedy allocator (so it may not always find
13// the most optimal allocation of tiles).
14//
15// The allocator operates at the CF dialect level. It is the responsibility of
16// users to ensure the IR has been lowered to CF before invoking the tile
17// allocator.
18//
19// The 128-bit tiles overlap with other element tiles as follows (see section
20// B2.3.2 of SME spec [2]):
21//
22// Tile Overlaps
23// ---------------------------------------------------------------------------
24// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32// ZA0.D ZA0.Q, ZA8.Q
33// ZA1.D ZA1.Q, ZA9.Q
34// ZA2.D ZA2.Q, ZA10.Q
35// ZA3.D ZA3.Q, ZA11.Q
36// ZA4.D ZA4.Q, ZA12.Q
37// ZA5.D ZA5.Q, ZA13.Q
38// ZA6.D ZA6.Q, ZA14.Q
39// ZA7.D ZA7.Q, ZA15.Q
40//
41// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44// [2] https://developer.arm.com/documentation/ddi0616/aa
45//
46//===----------------------------------------------------------------------===//
47
48#include "mlir/Analysis/Liveness.h"
49#include "mlir/Analysis/TopologicalSortUtils.h"
50#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
51#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
52#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
53#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
54#include "mlir/Dialect/Func/IR/FuncOps.h"
55#include "llvm/ADT/IntervalMap.h"
56#include "llvm/ADT/TypeSwitch.h"
57
58namespace mlir::arm_sme {
59#define GEN_PASS_DEF_TESTTILEALLOCATION
60#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
61} // namespace mlir::arm_sme
62
63using namespace mlir;
64using namespace mlir::arm_sme;
65
66namespace {
67
68enum class TileMask : unsigned {
69 // clang-format off
70 kZA0B = 0xffff, // 1111 1111 1111 1111
71
72 kZA0H = 0xaaaa, // 1010 1010 1010 1010
73 kZA1H = 0x5555, // 0101 0101 0101 0101
74
75 kZA0S = 0x8888, // 1000 1000 1000 1000
76 kZA1S = 0x4444, // 0100 0100 0100 0100
77 kZA2S = 0x2222, // 0010 0010 0010 0010
78 kZA3S = 0x1111, // 0001 0001 0001 0001
79
80 kZA0D = 0x8080, // 1000 0000 1000 0000
81 kZA1D = 0x4040, // 0100 0000 0100 0000
82 kZA2D = 0x2020, // 0010 0000 0010 0000
83 kZA3D = 0x1010, // 0001 0000 0001 0000
84 kZA4D = 0x808, // 0000 1000 0000 1000
85 kZA5D = 0x404, // 0000 0100 0000 0100
86 kZA6D = 0x202, // 0000 0010 0000 0010
87 kZA7D = 0x101, // 0000 0001 0000 0001
88
89 kZA0Q = 0x8000, // 1000 0000 0000 0000
90 kZA1Q = 0x4000, // 0100 0000 0000 0000
91 kZA2Q = 0x2000, // 0010 0000 0000 0000
92 kZA3Q = 0x1000, // 0001 0000 0000 0000
93 kZA4Q = 0x800, // 0000 1000 0000 0000
94 kZA5Q = 0x400, // 0000 0100 0000 0000
95 kZA6Q = 0x200, // 0000 0010 0000 0000
96 kZA7Q = 0x100, // 0000 0001 0000 0000
97 kZA8Q = 0x80, // 0000 0000 1000 0000
98 kZA9Q = 0x40, // 0000 0000 0100 0000
99 kZA10Q = 0x20, // 0000 0000 0010 0000
100 kZA11Q = 0x10, // 0000 0000 0001 0000
101 kZA12Q = 0x8, // 0000 0000 0000 1000
102 kZA13Q = 0x4, // 0000 0000 0000 0100
103 kZA14Q = 0x2, // 0000 0000 0000 0010
104 kZA15Q = 0x1, // 0000 0000 0000 0001
105
106 kNone = 0x0, // 0000 0000 0000 0000
107 // clang-format on
108
109 LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
110};
111
112/// Returns the set of masks relevant for the given type.
113static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
114 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117 TileMask::kZA2S, TileMask::kZA3S};
118 static constexpr std::array ZA_D_MASKS = {
119 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121 static constexpr std::array ZA_Q_MASKS = {
122 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
126 switch (type) {
127 case ArmSMETileType::ZAB:
128 return ZA_B_MASKS;
129 case ArmSMETileType::ZAH:
130 return ZA_H_MASKS;
131 case ArmSMETileType::ZAS:
132 return ZA_S_MASKS;
133 case ArmSMETileType::ZAD:
134 return ZA_D_MASKS;
135 case ArmSMETileType::ZAQ:
136 return ZA_Q_MASKS;
137 }
138 llvm_unreachable("unknown type in getMasks");
139}
140
141class TileAllocator {
142public:
143 /// Allocates and returns a tile ID. Fails if there are no tiles left.
144 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
145 auto masks = getMasks(type: tileType);
146 for (auto [tileId, tileMask] : llvm::enumerate(First&: masks)) {
147 if ((tilesInUse & tileMask) == TileMask::kNone) {
148 tilesInUse |= tileMask;
149 return tileId;
150 }
151 }
152 return failure();
153 }
154
155 /// Acquires a specific tile ID. Asserts the tile is initially free.
156 void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
157 TileMask tileMask = getMasks(type: tileType)[tileId];
158 assert((tilesInUse & tileMask) == TileMask::kNone &&
159 "cannot acquire allocated tile!");
160 tilesInUse |= tileMask;
161 }
162
163 /// Releases a previously allocated tile ID.
164 void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
165 TileMask tileMask = getMasks(type: tileType)[tileId];
166 assert((tilesInUse & tileMask) == tileMask &&
167 "cannot release unallocated tile!");
168 tilesInUse ^= tileMask;
169 }
170
171 /// Allocates an in-memory tile ID.
172 unsigned allocateInMemoryTileId() {
173 // Note: We never release in-memory tile IDs. We could, which may allow
174 // reusing an allocation, but as we _never_ want to spill an SME tile this
175 // is not optimized.
176 return nextInMemoryTileId++;
177 }
178
179private:
180 TileMask tilesInUse = TileMask::kNone;
181 unsigned nextInMemoryTileId = kInMemoryTileIdBase;
182};
183
184/// Add new intermediate blocks for the true and false destinations of
185/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
186/// overlaps due to copies at branches.
187///
188/// BEFORE:
189/// ```mlir
190/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
191/// ```
192///
193/// AFTER:
194/// ```mlir
195/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
196/// ^bb1_copy:
197/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
198/// ^bb2_copy:
199/// cf.br ^bb2
200/// ```
201void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
202 SmallVector<cf::CondBranchOp> worklist;
203 function.walk(callback: [&](cf::CondBranchOp condBranch) {
204 if (llvm::any_of(Range: condBranch->getOperands(), P: [&](Value value) {
205 return isValidSMETileVectorType(type: value.getType());
206 })) {
207 worklist.push_back(Elt: condBranch);
208 }
209 });
210
211 auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
212 rewriter.setInsertionPointToEnd(source);
213 rewriter.create<cf::BranchOp>(loc, dest, args);
214 };
215
216 for (auto condBranch : worklist) {
217 auto loc = condBranch.getLoc();
218 Block *block = condBranch->getBlock();
219 auto newTrueBranch = rewriter.splitBlock(block, before: block->end());
220 auto newFalseBranch = rewriter.splitBlock(block, before: block->end());
221 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
222 condBranch.getTrueDestOperands());
223 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
224 condBranch.getFalseDestOperands());
225 rewriter.modifyOpInPlace(root: condBranch, callable: [&] {
226 condBranch.getFalseDestOperandsMutable().clear();
227 condBranch.getTrueDestOperandsMutable().clear();
228 condBranch.setSuccessor(block: newTrueBranch, i: 0);
229 condBranch.setSuccessor(block: newFalseBranch, i: 1);
230 });
231 }
232}
233
234/// Inserts tile copies at `cf.br` operations.
235///
236/// BEFORE:
237/// ```mlir
238/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
239/// ```
240///
241/// AFTER:
242/// ```mlir
243/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
244/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
245/// ```
246void insertCopiesAtBranches(IRRewriter &rewriter,
247 FunctionOpInterface function) {
248 for (Block &block : function.getBlocks()) {
249 Operation *terminator = block.getTerminator();
250 if (!isa<cf::BranchOp>(Val: terminator))
251 continue;
252 rewriter.setInsertionPoint(terminator);
253 for (OpOperand &operand : terminator->getOpOperands()) {
254 if (isValidSMETileVectorType(type: operand.get().getType())) {
255 auto copy =
256 rewriter.create<CopyTileOp>(location: terminator->getLoc(), args: operand.get());
257 rewriter.modifyOpInPlace(root: terminator, callable: [&] { operand.assign(value: copy); });
258 }
259 }
260 }
261}
262
263/// Prepares the IR for tile allocation. It does this by first 'splitting'
264/// conditional branches (see `splitCondBranches`), then inserting tile copies
265/// at branch operations. The conditional branches are split to prevent the
266/// copies needed for them overlapping between the true and false paths of the
267/// branch (see `tile-allocation-copies.mlir` and
268/// `tile-allocation-liveness.mlir` for examples). The copies break up live
269/// ranges and ensure when moving out of SSA the semantics of the program are
270/// preserved.
271void preprocessForTileAllocation(IRRewriter &rewriter,
272 FunctionOpInterface function) {
273 splitCondBranches(rewriter, function);
274 insertCopiesAtBranches(rewriter, function);
275}
276
277/// A live range for a (collection of) tile values. A live range is built up of
278/// non-overlapping intervals [start, end) which represent parts of the program
279/// where a value in the range needs to be live (i.e. in an SME virtual tile).
280/// Note that as the intervals are non-overlapping all values within a live
281/// range can be allocated to the same SME virtual tile.
282struct LiveRange {
283 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
284 llvm::IntervalMapHalfOpenInfo<unsigned>>;
285 using Allocator = RangeSet::Allocator;
286 // Dummy value for the IntervalMap. Only the keys matter (the intervals).
287 static constexpr uint8_t kValidLiveRange = 0xff;
288
289 LiveRange(Allocator &allocator)
290 : ranges(std::make_unique<RangeSet>(args&: allocator)) {}
291
292 /// Returns true if this range overlaps with `otherRange`.
293 bool overlaps(LiveRange const &otherRange) const {
294 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
295 *otherRange.ranges)
296 .valid();
297 }
298
299 /// Returns true if this range is active at `point` in the program.
300 bool overlaps(uint64_t point) const {
301 return ranges->lookup(x: point) == kValidLiveRange;
302 }
303
304 /// Unions this live range with `otherRange`, aborts if the ranges overlap.
305 void unionWith(LiveRange const &otherRange) {
306 for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
307 ++it)
308 ranges->insert(a: it.start(), b: it.stop(), y: kValidLiveRange);
309 values.set_union(otherRange.values);
310 }
311
312 /// Inserts an interval [start, end) for `value` into this range.
313 void insert(Value value, unsigned start, unsigned end) {
314 values.insert(X: value);
315 if (start != end)
316 ranges->insert(a: start, b: end, y: kValidLiveRange);
317 }
318
319 bool empty() const { return ranges->empty(); }
320 unsigned start() const { return ranges->start(); }
321 unsigned end() const { return ranges->stop(); }
322 bool operator<(LiveRange const &other) const {
323 return start() < other.start();
324 }
325
326 ArmSMETileType getTileType() const {
327 return *getSMETileType(cast<VectorType>(Val: values[0].getType()));
328 }
329
330 /// The values contained in this live range.
331 SetVector<Value> values;
332
333 /// A set of (non-overlapping) intervals that mark where any value in `values`
334 /// is live.
335 std::unique_ptr<RangeSet> ranges;
336
337 /// The tile ID (or none) assigned to this live range.
338 std::optional<unsigned> tileId;
339};
340
341/// Number operations within a function to allow computing live ranges.
342/// Operations are numbered consecutively wihin blocks, and the blocks are
343/// topologically sorted (using forward edges). This function is only correct if
344/// all ArmSME have been converted to CF (which is asserted).
345DenseMap<Operation *, unsigned>
346generateOperationNumbering(FunctionOpInterface function) {
347 unsigned index = 0;
348 SetVector<Block *> blocks =
349 getBlocksSortedByDominance(region&: function.getFunctionBody());
350 DenseMap<Operation *, unsigned> operationToIndexMap;
351 for (Block *block : blocks) {
352 index++; // We want block args to have their own number.
353 for (Operation &op : block->getOperations()) {
354#ifndef NDEBUG
355 op.walk([&](ArmSMETileOpInterface nestedOp) {
356 assert(&op == nestedOp.getOperation() &&
357 "ArmSME tile allocation does not support nested regions");
358 });
359#endif
360 operationToIndexMap.try_emplace(Key: &op, Args: index++);
361 }
362 }
363 return operationToIndexMap;
364}
365
366/// Gather live ranges for SME tiles from the MLIR liveness analysis.
367DenseMap<Value, LiveRange>
368gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
369 LiveRange::Allocator &liveRangeAllocator,
370 Liveness &liveness, FunctionOpInterface function) {
371 assert(!operationToIndexMap.empty() && "expected operation numbering");
372 DenseMap<Value, LiveRange> liveRanges;
373 /// Defines or updates a live range for an SME tile value. Live-ins may update
374 /// an existing live range (rather than define a new one). Note: If
375 /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
376 /// the block.
377 auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
378 LivenessBlockInfo const &livenessInfo,
379 bool liveAtBlockEntry = false) {
380 if (!isValidSMETileVectorType(type: value.getType()))
381 return;
382 // Find or create a live range for `value`.
383 auto [it, _] = liveRanges.try_emplace(Key: value, Args&: liveRangeAllocator);
384 LiveRange &valueLiveRange = it->second;
385 auto lastUseInBlock = livenessInfo.getEndOperation(value, startOperation: firstUseOrDef);
386 // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
387 unsigned startOpIdx =
388 operationToIndexMap.at(Val: firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
389 unsigned endOpIdx = operationToIndexMap.at(Val: lastUseInBlock);
390 valueLiveRange.insert(value, start: startOpIdx, end: endOpIdx);
391 };
392
393 for (Block &block : function.getBlocks()) {
394 LivenessBlockInfo const *livenessInfo = liveness.getLiveness(block: &block);
395 // Handle block arguments:
396 for (Value argument : block.getArguments())
397 defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
398 /*liveAtBlockEntry=*/true);
399 // Handle live-ins:
400 for (Value liveIn : livenessInfo->in())
401 defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
402 /*liveAtBlockEntry=*/true);
403 // Handle new definitions:
404 for (Operation &op : block) {
405 for (Value result : op.getResults())
406 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
407 }
408 }
409
410 return liveRanges;
411}
412
413/// Iterate over all predecessor tile values to a (tile) block argument.
414static void forEachPredecessorTileValue(BlockArgument blockArg,
415 function_ref<void(Value)> callback) {
416 Block *block = blockArg.getOwner();
417 unsigned argNumber = blockArg.getArgNumber();
418 for (Block *pred : block->getPredecessors()) {
419 TypeSwitch<Operation *>(pred->getTerminator())
420 .Case<cf::BranchOp>(caseFn: [&](auto branch) {
421 Value predecessorOperand = branch.getDestOperands()[argNumber];
422 callback(predecessorOperand);
423 })
424 .Case<cf::CondBranchOp>(caseFn: [&](auto condBranch) {
425 if (condBranch.getFalseDest() == block) {
426 Value predecessorOperand =
427 condBranch.getFalseDestOperands()[argNumber];
428 callback(predecessorOperand);
429 }
430 if (condBranch.getTrueDest() == block) {
431 Value predecessorOperand =
432 condBranch.getTrueDestOperands()[argNumber];
433 callback(predecessorOperand);
434 }
435 });
436 }
437}
438
439/// Coalesce live ranges where it would prevent unnecessary tile moves.
440SmallVector<LiveRange *>
441coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
442 DenseMap<Value, LiveRange *> liveRanges;
443 for (auto &[value, liveRange] : initialLiveRanges) {
444 liveRanges.insert(KV: {value, &liveRange});
445 }
446
447 // Merge the live ranges of values `a` and `b` into one (if they do not
448 // overlap). After this, the values `a` and `b` will both point to the same
449 // live range (which will contain multiple values).
450 auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
451 LiveRange *aLiveRange = liveRanges.at(Val: a);
452 LiveRange *bLiveRange = liveRanges.at(Val: b);
453 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(otherRange: *bLiveRange)) {
454 aLiveRange->unionWith(otherRange: *bLiveRange);
455 for (Value value : bLiveRange->values)
456 liveRanges[value] = aLiveRange;
457 }
458 };
459
460 // Merge the live ranges of new definitions with their tile operands.
461 auto unifyDefinitionsWithOperands = [&](Value value) {
462 auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
463 if (!armSMEOp)
464 return;
465 for (auto operand : armSMEOp->getOperands()) {
466 if (isValidSMETileVectorType(type: operand.getType()))
467 mergeValuesIfNonOverlapping(value, operand);
468 }
469 };
470
471 // Merge the live ranges of block arguments with their predecessors.
472 auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
473 auto blockArg = dyn_cast<BlockArgument>(Val&: value);
474 if (!blockArg)
475 return;
476 forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) {
477 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
478 });
479 };
480
481 auto applyRule = [&](auto rule) {
482 llvm::for_each(llvm::make_first_range(c&: initialLiveRanges), rule);
483 };
484
485 // Unify as many live ranges as we can. This prevents unnecessary moves.
486 applyRule(unifyBlockArgumentsWithPredecessors);
487 applyRule(unifyDefinitionsWithOperands);
488
489 // Remove duplicate live range entries.
490 SetVector<LiveRange *> uniqueLiveRanges;
491 for (auto [_, liveRange] : liveRanges) {
492 if (!liveRange->empty())
493 uniqueLiveRanges.insert(X: liveRange);
494 }
495
496 // Sort the new live ranges by starting point (ready for tile allocation).
497 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
498 llvm::sort(C&: coalescedLiveRanges,
499 Comp: [](LiveRange *a, LiveRange *b) { return *a < *b; });
500 return std::move(coalescedLiveRanges);
501}
502
503/// Choose a live range to spill (via some heuristics). This picks either a live
504/// range from `overlappingRanges`, or the new live range `newRange`.
505template <typename OverlappingRangesIterator>
506LiveRange *
507chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
508 LiveRange *newRange) {
509 // Heuristic: Spill trivially copyable operations (usually free).
510 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
511 return isTileTypeGreaterOrEqual(typeA: allocatedRange.getTileType(),
512 typeB: newRange->getTileType()) &&
513 allocatedRange.values.size() == 1 &&
514 isTriviallyCloneableTileOp(
515 tileOp: allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
516 };
517 if (isTrivialSpill(*newRange))
518 return newRange;
519 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
520 if (trivialSpill != overlappingRanges.end())
521 return &*trivialSpill;
522
523 // Heuristic: Spill the range that ends last (with a compatible tile type).
524 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
525 return !isTileTypeGreaterOrEqual(typeA: a.getTileType(), typeB: b.getTileType()) ||
526 a.end() < b.end();
527 };
528 LiveRange &latestEndingLiveRange =
529 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
530 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
531 return &latestEndingLiveRange;
532 return newRange;
533}
534
535/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
536void allocateTilesToLiveRanges(
537 ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
538 TileAllocator tileAllocator;
539 // `activeRanges` = Live ranges that need to be in a tile at the
540 // `currentPoint` in the program.
541 SetVector<LiveRange *> activeRanges;
542 // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
543 // at the `currentPoint` in the program but could become active again later.
544 // An inactive section of a live range can be seen as a 'hole' in the live
545 // range, where it is possible to reuse the live range's tile ID _before_ it
546 // has ended. By identifying 'holes', the allocator can reuse tiles more
547 // often, which helps avoid costly tile spills.
548 SetVector<LiveRange *> inactiveRanges;
549 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
550 auto currentPoint = nextRange->start();
551 // 1. Update the `activeRanges` at `currentPoint`.
552 activeRanges.remove_if(P: [&](LiveRange *activeRange) {
553 // Check for live ranges that have expired.
554 if (activeRange->end() <= currentPoint) {
555 tileAllocator.releaseTileId(tileType: activeRange->getTileType(),
556 tileId: *activeRange->tileId);
557 return true;
558 }
559 // Check for live ranges that have become inactive.
560 if (!activeRange->overlaps(point: currentPoint)) {
561 tileAllocator.releaseTileId(tileType: activeRange->getTileType(),
562 tileId: *activeRange->tileId);
563 inactiveRanges.insert(X: activeRange);
564 return true;
565 }
566 return false;
567 });
568 // 2. Update the `inactiveRanges` at `currentPoint`.
569 inactiveRanges.remove_if(P: [&](LiveRange *inactiveRange) {
570 // Check for live ranges that have expired.
571 if (inactiveRange->end() <= currentPoint) {
572 return true;
573 }
574 // Check for live ranges that have become active.
575 if (inactiveRange->overlaps(point: currentPoint)) {
576 tileAllocator.acquireTileId(tileType: inactiveRange->getTileType(),
577 tileId: *inactiveRange->tileId);
578 activeRanges.insert(X: inactiveRange);
579 return true;
580 }
581 return false;
582 });
583
584 // 3. Collect inactive live ranges that overlap with the new live range.
585 // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
586 // whereas this checks if there is an overlap at any future point too.
587 SmallVector<LiveRange *> overlappingInactiveRanges;
588 for (LiveRange *inactiveRange : inactiveRanges) {
589 if (inactiveRange->overlaps(otherRange: *nextRange)) {
590 // We need to reserve the tile IDs of overlapping inactive ranges to
591 // prevent two (overlapping) live ranges from getting the same tile ID.
592 tileAllocator.acquireTileId(tileType: inactiveRange->getTileType(),
593 tileId: *inactiveRange->tileId);
594 overlappingInactiveRanges.push_back(Elt: inactiveRange);
595 }
596 }
597
598 // 4. Allocate a tile ID to `nextRange`.
599 auto rangeTileType = nextRange->getTileType();
600 auto tileId = tileAllocator.allocateTileId(tileType: rangeTileType);
601 if (succeeded(Result: tileId)) {
602 nextRange->tileId = *tileId;
603 } else {
604 // Create an iterator over all overlapping live ranges.
605 auto allOverlappingRanges = llvm::concat<LiveRange>(
606 Ranges: llvm::make_pointee_range(Range: activeRanges.getArrayRef()),
607 Ranges: llvm::make_pointee_range(Range&: overlappingInactiveRanges));
608 // Choose an overlapping live range to spill.
609 LiveRange *rangeToSpill =
610 chooseSpillUsingHeuristics(overlappingRanges: allOverlappingRanges, newRange: nextRange);
611 if (rangeToSpill != nextRange) {
612 // Spill an (in)active live range (so release its tile ID first).
613 tileAllocator.releaseTileId(tileType: rangeToSpill->getTileType(),
614 tileId: *rangeToSpill->tileId);
615 // This will always succeed after a spill (of an active live range).
616 nextRange->tileId = *tileAllocator.allocateTileId(tileType: rangeTileType);
617 // Remove the live range from the active/inactive sets.
618 if (!activeRanges.remove(X: rangeToSpill)) {
619 bool removed = inactiveRanges.remove(X: rangeToSpill);
620 assert(removed && "expected a range to be removed!");
621 (void)removed;
622 }
623 }
624 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
625 }
626
627 // 5. Insert the live range into the active ranges.
628 if (nextRange->tileId < kInMemoryTileIdBase)
629 activeRanges.insert(X: nextRange);
630
631 // 6. Release tiles reserved for inactive live ranges (in step 3).
632 for (LiveRange *range : overlappingInactiveRanges) {
633 if (*range->tileId < kInMemoryTileIdBase)
634 tileAllocator.releaseTileId(tileType: range->getTileType(), tileId: *range->tileId);
635 }
636 }
637}
638
639/// Assigns a tile ID to an MLIR value.
640void assignTileIdToValue(IRRewriter &rewriter, Value value,
641 IntegerAttr tileIdAttr) {
642 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
643 rewriter.modifyOpInPlace(root: tileOp, callable: [&] { tileOp.setTileId(tileIdAttr); });
644 for (Operation *user : value.getUsers()) {
645 if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(Val: user)) {
646 // Ensure ArmSME ops that don't produce a value still get a tile ID.
647 if (!hasTileResult(tileOp))
648 rewriter.modifyOpInPlace(root: tileOp, callable: [&] { tileOp.setTileId(tileIdAttr); });
649 }
650 }
651}
652
653/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
654LogicalResult assignTileIdsAndResolveTrivialConflicts(
655 IRRewriter &rewriter, FunctionOpInterface function,
656 ArrayRef<LiveRange *> allocatedLiveRanges) {
657 for (LiveRange const *liveRange : allocatedLiveRanges) {
658 auto tileIdAttr = rewriter.getI32IntegerAttr(value: *liveRange->tileId);
659 auto isAllocatedToSameTile = [&](Value value) {
660 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
661 tileOp && tileOp.getTileId() == tileIdAttr)
662 return true;
663 return liveRange->values.contains(key: value);
664 };
665
666 /// Eliminates copies where the operand has the same tile ID.
667 auto foldRedundantCopies = [&](Value value) -> LogicalResult {
668 auto copyOp = value.getDefiningOp<CopyTileOp>();
669 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
670 return failure();
671 rewriter.replaceAllUsesWith(from: copyOp, to: copyOp.getTile());
672 return success();
673 };
674
675 /// Validates each predecessor to a tile block argument has been assigned
676 /// the same tile ID.
677 auto validateBlockArguments = [&](Value value) {
678 auto blockArg = dyn_cast<BlockArgument>(Val&: value);
679 if (!blockArg) {
680 // Not a block argument (nothing to validate).
681 return success();
682 }
683 bool tileMismatch = false;
684 forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) {
685 if (tileMismatch)
686 return;
687 if (!isAllocatedToSameTile(predecessorTile)) {
688 blockArg.getOwner()->getParentOp()->emitOpError(
689 message: "block argument not allocated to the same SME virtial tile as "
690 "predecessors");
691 tileMismatch = true;
692 }
693 });
694 return success(/*isSuccess=*/IsSuccess: !tileMismatch);
695 };
696
697 /// Attempts to resolve (trivial) tile ID conflicts.
698 auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
699 auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
700 OpOperand *tileOperand = getTileOpOperand(tileOp);
701 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
702 // Operand already allocated to the correct tile.
703 // No conflict to resolve.
704 return success();
705 }
706 auto operandTileOp =
707 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
708 if (!isTriviallyCloneableTileOp(tileOp: operandTileOp)) {
709 auto error =
710 tileOp.emitOpError(message: "tile operand allocated to different SME "
711 "virtial tile (move required)");
712 error.attachNote(noteLoc: tileOperand->get().getLoc())
713 << "tile operand is: " << tileOperand->get();
714 return error;
715 }
716 // Cloning prevents a move/spill (though may require recomputation).
717 rewriter.setInsertionPoint(tileOp);
718 auto clonedOp = operandTileOp.clone();
719 rewriter.modifyOpInPlace(root: clonedOp,
720 callable: [&] { clonedOp.setTileId(tileOp.getTileId()); });
721 rewriter.insert(op: clonedOp);
722 if (isa<CopyTileOp>(Val: tileOp)) {
723 rewriter.replaceAllUsesWith(from: tileOp->getResult(idx: 0),
724 to: clonedOp->getResult(idx: 0));
725 } else {
726 rewriter.modifyOpInPlace(
727 root: tileOp, callable: [&] { tileOperand->assign(value: clonedOp->getResult(idx: 0)); });
728 }
729 return success();
730 };
731
732 for (Value value : liveRange->values) {
733 // 1. Assign the tile ID to the value.
734 assignTileIdToValue(rewriter, value, tileIdAttr);
735
736 // 2. Attempt to eliminate redundant tile copies.
737 if (succeeded(Result: foldRedundantCopies(value)))
738 continue;
739
740 // 3. Validate tile block arguments.
741 if (failed(Result: validateBlockArguments(value)))
742 return failure();
743
744 // 4. Attempt to resolve (trivial) tile ID conflicts.
745 if (failed(Result: resolveTrivialTileConflicts(value)))
746 return failure();
747 }
748 }
749 return success();
750}
751
752/// Prints live ranges alongside operation names for debugging.
753void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
754 ArrayRef<LiveRange const *> liveRanges,
755 FunctionOpInterface function) {
756 llvm::errs() << "SME Tile Liveness: @" << function.getName()
757 << "\nKey:\nS - Start\nE - End\n| - Live\n";
758 for (auto [blockIdx, block] : llvm::enumerate(First&: function.getBlocks())) {
759 llvm::errs() << "^bb" << blockIdx << ":\n";
760 for (Operation &op : block.getOperations()) {
761 unsigned operationIndex = operationToIndexMap.at(Val: &op);
762 for (LiveRange const *range : liveRanges) {
763 char liveness = ' ';
764 for (auto it = range->ranges->begin(); it != range->ranges->end();
765 ++it) {
766 if (it.start() == operationIndex)
767 liveness = (liveness == 'E' ? '|' : 'S');
768 else if (it.stop() == operationIndex)
769 liveness = (liveness == 'S' ? '|' : 'E');
770 else if (operationIndex >= it.start() && operationIndex < it.stop())
771 liveness = '|';
772 }
773 llvm::errs() << liveness;
774 }
775 llvm::errs() << ' ' << op.getName() << '\n';
776 }
777 }
778 llvm::errs() << "==========\n";
779}
780
781struct TestTileAllocationPass
782 : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
783 using TestTileAllocationBase::TestTileAllocationBase;
784 void runOnOperation() override {
785 FunctionOpInterface function = getOperation();
786 if (preprocessOnly) {
787 IRRewriter rewriter(function);
788 return preprocessForTileAllocation(rewriter, function);
789 }
790 if (failed(Result: arm_sme::allocateSMETiles(function, dumpRanges: dumpTileLiveRanges)))
791 signalPassFailure();
792 }
793};
794} // namespace
795
796LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
797 bool dumpRanges) {
798 if (function.empty()) {
799 // TODO: Also return early if the function contains no ArmSME ops?
800 return success();
801 }
802
803 LiveRange::Allocator liveRangeAllocator;
804 IRRewriter rewriter(function.getContext());
805
806 // 1. Preprocess the IR for tile allocation.
807 preprocessForTileAllocation(rewriter, function);
808
809 // 2. Gather live ranges for each ArmSME tile within the function.
810 Liveness liveness(function);
811 auto operationToIndexMap = generateOperationNumbering(function);
812 auto initialLiveRanges = gatherTileLiveRanges(
813 operationToIndexMap, liveRangeAllocator, liveness, function);
814 if (initialLiveRanges.empty())
815 return success();
816
817 if (dumpRanges) {
818 // Wrangle initial live ranges into a form suitable for printing.
819 auto nonEmpty = llvm::make_filter_range(
820 Range: llvm::make_second_range(c&: initialLiveRanges),
821 Pred: [&](LiveRange const &liveRange) { return !liveRange.empty(); });
822 auto initialRanges = llvm::to_vector(Range: llvm::map_range(
823 C&: nonEmpty, F: [](LiveRange const &liveRange) { return &liveRange; }));
824 llvm::sort(C&: initialRanges,
825 Comp: [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
826 llvm::errs() << "\n========== Initial Live Ranges:\n";
827 dumpLiveRanges(operationToIndexMap, liveRanges: initialRanges, function);
828 }
829
830 // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
831 // for tile allocation. E.g. Unify the result of an operation with its
832 // operands.
833 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
834
835 if (dumpRanges) {
836 llvm::errs() << "\n========== Coalesced Live Ranges:\n";
837 dumpLiveRanges(operationToIndexMap, liveRanges: coalescedLiveRanges, function);
838 }
839
840 // 4. Allocate tile IDs to live ranges.
841 allocateTilesToLiveRanges(liveRangesSortedByStartPoint: coalescedLiveRanges);
842
843 // 5. Assign the tile IDs back to the ArmSME operations.
844 if (failed(Result: assignTileIdsAndResolveTrivialConflicts(rewriter, function,
845 allocatedLiveRanges: coalescedLiveRanges))) {
846 return failure();
847 }
848
849 // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
850 // users). This prevents the LLVM conversion needlessly inserting spills.
851 eraseTriviallyDeadTileOps(rewriter, function);
852 return success();
853}
854

source code of mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp