1//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10// operations. It roughly implements a linear scan register allocator, similar
11// to the one outlined in [1], but with simplifications and assumptions made for
12// our use case. Note that this is a greedy allocator (so it may not always find
13// the most optimal allocation of tiles).
14//
15// The allocator operates at the CF dialect level. It is the responsibility of
16// users to ensure the IR has been lowered to CF before invoking the tile
17// allocator.
18//
19// The 128-bit tiles overlap with other element tiles as follows (see section
20// B2.3.2 of SME spec [2]):
21//
22// Tile Overlaps
23// ---------------------------------------------------------------------------
24// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
25// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
26// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
27// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
28// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
29// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
30// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
31// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
32// ZA0.D ZA0.Q, ZA8.Q
33// ZA1.D ZA1.Q, ZA9.Q
34// ZA2.D ZA2.Q, ZA10.Q
35// ZA3.D ZA3.Q, ZA11.Q
36// ZA4.D ZA4.Q, ZA12.Q
37// ZA5.D ZA5.Q, ZA13.Q
38// ZA6.D ZA6.Q, ZA14.Q
39// ZA7.D ZA7.Q, ZA15.Q
40//
41// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44// [2] https://developer.arm.com/documentation/ddi0616/aa
45//
46//===----------------------------------------------------------------------===//
47
48#include "mlir/Analysis/Liveness.h"
49#include "mlir/Analysis/TopologicalSortUtils.h"
50#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
51#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
52#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
53#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
54#include "mlir/Dialect/Func/IR/FuncOps.h"
55#include "mlir/Transforms/RegionUtils.h"
56#include "llvm/ADT/IntervalMap.h"
57#include "llvm/ADT/TypeSwitch.h"
58#include <algorithm>
59
60namespace mlir::arm_sme {
61#define GEN_PASS_DEF_TESTTILEALLOCATION
62#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
63} // namespace mlir::arm_sme
64
65using namespace mlir;
66using namespace mlir::arm_sme;
67
68namespace {
69
70enum class TileMask : unsigned {
71 // clang-format off
72 kZA0B = 0xffff, // 1111 1111 1111 1111
73
74 kZA0H = 0xaaaa, // 1010 1010 1010 1010
75 kZA1H = 0x5555, // 0101 0101 0101 0101
76
77 kZA0S = 0x8888, // 1000 1000 1000 1000
78 kZA1S = 0x4444, // 0100 0100 0100 0100
79 kZA2S = 0x2222, // 0010 0010 0010 0010
80 kZA3S = 0x1111, // 0001 0001 0001 0001
81
82 kZA0D = 0x8080, // 1000 0000 1000 0000
83 kZA1D = 0x4040, // 0100 0000 0100 0000
84 kZA2D = 0x2020, // 0010 0000 0010 0000
85 kZA3D = 0x1010, // 0001 0000 0001 0000
86 kZA4D = 0x808, // 0000 1000 0000 1000
87 kZA5D = 0x404, // 0000 0100 0000 0100
88 kZA6D = 0x202, // 0000 0010 0000 0010
89 kZA7D = 0x101, // 0000 0001 0000 0001
90
91 kZA0Q = 0x8000, // 1000 0000 0000 0000
92 kZA1Q = 0x4000, // 0100 0000 0000 0000
93 kZA2Q = 0x2000, // 0010 0000 0000 0000
94 kZA3Q = 0x1000, // 0001 0000 0000 0000
95 kZA4Q = 0x800, // 0000 1000 0000 0000
96 kZA5Q = 0x400, // 0000 0100 0000 0000
97 kZA6Q = 0x200, // 0000 0010 0000 0000
98 kZA7Q = 0x100, // 0000 0001 0000 0000
99 kZA8Q = 0x80, // 0000 0000 1000 0000
100 kZA9Q = 0x40, // 0000 0000 0100 0000
101 kZA10Q = 0x20, // 0000 0000 0010 0000
102 kZA11Q = 0x10, // 0000 0000 0001 0000
103 kZA12Q = 0x8, // 0000 0000 0000 1000
104 kZA13Q = 0x4, // 0000 0000 0000 0100
105 kZA14Q = 0x2, // 0000 0000 0000 0010
106 kZA15Q = 0x1, // 0000 0000 0000 0001
107
108 kNone = 0x0, // 0000 0000 0000 0000
109 // clang-format on
110
111 LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
112};
113
114/// Returns the set of masks relevant for the given type.
115static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
116 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
117 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
118 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
119 TileMask::kZA2S, TileMask::kZA3S};
120 static constexpr std::array ZA_D_MASKS = {
121 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
122 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
123 static constexpr std::array ZA_Q_MASKS = {
124 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
125 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
126 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
127 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
128 switch (type) {
129 case ArmSMETileType::ZAB:
130 return ZA_B_MASKS;
131 case ArmSMETileType::ZAH:
132 return ZA_H_MASKS;
133 case ArmSMETileType::ZAS:
134 return ZA_S_MASKS;
135 case ArmSMETileType::ZAD:
136 return ZA_D_MASKS;
137 case ArmSMETileType::ZAQ:
138 return ZA_Q_MASKS;
139 }
140 llvm_unreachable("unknown type in getMasks");
141}
142
143class TileAllocator {
144public:
145 /// Allocates and returns a tile ID. Fails if there are no tiles left.
146 FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
147 auto masks = getMasks(tileType);
148 for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
149 if ((tilesInUse & tileMask) == TileMask::kNone) {
150 tilesInUse |= tileMask;
151 return tileId;
152 }
153 }
154 return failure();
155 }
156
157 /// Acquires a specific tile ID. Asserts the tile is initially free.
158 void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
159 TileMask tileMask = getMasks(tileType)[tileId];
160 assert((tilesInUse & tileMask) == TileMask::kNone &&
161 "cannot acquire allocated tile!");
162 tilesInUse |= tileMask;
163 }
164
165 /// Releases a previously allocated tile ID.
166 void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
167 TileMask tileMask = getMasks(tileType)[tileId];
168 assert((tilesInUse & tileMask) == tileMask &&
169 "cannot release unallocated tile!");
170 tilesInUse ^= tileMask;
171 }
172
173 /// Allocates an in-memory tile ID.
174 unsigned allocateInMemoryTileId() {
175 // Note: We never release in-memory tile IDs. We could, which may allow
176 // reusing an allocation, but as we _never_ want to spill an SME tile this
177 // is not optimized.
178 return nextInMemoryTileId++;
179 }
180
181private:
182 TileMask tilesInUse = TileMask::kNone;
183 unsigned nextInMemoryTileId = kInMemoryTileIdBase;
184};
185
186/// Add new intermediate blocks for the true and false destinations of
187/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
188/// overlaps due to copies at branches.
189///
190/// BEFORE:
191/// ```mlir
192/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
193/// ```
194///
195/// AFTER:
196/// ```mlir
197/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
198/// ^bb1_copy:
199/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
200/// ^bb2_copy:
201/// cf.br ^bb2
202/// ```
203void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
204 SmallVector<cf::CondBranchOp> worklist;
205 function.walk([&](cf::CondBranchOp condBranch) {
206 if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
207 return isValidSMETileVectorType(type: value.getType());
208 })) {
209 worklist.push_back(condBranch);
210 }
211 });
212
213 auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
214 rewriter.setInsertionPointToEnd(source);
215 rewriter.create<cf::BranchOp>(loc, dest, args);
216 };
217
218 for (auto condBranch : worklist) {
219 auto loc = condBranch.getLoc();
220 Block *block = condBranch->getBlock();
221 auto newTrueBranch = rewriter.splitBlock(block, block->end());
222 auto newFalseBranch = rewriter.splitBlock(block, block->end());
223 insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
224 condBranch.getTrueDestOperands());
225 insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
226 condBranch.getFalseDestOperands());
227 rewriter.modifyOpInPlace(condBranch, [&] {
228 condBranch.getFalseDestOperandsMutable().clear();
229 condBranch.getTrueDestOperandsMutable().clear();
230 condBranch.setSuccessor(newTrueBranch, 0);
231 condBranch.setSuccessor(newFalseBranch, 1);
232 });
233 }
234}
235
236/// Inserts tile copies at `cf.br` operations.
237///
238/// BEFORE:
239/// ```mlir
240/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
241/// ```
242///
243/// AFTER:
244/// ```mlir
245/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
246/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
247/// ```
248void insertCopiesAtBranches(IRRewriter &rewriter,
249 FunctionOpInterface function) {
250 for (Block &block : function.getBlocks()) {
251 Operation *terminator = block.getTerminator();
252 if (!isa<cf::BranchOp>(terminator))
253 continue;
254 rewriter.setInsertionPoint(terminator);
255 for (OpOperand &operand : terminator->getOpOperands()) {
256 if (isValidSMETileVectorType(operand.get().getType())) {
257 auto copy =
258 rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
259 rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
260 }
261 }
262 }
263}
264
265/// Prepares the IR for tile allocation. It does this by first 'splitting'
266/// conditional branches (see `splitCondBranches`), then inserting tile copies
267/// at branch operations. The conditional branches are split to prevent the
268/// copies needed for them overlapping between the true and false paths of the
269/// branch (see `tile-allocation-copies.mlir` and
270/// `tile-allocation-liveness.mlir` for examples). The copies break up live
271/// ranges and ensure when moving out of SSA the semantics of the program are
272/// preserved.
273void preprocessForTileAllocation(IRRewriter &rewriter,
274 FunctionOpInterface function) {
275 splitCondBranches(rewriter, function);
276 insertCopiesAtBranches(rewriter, function);
277}
278
279/// A live range for a (collection of) tile values. A live range is built up of
280/// non-overlapping intervals [start, end) which represent parts of the program
281/// where a value in the range needs to be live (i.e. in an SME virtual tile).
282/// Note that as the intervals are non-overlapping all values within a live
283/// range can be allocated to the same SME virtual tile.
284struct LiveRange {
285 using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
286 llvm::IntervalMapHalfOpenInfo<unsigned>>;
287 using Allocator = RangeSet::Allocator;
288 // Dummy value for the IntervalMap. Only the keys matter (the intervals).
289 static constexpr uint8_t kValidLiveRange = 0xff;
290
291 LiveRange(Allocator &allocator)
292 : ranges(std::make_unique<RangeSet>(args&: allocator)) {}
293
294 /// Returns true if this range overlaps with `otherRange`.
295 bool overlaps(LiveRange const &otherRange) const {
296 return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
297 *otherRange.ranges)
298 .valid();
299 }
300
301 /// Returns true if this range is active at `point` in the program.
302 bool overlaps(uint64_t point) const {
303 return ranges->lookup(x: point) == kValidLiveRange;
304 }
305
306 /// Unions this live range with `otherRange`, aborts if the ranges overlap.
307 void unionWith(LiveRange const &otherRange) {
308 for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
309 ++it)
310 ranges->insert(a: it.start(), b: it.stop(), y: kValidLiveRange);
311 values.set_union(otherRange.values);
312 }
313
314 /// Inserts an interval [start, end) for `value` into this range.
315 void insert(Value value, unsigned start, unsigned end) {
316 values.insert(X: value);
317 if (start != end)
318 ranges->insert(a: start, b: end, y: kValidLiveRange);
319 }
320
321 bool empty() const { return ranges->empty(); }
322 unsigned start() const { return ranges->start(); }
323 unsigned end() const { return ranges->stop(); }
324 bool operator<(LiveRange const &other) const {
325 return start() < other.start();
326 }
327
328 ArmSMETileType getTileType() const {
329 return *getSMETileType(cast<VectorType>(values[0].getType()));
330 }
331
332 /// The values contained in this live range.
333 SetVector<Value> values;
334
335 /// A set of (non-overlapping) intervals that mark where any value in `values`
336 /// is live.
337 std::unique_ptr<RangeSet> ranges;
338
339 /// The tile ID (or none) assigned to this live range.
340 std::optional<unsigned> tileId;
341};
342
343/// Number operations within a function to allow computing live ranges.
344/// Operations are numbered consecutively wihin blocks, and the blocks are
345/// topologically sorted (using forward edges). This function is only correct if
346/// all ArmSME have been converted to CF (which is asserted).
347DenseMap<Operation *, unsigned>
348generateOperationNumbering(FunctionOpInterface function) {
349 unsigned index = 0;
350 SetVector<Block *> blocks =
351 getBlocksSortedByDominance(function.getFunctionBody());
352 DenseMap<Operation *, unsigned> operationToIndexMap;
353 for (Block *block : blocks) {
354 index++; // We want block args to have their own number.
355 for (Operation &op : block->getOperations()) {
356#ifndef NDEBUG
357 op.walk([&](ArmSMETileOpInterface nestedOp) {
358 assert(&op == nestedOp.getOperation() &&
359 "ArmSME tile allocation does not support nested regions");
360 });
361#endif
362 operationToIndexMap.try_emplace(&op, index++);
363 }
364 }
365 return operationToIndexMap;
366}
367
368/// Gather live ranges for SME tiles from the MLIR liveness analysis.
369DenseMap<Value, LiveRange>
370gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
371 LiveRange::Allocator &liveRangeAllocator,
372 Liveness &liveness, FunctionOpInterface function) {
373 assert(!operationToIndexMap.empty() && "expected operation numbering");
374 DenseMap<Value, LiveRange> liveRanges;
375 /// Defines or updates a live range for an SME tile value. Live-ins may update
376 /// an existing live range (rather than define a new one). Note: If
377 /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
378 /// the block.
379 auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
380 LivenessBlockInfo const &livenessInfo,
381 bool liveAtBlockEntry = false) {
382 if (!isValidSMETileVectorType(type: value.getType()))
383 return;
384 // Find or create a live range for `value`.
385 auto [it, _] = liveRanges.try_emplace(Key: value, Args&: liveRangeAllocator);
386 LiveRange &valueLiveRange = it->second;
387 auto lastUseInBlock = livenessInfo.getEndOperation(value, startOperation: firstUseOrDef);
388 // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
389 unsigned startOpIdx =
390 operationToIndexMap.at(Val: firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
391 unsigned endOpIdx = operationToIndexMap.at(Val: lastUseInBlock);
392 valueLiveRange.insert(value, start: startOpIdx, end: endOpIdx);
393 };
394
395 for (Block &block : function.getBlocks()) {
396 LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
397 // Handle block arguments:
398 for (Value argument : block.getArguments())
399 defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
400 /*liveAtBlockEntry=*/true);
401 // Handle live-ins:
402 for (Value liveIn : livenessInfo->in())
403 defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
404 /*liveAtBlockEntry=*/true);
405 // Handle new definitions:
406 for (Operation &op : block) {
407 for (Value result : op.getResults())
408 defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
409 }
410 }
411
412 return liveRanges;
413}
414
415/// Iterate over all predecessor tile values to a (tile) block argument.
416static void forEachPredecessorTileValue(BlockArgument blockArg,
417 function_ref<void(Value)> callback) {
418 Block *block = blockArg.getOwner();
419 unsigned argNumber = blockArg.getArgNumber();
420 for (Block *pred : block->getPredecessors()) {
421 TypeSwitch<Operation *>(pred->getTerminator())
422 .Case<cf::BranchOp>([&](auto branch) {
423 Value predecessorOperand = branch.getDestOperands()[argNumber];
424 callback(predecessorOperand);
425 })
426 .Case<cf::CondBranchOp>([&](auto condBranch) {
427 if (condBranch.getFalseDest() == block) {
428 Value predecessorOperand =
429 condBranch.getFalseDestOperands()[argNumber];
430 callback(predecessorOperand);
431 }
432 if (condBranch.getTrueDest() == block) {
433 Value predecessorOperand =
434 condBranch.getTrueDestOperands()[argNumber];
435 callback(predecessorOperand);
436 }
437 });
438 }
439}
440
441/// Coalesce live ranges where it would prevent unnecessary tile moves.
442SmallVector<LiveRange *>
443coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
444 DenseMap<Value, LiveRange *> liveRanges;
445 for (auto &[value, liveRange] : initialLiveRanges) {
446 liveRanges.insert(KV: {value, &liveRange});
447 }
448
449 // Merge the live ranges of values `a` and `b` into one (if they do not
450 // overlap). After this, the values `a` and `b` will both point to the same
451 // live range (which will contain multiple values).
452 auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
453 LiveRange *aLiveRange = liveRanges.at(Val: a);
454 LiveRange *bLiveRange = liveRanges.at(Val: b);
455 if (aLiveRange != bLiveRange && !aLiveRange->overlaps(otherRange: *bLiveRange)) {
456 aLiveRange->unionWith(otherRange: *bLiveRange);
457 for (Value value : bLiveRange->values)
458 liveRanges[value] = aLiveRange;
459 }
460 };
461
462 // Merge the live ranges of new definitions with their tile operands.
463 auto unifyDefinitionsWithOperands = [&](Value value) {
464 auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
465 if (!armSMEOp)
466 return;
467 for (auto operand : armSMEOp->getOperands()) {
468 if (isValidSMETileVectorType(operand.getType()))
469 mergeValuesIfNonOverlapping(value, operand);
470 }
471 };
472
473 // Merge the live ranges of block arguments with their predecessors.
474 auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
475 auto blockArg = dyn_cast<BlockArgument>(Val&: value);
476 if (!blockArg)
477 return;
478 forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) {
479 mergeValuesIfNonOverlapping(blockArg, predecessorTile);
480 });
481 };
482
483 auto applyRule = [&](auto rule) {
484 llvm::for_each(llvm::make_first_range(c&: initialLiveRanges), rule);
485 };
486
487 // Unify as many live ranges as we can. This prevents unnecessary moves.
488 applyRule(unifyBlockArgumentsWithPredecessors);
489 applyRule(unifyDefinitionsWithOperands);
490
491 // Remove duplicate live range entries.
492 SetVector<LiveRange *> uniqueLiveRanges;
493 for (auto [_, liveRange] : liveRanges) {
494 if (!liveRange->empty())
495 uniqueLiveRanges.insert(X: liveRange);
496 }
497
498 // Sort the new live ranges by starting point (ready for tile allocation).
499 auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
500 llvm::sort(C&: coalescedLiveRanges,
501 Comp: [](LiveRange *a, LiveRange *b) { return *a < *b; });
502 return std::move(coalescedLiveRanges);
503}
504
505/// Choose a live range to spill (via some heuristics). This picks either a live
506/// range from `overlappingRanges`, or the new live range `newRange`.
507template <typename OverlappingRangesIterator>
508LiveRange *
509chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
510 LiveRange *newRange) {
511 // Heuristic: Spill trivially copyable operations (usually free).
512 auto isTrivialSpill = [&](LiveRange &allocatedRange) {
513 return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
514 newRange->getTileType()) &&
515 allocatedRange.values.size() == 1 &&
516 isTriviallyCloneableTileOp(
517 allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
518 };
519 if (isTrivialSpill(*newRange))
520 return newRange;
521 auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
522 if (trivialSpill != overlappingRanges.end())
523 return &*trivialSpill;
524
525 // Heuristic: Spill the range that ends last (with a compatible tile type).
526 auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
527 return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
528 a.end() < b.end();
529 };
530 LiveRange &latestEndingLiveRange =
531 *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
532 if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
533 return &latestEndingLiveRange;
534 return newRange;
535}
536
537/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
538void allocateTilesToLiveRanges(
539 ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
540 TileAllocator tileAllocator;
541 // `activeRanges` = Live ranges that need to be in a tile at the
542 // `currentPoint` in the program.
543 SetVector<LiveRange *> activeRanges;
544 // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
545 // at the `currentPoint` in the program but could become active again later.
546 // An inactive section of a live range can be seen as a 'hole' in the live
547 // range, where it is possible to reuse the live range's tile ID _before_ it
548 // has ended. By identifying 'holes', the allocator can reuse tiles more
549 // often, which helps avoid costly tile spills.
550 SetVector<LiveRange *> inactiveRanges;
551 for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
552 auto currentPoint = nextRange->start();
553 // 1. Update the `activeRanges` at `currentPoint`.
554 activeRanges.remove_if(P: [&](LiveRange *activeRange) {
555 // Check for live ranges that have expired.
556 if (activeRange->end() <= currentPoint) {
557 tileAllocator.releaseTileId(activeRange->getTileType(),
558 *activeRange->tileId);
559 return true;
560 }
561 // Check for live ranges that have become inactive.
562 if (!activeRange->overlaps(point: currentPoint)) {
563 tileAllocator.releaseTileId(activeRange->getTileType(),
564 *activeRange->tileId);
565 inactiveRanges.insert(X: activeRange);
566 return true;
567 }
568 return false;
569 });
570 // 2. Update the `inactiveRanges` at `currentPoint`.
571 inactiveRanges.remove_if(P: [&](LiveRange *inactiveRange) {
572 // Check for live ranges that have expired.
573 if (inactiveRange->end() <= currentPoint) {
574 return true;
575 }
576 // Check for live ranges that have become active.
577 if (inactiveRange->overlaps(point: currentPoint)) {
578 tileAllocator.acquireTileId(inactiveRange->getTileType(),
579 *inactiveRange->tileId);
580 activeRanges.insert(X: inactiveRange);
581 return true;
582 }
583 return false;
584 });
585
586 // 3. Collect inactive live ranges that overlap with the new live range.
587 // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
588 // whereas this checks if there is an overlap at any future point too.
589 SmallVector<LiveRange *> overlappingInactiveRanges;
590 for (LiveRange *inactiveRange : inactiveRanges) {
591 if (inactiveRange->overlaps(otherRange: *nextRange)) {
592 // We need to reserve the tile IDs of overlapping inactive ranges to
593 // prevent two (overlapping) live ranges from getting the same tile ID.
594 tileAllocator.acquireTileId(inactiveRange->getTileType(),
595 *inactiveRange->tileId);
596 overlappingInactiveRanges.push_back(Elt: inactiveRange);
597 }
598 }
599
600 // 4. Allocate a tile ID to `nextRange`.
601 auto rangeTileType = nextRange->getTileType();
602 auto tileId = tileAllocator.allocateTileId(rangeTileType);
603 if (succeeded(tileId)) {
604 nextRange->tileId = *tileId;
605 } else {
606 // Create an iterator over all overlapping live ranges.
607 auto allOverlappingRanges = llvm::concat<LiveRange>(
608 Ranges: llvm::make_pointee_range(Range: activeRanges.getArrayRef()),
609 Ranges: llvm::make_pointee_range(Range&: overlappingInactiveRanges));
610 // Choose an overlapping live range to spill.
611 LiveRange *rangeToSpill =
612 chooseSpillUsingHeuristics(overlappingRanges: allOverlappingRanges, newRange: nextRange);
613 if (rangeToSpill != nextRange) {
614 // Spill an (in)active live range (so release its tile ID first).
615 tileAllocator.releaseTileId(rangeToSpill->getTileType(),
616 *rangeToSpill->tileId);
617 // This will always succeed after a spill (of an active live range).
618 nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
619 // Remove the live range from the active/inactive sets.
620 if (!activeRanges.remove(X: rangeToSpill)) {
621 bool removed = inactiveRanges.remove(X: rangeToSpill);
622 assert(removed && "expected a range to be removed!");
623 (void)removed;
624 }
625 }
626 rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
627 }
628
629 // 5. Insert the live range into the active ranges.
630 if (nextRange->tileId < kInMemoryTileIdBase)
631 activeRanges.insert(X: nextRange);
632
633 // 6. Release tiles reserved for inactive live ranges (in step 3).
634 for (LiveRange *range : overlappingInactiveRanges) {
635 if (*range->tileId < kInMemoryTileIdBase)
636 tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
637 }
638 }
639}
640
641/// Assigns a tile ID to an MLIR value.
642void assignTileIdToValue(IRRewriter &rewriter, Value value,
643 IntegerAttr tileIdAttr) {
644 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
645 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
646 for (Operation *user : value.getUsers()) {
647 if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
648 // Ensure ArmSME ops that don't produce a value still get a tile ID.
649 if (!hasTileResult(tileOp))
650 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
651 }
652 }
653}
654
655/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
656LogicalResult assignTileIdsAndResolveTrivialConflicts(
657 IRRewriter &rewriter, FunctionOpInterface function,
658 ArrayRef<LiveRange *> allocatedLiveRanges) {
659 for (LiveRange const *liveRange : allocatedLiveRanges) {
660 auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
661 auto isAllocatedToSameTile = [&](Value value) {
662 if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
663 tileOp && tileOp.getTileId() == tileIdAttr)
664 return true;
665 return liveRange->values.contains(key: value);
666 };
667
668 /// Eliminates copies where the operand has the same tile ID.
669 auto foldRedundantCopies = [&](Value value) -> LogicalResult {
670 auto copyOp = value.getDefiningOp<CopyTileOp>();
671 if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
672 return failure();
673 rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
674 return success();
675 };
676
677 /// Validates each predecessor to a tile block argument has been assigned
678 /// the same tile ID.
679 auto validateBlockArguments = [&](Value value) {
680 auto blockArg = dyn_cast<BlockArgument>(Val&: value);
681 if (!blockArg) {
682 // Not a block argument (nothing to validate).
683 return success();
684 }
685 bool tileMismatch = false;
686 forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) {
687 if (tileMismatch)
688 return;
689 if (!isAllocatedToSameTile(predecessorTile)) {
690 blockArg.getOwner()->getParentOp()->emitOpError(
691 message: "block argument not allocated to the same SME virtial tile as "
692 "predecessors");
693 tileMismatch = true;
694 }
695 });
696 return success(/*isSuccess=*/IsSuccess: !tileMismatch);
697 };
698
699 /// Attempts to resolve (trivial) tile ID conflicts.
700 auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
701 auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
702 OpOperand *tileOperand = getTileOpOperand(tileOp);
703 if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
704 // Operand already allocated to the correct tile.
705 // No conflict to resolve.
706 return success();
707 }
708 auto operandTileOp =
709 tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
710 if (!isTriviallyCloneableTileOp(operandTileOp)) {
711 auto error =
712 tileOp.emitOpError("tile operand allocated to different SME "
713 "virtial tile (move required)");
714 error.attachNote(tileOperand->get().getLoc())
715 << "tile operand is: " << tileOperand->get();
716 return error;
717 }
718 // Cloning prevents a move/spill (though may require recomputation).
719 rewriter.setInsertionPoint(tileOp);
720 auto clonedOp = operandTileOp.clone();
721 rewriter.modifyOpInPlace(clonedOp,
722 [&] { clonedOp.setTileId(tileOp.getTileId()); });
723 rewriter.insert(op: clonedOp);
724 if (isa<CopyTileOp>(tileOp)) {
725 rewriter.replaceAllUsesWith(tileOp->getResult(0),
726 clonedOp->getResult(0));
727 } else {
728 rewriter.modifyOpInPlace(
729 tileOp, [&] { tileOperand->assign(value: clonedOp->getResult(0)); });
730 }
731 return success();
732 };
733
734 for (Value value : liveRange->values) {
735 // 1. Assign the tile ID to the value.
736 assignTileIdToValue(rewriter, value, tileIdAttr);
737
738 // 2. Attempt to eliminate redundant tile copies.
739 if (succeeded(Result: foldRedundantCopies(value)))
740 continue;
741
742 // 3. Validate tile block arguments.
743 if (failed(Result: validateBlockArguments(value)))
744 return failure();
745
746 // 4. Attempt to resolve (trivial) tile ID conflicts.
747 if (failed(Result: resolveTrivialTileConflicts(value)))
748 return failure();
749 }
750 }
751 return success();
752}
753
754/// Prints live ranges alongside operation names for debugging.
755void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
756 ArrayRef<LiveRange const *> liveRanges,
757 FunctionOpInterface function) {
758 llvm::errs() << "SME Tile Liveness: @" << function.getName()
759 << "\nKey:\nS - Start\nE - End\n| - Live\n";
760 for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
761 llvm::errs() << "^bb" << blockIdx << ":\n";
762 for (Operation &op : block.getOperations()) {
763 unsigned operationIndex = operationToIndexMap.at(&op);
764 for (LiveRange const *range : liveRanges) {
765 char liveness = ' ';
766 for (auto it = range->ranges->begin(); it != range->ranges->end();
767 ++it) {
768 if (it.start() == operationIndex)
769 liveness = (liveness == 'E' ? '|' : 'S');
770 else if (it.stop() == operationIndex)
771 liveness = (liveness == 'S' ? '|' : 'E');
772 else if (operationIndex >= it.start() && operationIndex < it.stop())
773 liveness = '|';
774 }
775 llvm::errs() << liveness;
776 }
777 llvm::errs() << ' ' << op.getName() << '\n';
778 }
779 }
780 llvm::errs() << "==========\n";
781}
782
783struct TestTileAllocationPass
784 : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
785 using TestTileAllocationBase::TestTileAllocationBase;
786 void runOnOperation() override {
787 FunctionOpInterface function = getOperation();
788 if (preprocessOnly) {
789 IRRewriter rewriter(function);
790 return preprocessForTileAllocation(rewriter, function);
791 }
792 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
793 signalPassFailure();
794 }
795};
796} // namespace
797
798LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
799 bool dumpRanges) {
800 if (function.empty()) {
801 // TODO: Also return early if the function contains no ArmSME ops?
802 return success();
803 }
804
805 LiveRange::Allocator liveRangeAllocator;
806 IRRewriter rewriter(function.getContext());
807
808 // 1. Preprocess the IR for tile allocation.
809 preprocessForTileAllocation(rewriter, function);
810
811 // 2. Gather live ranges for each ArmSME tile within the function.
812 Liveness liveness(function);
813 auto operationToIndexMap = generateOperationNumbering(function);
814 auto initialLiveRanges = gatherTileLiveRanges(
815 operationToIndexMap, liveRangeAllocator, liveness, function);
816 if (initialLiveRanges.empty())
817 return success();
818
819 if (dumpRanges) {
820 // Wrangle initial live ranges into a form suitable for printing.
821 auto nonEmpty = llvm::make_filter_range(
822 llvm::make_second_range(initialLiveRanges),
823 [&](LiveRange const &liveRange) { return !liveRange.empty(); });
824 auto initialRanges = llvm::to_vector(llvm::map_range(
825 nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
826 llvm::sort(initialRanges,
827 [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
828 llvm::errs() << "\n========== Initial Live Ranges:\n";
829 dumpLiveRanges(operationToIndexMap, initialRanges, function);
830 }
831
832 // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
833 // for tile allocation. E.g. Unify the result of an operation with its
834 // operands.
835 auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
836
837 if (dumpRanges) {
838 llvm::errs() << "\n========== Coalesced Live Ranges:\n";
839 dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
840 }
841
842 // 4. Allocate tile IDs to live ranges.
843 allocateTilesToLiveRanges(coalescedLiveRanges);
844
845 // 5. Assign the tile IDs back to the ArmSME operations.
846 if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
847 coalescedLiveRanges))) {
848 return failure();
849 }
850
851 // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
852 // users). This prevents the LLVM conversion needlessly inserting spills.
853 eraseTriviallyDeadTileOps(rewriter, function);
854 return success();
855}
856

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp