| 1 | //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This transform allocates SME tiles at the 'func.func' op level for ArmSME |
| 10 | // operations. It roughly implements a linear scan register allocator, similar |
| 11 | // to the one outlined in [1], but with simplifications and assumptions made for |
| 12 | // our use case. Note that this is a greedy allocator (so it may not always find |
| 13 | // the most optimal allocation of tiles). |
| 14 | // |
| 15 | // The allocator operates at the CF dialect level. It is the responsibility of |
| 16 | // users to ensure the IR has been lowered to CF before invoking the tile |
| 17 | // allocator. |
| 18 | // |
| 19 | // The 128-bit tiles overlap with other element tiles as follows (see section |
| 20 | // B2.3.2 of SME spec [2]): |
| 21 | // |
| 22 | // Tile Overlaps |
| 23 | // --------------------------------------------------------------------------- |
| 24 | // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, |
| 25 | // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q |
| 26 | // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q |
| 27 | // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q |
| 28 | // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q |
| 29 | // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q |
| 30 | // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q |
| 31 | // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q |
| 32 | // ZA0.D ZA0.Q, ZA8.Q |
| 33 | // ZA1.D ZA1.Q, ZA9.Q |
| 34 | // ZA2.D ZA2.Q, ZA10.Q |
| 35 | // ZA3.D ZA3.Q, ZA11.Q |
| 36 | // ZA4.D ZA4.Q, ZA12.Q |
| 37 | // ZA5.D ZA5.Q, ZA13.Q |
| 38 | // ZA6.D ZA6.Q, ZA14.Q |
| 39 | // ZA7.D ZA7.Q, ZA15.Q |
| 40 | // |
| 41 | // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register |
| 42 | // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) |
| 43 | // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf |
| 44 | // [2] https://developer.arm.com/documentation/ddi0616/aa |
| 45 | // |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | |
| 48 | #include "mlir/Analysis/Liveness.h" |
| 49 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 50 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| 51 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
| 52 | #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" |
| 53 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 54 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 55 | #include "mlir/Transforms/RegionUtils.h" |
| 56 | #include "llvm/ADT/IntervalMap.h" |
| 57 | #include "llvm/ADT/TypeSwitch.h" |
| 58 | #include <algorithm> |
| 59 | |
| 60 | namespace mlir::arm_sme { |
| 61 | #define GEN_PASS_DEF_TESTTILEALLOCATION |
| 62 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
| 63 | } // namespace mlir::arm_sme |
| 64 | |
| 65 | using namespace mlir; |
| 66 | using namespace mlir::arm_sme; |
| 67 | |
| 68 | namespace { |
| 69 | |
| 70 | enum class TileMask : unsigned { |
| 71 | // clang-format off |
| 72 | kZA0B = 0xffff, // 1111 1111 1111 1111 |
| 73 | |
| 74 | kZA0H = 0xaaaa, // 1010 1010 1010 1010 |
| 75 | kZA1H = 0x5555, // 0101 0101 0101 0101 |
| 76 | |
| 77 | kZA0S = 0x8888, // 1000 1000 1000 1000 |
| 78 | kZA1S = 0x4444, // 0100 0100 0100 0100 |
| 79 | kZA2S = 0x2222, // 0010 0010 0010 0010 |
| 80 | kZA3S = 0x1111, // 0001 0001 0001 0001 |
| 81 | |
| 82 | kZA0D = 0x8080, // 1000 0000 1000 0000 |
| 83 | kZA1D = 0x4040, // 0100 0000 0100 0000 |
| 84 | kZA2D = 0x2020, // 0010 0000 0010 0000 |
| 85 | kZA3D = 0x1010, // 0001 0000 0001 0000 |
| 86 | kZA4D = 0x808, // 0000 1000 0000 1000 |
| 87 | kZA5D = 0x404, // 0000 0100 0000 0100 |
| 88 | kZA6D = 0x202, // 0000 0010 0000 0010 |
| 89 | kZA7D = 0x101, // 0000 0001 0000 0001 |
| 90 | |
| 91 | kZA0Q = 0x8000, // 1000 0000 0000 0000 |
| 92 | kZA1Q = 0x4000, // 0100 0000 0000 0000 |
| 93 | kZA2Q = 0x2000, // 0010 0000 0000 0000 |
| 94 | kZA3Q = 0x1000, // 0001 0000 0000 0000 |
| 95 | kZA4Q = 0x800, // 0000 1000 0000 0000 |
| 96 | kZA5Q = 0x400, // 0000 0100 0000 0000 |
| 97 | kZA6Q = 0x200, // 0000 0010 0000 0000 |
| 98 | kZA7Q = 0x100, // 0000 0001 0000 0000 |
| 99 | kZA8Q = 0x80, // 0000 0000 1000 0000 |
| 100 | kZA9Q = 0x40, // 0000 0000 0100 0000 |
| 101 | kZA10Q = 0x20, // 0000 0000 0010 0000 |
| 102 | kZA11Q = 0x10, // 0000 0000 0001 0000 |
| 103 | kZA12Q = 0x8, // 0000 0000 0000 1000 |
| 104 | kZA13Q = 0x4, // 0000 0000 0000 0100 |
| 105 | kZA14Q = 0x2, // 0000 0000 0000 0010 |
| 106 | kZA15Q = 0x1, // 0000 0000 0000 0001 |
| 107 | |
| 108 | kNone = 0x0, // 0000 0000 0000 0000 |
| 109 | // clang-format on |
| 110 | |
| 111 | LLVM_MARK_AS_BITMASK_ENUM(kZA0B) |
| 112 | }; |
| 113 | |
| 114 | /// Returns the set of masks relevant for the given type. |
| 115 | static ArrayRef<TileMask> getMasks(ArmSMETileType type) { |
| 116 | static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; |
| 117 | static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; |
| 118 | static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, |
| 119 | TileMask::kZA2S, TileMask::kZA3S}; |
| 120 | static constexpr std::array ZA_D_MASKS = { |
| 121 | TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, |
| 122 | TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; |
| 123 | static constexpr std::array ZA_Q_MASKS = { |
| 124 | TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, |
| 125 | TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, |
| 126 | TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, |
| 127 | TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; |
| 128 | switch (type) { |
| 129 | case ArmSMETileType::ZAB: |
| 130 | return ZA_B_MASKS; |
| 131 | case ArmSMETileType::ZAH: |
| 132 | return ZA_H_MASKS; |
| 133 | case ArmSMETileType::ZAS: |
| 134 | return ZA_S_MASKS; |
| 135 | case ArmSMETileType::ZAD: |
| 136 | return ZA_D_MASKS; |
| 137 | case ArmSMETileType::ZAQ: |
| 138 | return ZA_Q_MASKS; |
| 139 | } |
| 140 | llvm_unreachable("unknown type in getMasks" ); |
| 141 | } |
| 142 | |
| 143 | class TileAllocator { |
| 144 | public: |
| 145 | /// Allocates and returns a tile ID. Fails if there are no tiles left. |
| 146 | FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) { |
| 147 | auto masks = getMasks(tileType); |
| 148 | for (auto [tileId, tileMask] : llvm::enumerate(masks)) { |
| 149 | if ((tilesInUse & tileMask) == TileMask::kNone) { |
| 150 | tilesInUse |= tileMask; |
| 151 | return tileId; |
| 152 | } |
| 153 | } |
| 154 | return failure(); |
| 155 | } |
| 156 | |
| 157 | /// Acquires a specific tile ID. Asserts the tile is initially free. |
| 158 | void acquireTileId(ArmSMETileType tileType, unsigned tileId) { |
| 159 | TileMask tileMask = getMasks(tileType)[tileId]; |
| 160 | assert((tilesInUse & tileMask) == TileMask::kNone && |
| 161 | "cannot acquire allocated tile!" ); |
| 162 | tilesInUse |= tileMask; |
| 163 | } |
| 164 | |
| 165 | /// Releases a previously allocated tile ID. |
| 166 | void releaseTileId(ArmSMETileType tileType, unsigned tileId) { |
| 167 | TileMask tileMask = getMasks(tileType)[tileId]; |
| 168 | assert((tilesInUse & tileMask) == tileMask && |
| 169 | "cannot release unallocated tile!" ); |
| 170 | tilesInUse ^= tileMask; |
| 171 | } |
| 172 | |
| 173 | /// Allocates an in-memory tile ID. |
| 174 | unsigned allocateInMemoryTileId() { |
| 175 | // Note: We never release in-memory tile IDs. We could, which may allow |
| 176 | // reusing an allocation, but as we _never_ want to spill an SME tile this |
| 177 | // is not optimized. |
| 178 | return nextInMemoryTileId++; |
| 179 | } |
| 180 | |
| 181 | private: |
| 182 | TileMask tilesInUse = TileMask::kNone; |
| 183 | unsigned nextInMemoryTileId = kInMemoryTileIdBase; |
| 184 | }; |
| 185 | |
| 186 | /// Add new intermediate blocks for the true and false destinations of |
| 187 | /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness |
| 188 | /// overlaps due to copies at branches. |
| 189 | /// |
| 190 | /// BEFORE: |
| 191 | /// ```mlir |
| 192 | /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 |
| 193 | /// ``` |
| 194 | /// |
| 195 | /// AFTER: |
| 196 | /// ```mlir |
| 197 | /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy |
| 198 | /// ^bb1_copy: |
| 199 | /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) |
| 200 | /// ^bb2_copy: |
| 201 | /// cf.br ^bb2 |
| 202 | /// ``` |
| 203 | void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { |
| 204 | SmallVector<cf::CondBranchOp> worklist; |
| 205 | function.walk([&](cf::CondBranchOp condBranch) { |
| 206 | if (llvm::any_of(condBranch->getOperands(), [&](Value value) { |
| 207 | return isValidSMETileVectorType(type: value.getType()); |
| 208 | })) { |
| 209 | worklist.push_back(condBranch); |
| 210 | } |
| 211 | }); |
| 212 | |
| 213 | auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { |
| 214 | rewriter.setInsertionPointToEnd(source); |
| 215 | rewriter.create<cf::BranchOp>(loc, dest, args); |
| 216 | }; |
| 217 | |
| 218 | for (auto condBranch : worklist) { |
| 219 | auto loc = condBranch.getLoc(); |
| 220 | Block *block = condBranch->getBlock(); |
| 221 | auto newTrueBranch = rewriter.splitBlock(block, block->end()); |
| 222 | auto newFalseBranch = rewriter.splitBlock(block, block->end()); |
| 223 | insertJump(loc, newTrueBranch, condBranch.getTrueDest(), |
| 224 | condBranch.getTrueDestOperands()); |
| 225 | insertJump(loc, newFalseBranch, condBranch.getFalseDest(), |
| 226 | condBranch.getFalseDestOperands()); |
| 227 | rewriter.modifyOpInPlace(condBranch, [&] { |
| 228 | condBranch.getFalseDestOperandsMutable().clear(); |
| 229 | condBranch.getTrueDestOperandsMutable().clear(); |
| 230 | condBranch.setSuccessor(newTrueBranch, 0); |
| 231 | condBranch.setSuccessor(newFalseBranch, 1); |
| 232 | }); |
| 233 | } |
| 234 | } |
| 235 | |
| 236 | /// Inserts tile copies at `cf.br` operations. |
| 237 | /// |
| 238 | /// BEFORE: |
| 239 | /// ```mlir |
| 240 | /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) |
| 241 | /// ``` |
| 242 | /// |
| 243 | /// AFTER: |
| 244 | /// ```mlir |
| 245 | /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> |
| 246 | /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) |
| 247 | /// ``` |
| 248 | void insertCopiesAtBranches(IRRewriter &rewriter, |
| 249 | FunctionOpInterface function) { |
| 250 | for (Block &block : function.getBlocks()) { |
| 251 | Operation *terminator = block.getTerminator(); |
| 252 | if (!isa<cf::BranchOp>(terminator)) |
| 253 | continue; |
| 254 | rewriter.setInsertionPoint(terminator); |
| 255 | for (OpOperand &operand : terminator->getOpOperands()) { |
| 256 | if (isValidSMETileVectorType(operand.get().getType())) { |
| 257 | auto copy = |
| 258 | rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get()); |
| 259 | rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); |
| 260 | } |
| 261 | } |
| 262 | } |
| 263 | } |
| 264 | |
| 265 | /// Prepares the IR for tile allocation. It does this by first 'splitting' |
| 266 | /// conditional branches (see `splitCondBranches`), then inserting tile copies |
| 267 | /// at branch operations. The conditional branches are split to prevent the |
| 268 | /// copies needed for them overlapping between the true and false paths of the |
| 269 | /// branch (see `tile-allocation-copies.mlir` and |
| 270 | /// `tile-allocation-liveness.mlir` for examples). The copies break up live |
| 271 | /// ranges and ensure when moving out of SSA the semantics of the program are |
| 272 | /// preserved. |
| 273 | void preprocessForTileAllocation(IRRewriter &rewriter, |
| 274 | FunctionOpInterface function) { |
| 275 | splitCondBranches(rewriter, function); |
| 276 | insertCopiesAtBranches(rewriter, function); |
| 277 | } |
| 278 | |
| 279 | /// A live range for a (collection of) tile values. A live range is built up of |
| 280 | /// non-overlapping intervals [start, end) which represent parts of the program |
| 281 | /// where a value in the range needs to be live (i.e. in an SME virtual tile). |
| 282 | /// Note that as the intervals are non-overlapping all values within a live |
| 283 | /// range can be allocated to the same SME virtual tile. |
| 284 | struct LiveRange { |
| 285 | using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16, |
| 286 | llvm::IntervalMapHalfOpenInfo<unsigned>>; |
| 287 | using Allocator = RangeSet::Allocator; |
| 288 | // Dummy value for the IntervalMap. Only the keys matter (the intervals). |
| 289 | static constexpr uint8_t kValidLiveRange = 0xff; |
| 290 | |
| 291 | LiveRange(Allocator &allocator) |
| 292 | : ranges(std::make_unique<RangeSet>(args&: allocator)) {} |
| 293 | |
| 294 | /// Returns true if this range overlaps with `otherRange`. |
| 295 | bool overlaps(LiveRange const &otherRange) const { |
| 296 | return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges, |
| 297 | *otherRange.ranges) |
| 298 | .valid(); |
| 299 | } |
| 300 | |
| 301 | /// Returns true if this range is active at `point` in the program. |
| 302 | bool overlaps(uint64_t point) const { |
| 303 | return ranges->lookup(x: point) == kValidLiveRange; |
| 304 | } |
| 305 | |
| 306 | /// Unions this live range with `otherRange`, aborts if the ranges overlap. |
| 307 | void unionWith(LiveRange const &otherRange) { |
| 308 | for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end(); |
| 309 | ++it) |
| 310 | ranges->insert(a: it.start(), b: it.stop(), y: kValidLiveRange); |
| 311 | values.set_union(otherRange.values); |
| 312 | } |
| 313 | |
| 314 | /// Inserts an interval [start, end) for `value` into this range. |
| 315 | void insert(Value value, unsigned start, unsigned end) { |
| 316 | values.insert(X: value); |
| 317 | if (start != end) |
| 318 | ranges->insert(a: start, b: end, y: kValidLiveRange); |
| 319 | } |
| 320 | |
| 321 | bool empty() const { return ranges->empty(); } |
| 322 | unsigned start() const { return ranges->start(); } |
| 323 | unsigned end() const { return ranges->stop(); } |
| 324 | bool operator<(LiveRange const &other) const { |
| 325 | return start() < other.start(); |
| 326 | } |
| 327 | |
| 328 | ArmSMETileType getTileType() const { |
| 329 | return *getSMETileType(cast<VectorType>(values[0].getType())); |
| 330 | } |
| 331 | |
| 332 | /// The values contained in this live range. |
| 333 | SetVector<Value> values; |
| 334 | |
| 335 | /// A set of (non-overlapping) intervals that mark where any value in `values` |
| 336 | /// is live. |
| 337 | std::unique_ptr<RangeSet> ranges; |
| 338 | |
| 339 | /// The tile ID (or none) assigned to this live range. |
| 340 | std::optional<unsigned> tileId; |
| 341 | }; |
| 342 | |
| 343 | /// Number operations within a function to allow computing live ranges. |
| 344 | /// Operations are numbered consecutively wihin blocks, and the blocks are |
| 345 | /// topologically sorted (using forward edges). This function is only correct if |
| 346 | /// all ArmSME have been converted to CF (which is asserted). |
| 347 | DenseMap<Operation *, unsigned> |
| 348 | generateOperationNumbering(FunctionOpInterface function) { |
| 349 | unsigned index = 0; |
| 350 | SetVector<Block *> blocks = |
| 351 | getBlocksSortedByDominance(function.getFunctionBody()); |
| 352 | DenseMap<Operation *, unsigned> operationToIndexMap; |
| 353 | for (Block *block : blocks) { |
| 354 | index++; // We want block args to have their own number. |
| 355 | for (Operation &op : block->getOperations()) { |
| 356 | #ifndef NDEBUG |
| 357 | op.walk([&](ArmSMETileOpInterface nestedOp) { |
| 358 | assert(&op == nestedOp.getOperation() && |
| 359 | "ArmSME tile allocation does not support nested regions" ); |
| 360 | }); |
| 361 | #endif |
| 362 | operationToIndexMap.try_emplace(&op, index++); |
| 363 | } |
| 364 | } |
| 365 | return operationToIndexMap; |
| 366 | } |
| 367 | |
| 368 | /// Gather live ranges for SME tiles from the MLIR liveness analysis. |
| 369 | DenseMap<Value, LiveRange> |
| 370 | gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, |
| 371 | LiveRange::Allocator &liveRangeAllocator, |
| 372 | Liveness &liveness, FunctionOpInterface function) { |
| 373 | assert(!operationToIndexMap.empty() && "expected operation numbering" ); |
| 374 | DenseMap<Value, LiveRange> liveRanges; |
| 375 | /// Defines or updates a live range for an SME tile value. Live-ins may update |
| 376 | /// an existing live range (rather than define a new one). Note: If |
| 377 | /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in |
| 378 | /// the block. |
| 379 | auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef, |
| 380 | LivenessBlockInfo const &livenessInfo, |
| 381 | bool liveAtBlockEntry = false) { |
| 382 | if (!isValidSMETileVectorType(type: value.getType())) |
| 383 | return; |
| 384 | // Find or create a live range for `value`. |
| 385 | auto [it, _] = liveRanges.try_emplace(Key: value, Args&: liveRangeAllocator); |
| 386 | LiveRange &valueLiveRange = it->second; |
| 387 | auto lastUseInBlock = livenessInfo.getEndOperation(value, startOperation: firstUseOrDef); |
| 388 | // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. |
| 389 | unsigned startOpIdx = |
| 390 | operationToIndexMap.at(Val: firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); |
| 391 | unsigned endOpIdx = operationToIndexMap.at(Val: lastUseInBlock); |
| 392 | valueLiveRange.insert(value, start: startOpIdx, end: endOpIdx); |
| 393 | }; |
| 394 | |
| 395 | for (Block &block : function.getBlocks()) { |
| 396 | LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block); |
| 397 | // Handle block arguments: |
| 398 | for (Value argument : block.getArguments()) |
| 399 | defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo, |
| 400 | /*liveAtBlockEntry=*/true); |
| 401 | // Handle live-ins: |
| 402 | for (Value liveIn : livenessInfo->in()) |
| 403 | defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo, |
| 404 | /*liveAtBlockEntry=*/true); |
| 405 | // Handle new definitions: |
| 406 | for (Operation &op : block) { |
| 407 | for (Value result : op.getResults()) |
| 408 | defineOrUpdateValueLiveRange(result, &op, *livenessInfo); |
| 409 | } |
| 410 | } |
| 411 | |
| 412 | return liveRanges; |
| 413 | } |
| 414 | |
| 415 | /// Iterate over all predecessor tile values to a (tile) block argument. |
| 416 | static void forEachPredecessorTileValue(BlockArgument blockArg, |
| 417 | function_ref<void(Value)> callback) { |
| 418 | Block *block = blockArg.getOwner(); |
| 419 | unsigned argNumber = blockArg.getArgNumber(); |
| 420 | for (Block *pred : block->getPredecessors()) { |
| 421 | TypeSwitch<Operation *>(pred->getTerminator()) |
| 422 | .Case<cf::BranchOp>([&](auto branch) { |
| 423 | Value predecessorOperand = branch.getDestOperands()[argNumber]; |
| 424 | callback(predecessorOperand); |
| 425 | }) |
| 426 | .Case<cf::CondBranchOp>([&](auto condBranch) { |
| 427 | if (condBranch.getFalseDest() == block) { |
| 428 | Value predecessorOperand = |
| 429 | condBranch.getFalseDestOperands()[argNumber]; |
| 430 | callback(predecessorOperand); |
| 431 | } |
| 432 | if (condBranch.getTrueDest() == block) { |
| 433 | Value predecessorOperand = |
| 434 | condBranch.getTrueDestOperands()[argNumber]; |
| 435 | callback(predecessorOperand); |
| 436 | } |
| 437 | }); |
| 438 | } |
| 439 | } |
| 440 | |
| 441 | /// Coalesce live ranges where it would prevent unnecessary tile moves. |
| 442 | SmallVector<LiveRange *> |
| 443 | coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) { |
| 444 | DenseMap<Value, LiveRange *> liveRanges; |
| 445 | for (auto &[value, liveRange] : initialLiveRanges) { |
| 446 | liveRanges.insert(KV: {value, &liveRange}); |
| 447 | } |
| 448 | |
| 449 | // Merge the live ranges of values `a` and `b` into one (if they do not |
| 450 | // overlap). After this, the values `a` and `b` will both point to the same |
| 451 | // live range (which will contain multiple values). |
| 452 | auto mergeValuesIfNonOverlapping = [&](Value a, Value b) { |
| 453 | LiveRange *aLiveRange = liveRanges.at(Val: a); |
| 454 | LiveRange *bLiveRange = liveRanges.at(Val: b); |
| 455 | if (aLiveRange != bLiveRange && !aLiveRange->overlaps(otherRange: *bLiveRange)) { |
| 456 | aLiveRange->unionWith(otherRange: *bLiveRange); |
| 457 | for (Value value : bLiveRange->values) |
| 458 | liveRanges[value] = aLiveRange; |
| 459 | } |
| 460 | }; |
| 461 | |
| 462 | // Merge the live ranges of new definitions with their tile operands. |
| 463 | auto unifyDefinitionsWithOperands = [&](Value value) { |
| 464 | auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>(); |
| 465 | if (!armSMEOp) |
| 466 | return; |
| 467 | for (auto operand : armSMEOp->getOperands()) { |
| 468 | if (isValidSMETileVectorType(operand.getType())) |
| 469 | mergeValuesIfNonOverlapping(value, operand); |
| 470 | } |
| 471 | }; |
| 472 | |
| 473 | // Merge the live ranges of block arguments with their predecessors. |
| 474 | auto unifyBlockArgumentsWithPredecessors = [&](Value value) { |
| 475 | auto blockArg = dyn_cast<BlockArgument>(Val&: value); |
| 476 | if (!blockArg) |
| 477 | return; |
| 478 | forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) { |
| 479 | mergeValuesIfNonOverlapping(blockArg, predecessorTile); |
| 480 | }); |
| 481 | }; |
| 482 | |
| 483 | auto applyRule = [&](auto rule) { |
| 484 | llvm::for_each(llvm::make_first_range(c&: initialLiveRanges), rule); |
| 485 | }; |
| 486 | |
| 487 | // Unify as many live ranges as we can. This prevents unnecessary moves. |
| 488 | applyRule(unifyBlockArgumentsWithPredecessors); |
| 489 | applyRule(unifyDefinitionsWithOperands); |
| 490 | |
| 491 | // Remove duplicate live range entries. |
| 492 | SetVector<LiveRange *> uniqueLiveRanges; |
| 493 | for (auto [_, liveRange] : liveRanges) { |
| 494 | if (!liveRange->empty()) |
| 495 | uniqueLiveRanges.insert(X: liveRange); |
| 496 | } |
| 497 | |
| 498 | // Sort the new live ranges by starting point (ready for tile allocation). |
| 499 | auto coalescedLiveRanges = uniqueLiveRanges.takeVector(); |
| 500 | llvm::sort(C&: coalescedLiveRanges, |
| 501 | Comp: [](LiveRange *a, LiveRange *b) { return *a < *b; }); |
| 502 | return std::move(coalescedLiveRanges); |
| 503 | } |
| 504 | |
| 505 | /// Choose a live range to spill (via some heuristics). This picks either a live |
| 506 | /// range from `overlappingRanges`, or the new live range `newRange`. |
| 507 | template <typename OverlappingRangesIterator> |
| 508 | LiveRange * |
| 509 | chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges, |
| 510 | LiveRange *newRange) { |
| 511 | // Heuristic: Spill trivially copyable operations (usually free). |
| 512 | auto isTrivialSpill = [&](LiveRange &allocatedRange) { |
| 513 | return isTileTypeGreaterOrEqual(allocatedRange.getTileType(), |
| 514 | newRange->getTileType()) && |
| 515 | allocatedRange.values.size() == 1 && |
| 516 | isTriviallyCloneableTileOp( |
| 517 | allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>()); |
| 518 | }; |
| 519 | if (isTrivialSpill(*newRange)) |
| 520 | return newRange; |
| 521 | auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill); |
| 522 | if (trivialSpill != overlappingRanges.end()) |
| 523 | return &*trivialSpill; |
| 524 | |
| 525 | // Heuristic: Spill the range that ends last (with a compatible tile type). |
| 526 | auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) { |
| 527 | return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) || |
| 528 | a.end() < b.end(); |
| 529 | }; |
| 530 | LiveRange &latestEndingLiveRange = |
| 531 | *llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier); |
| 532 | if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange)) |
| 533 | return &latestEndingLiveRange; |
| 534 | return newRange; |
| 535 | } |
| 536 | |
| 537 | /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. |
| 538 | void allocateTilesToLiveRanges( |
| 539 | ArrayRef<LiveRange *> liveRangesSortedByStartPoint) { |
| 540 | TileAllocator tileAllocator; |
| 541 | // `activeRanges` = Live ranges that need to be in a tile at the |
| 542 | // `currentPoint` in the program. |
| 543 | SetVector<LiveRange *> activeRanges; |
| 544 | // `inactiveRanges` = Live ranges that _do not_ need to be in a tile |
| 545 | // at the `currentPoint` in the program but could become active again later. |
| 546 | // An inactive section of a live range can be seen as a 'hole' in the live |
| 547 | // range, where it is possible to reuse the live range's tile ID _before_ it |
| 548 | // has ended. By identifying 'holes', the allocator can reuse tiles more |
| 549 | // often, which helps avoid costly tile spills. |
| 550 | SetVector<LiveRange *> inactiveRanges; |
| 551 | for (LiveRange * : liveRangesSortedByStartPoint) { |
| 552 | auto currentPoint = nextRange->start(); |
| 553 | // 1. Update the `activeRanges` at `currentPoint`. |
| 554 | activeRanges.remove_if(P: [&](LiveRange *activeRange) { |
| 555 | // Check for live ranges that have expired. |
| 556 | if (activeRange->end() <= currentPoint) { |
| 557 | tileAllocator.releaseTileId(activeRange->getTileType(), |
| 558 | *activeRange->tileId); |
| 559 | return true; |
| 560 | } |
| 561 | // Check for live ranges that have become inactive. |
| 562 | if (!activeRange->overlaps(point: currentPoint)) { |
| 563 | tileAllocator.releaseTileId(activeRange->getTileType(), |
| 564 | *activeRange->tileId); |
| 565 | inactiveRanges.insert(X: activeRange); |
| 566 | return true; |
| 567 | } |
| 568 | return false; |
| 569 | }); |
| 570 | // 2. Update the `inactiveRanges` at `currentPoint`. |
| 571 | inactiveRanges.remove_if(P: [&](LiveRange *inactiveRange) { |
| 572 | // Check for live ranges that have expired. |
| 573 | if (inactiveRange->end() <= currentPoint) { |
| 574 | return true; |
| 575 | } |
| 576 | // Check for live ranges that have become active. |
| 577 | if (inactiveRange->overlaps(point: currentPoint)) { |
| 578 | tileAllocator.acquireTileId(inactiveRange->getTileType(), |
| 579 | *inactiveRange->tileId); |
| 580 | activeRanges.insert(X: inactiveRange); |
| 581 | return true; |
| 582 | } |
| 583 | return false; |
| 584 | }); |
| 585 | |
| 586 | // 3. Collect inactive live ranges that overlap with the new live range. |
| 587 | // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint` |
| 588 | // whereas this checks if there is an overlap at any future point too. |
| 589 | SmallVector<LiveRange *> overlappingInactiveRanges; |
| 590 | for (LiveRange *inactiveRange : inactiveRanges) { |
| 591 | if (inactiveRange->overlaps(otherRange: *nextRange)) { |
| 592 | // We need to reserve the tile IDs of overlapping inactive ranges to |
| 593 | // prevent two (overlapping) live ranges from getting the same tile ID. |
| 594 | tileAllocator.acquireTileId(inactiveRange->getTileType(), |
| 595 | *inactiveRange->tileId); |
| 596 | overlappingInactiveRanges.push_back(Elt: inactiveRange); |
| 597 | } |
| 598 | } |
| 599 | |
| 600 | // 4. Allocate a tile ID to `nextRange`. |
| 601 | auto rangeTileType = nextRange->getTileType(); |
| 602 | auto tileId = tileAllocator.allocateTileId(rangeTileType); |
| 603 | if (succeeded(tileId)) { |
| 604 | nextRange->tileId = *tileId; |
| 605 | } else { |
| 606 | // Create an iterator over all overlapping live ranges. |
| 607 | auto allOverlappingRanges = llvm::concat<LiveRange>( |
| 608 | Ranges: llvm::make_pointee_range(Range: activeRanges.getArrayRef()), |
| 609 | Ranges: llvm::make_pointee_range(Range&: overlappingInactiveRanges)); |
| 610 | // Choose an overlapping live range to spill. |
| 611 | LiveRange *rangeToSpill = |
| 612 | chooseSpillUsingHeuristics(overlappingRanges: allOverlappingRanges, newRange: nextRange); |
| 613 | if (rangeToSpill != nextRange) { |
| 614 | // Spill an (in)active live range (so release its tile ID first). |
| 615 | tileAllocator.releaseTileId(rangeToSpill->getTileType(), |
| 616 | *rangeToSpill->tileId); |
| 617 | // This will always succeed after a spill (of an active live range). |
| 618 | nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType); |
| 619 | // Remove the live range from the active/inactive sets. |
| 620 | if (!activeRanges.remove(X: rangeToSpill)) { |
| 621 | bool removed = inactiveRanges.remove(X: rangeToSpill); |
| 622 | assert(removed && "expected a range to be removed!" ); |
| 623 | (void)removed; |
| 624 | } |
| 625 | } |
| 626 | rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId(); |
| 627 | } |
| 628 | |
| 629 | // 5. Insert the live range into the active ranges. |
| 630 | if (nextRange->tileId < kInMemoryTileIdBase) |
| 631 | activeRanges.insert(X: nextRange); |
| 632 | |
| 633 | // 6. Release tiles reserved for inactive live ranges (in step 3). |
| 634 | for (LiveRange *range : overlappingInactiveRanges) { |
| 635 | if (*range->tileId < kInMemoryTileIdBase) |
| 636 | tileAllocator.releaseTileId(range->getTileType(), *range->tileId); |
| 637 | } |
| 638 | } |
| 639 | } |
| 640 | |
| 641 | /// Assigns a tile ID to an MLIR value. |
| 642 | void assignTileIdToValue(IRRewriter &rewriter, Value value, |
| 643 | IntegerAttr tileIdAttr) { |
| 644 | if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) |
| 645 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); |
| 646 | for (Operation *user : value.getUsers()) { |
| 647 | if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) { |
| 648 | // Ensure ArmSME ops that don't produce a value still get a tile ID. |
| 649 | if (!hasTileResult(tileOp)) |
| 650 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); |
| 651 | } |
| 652 | } |
| 653 | } |
| 654 | |
| 655 | /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. |
| 656 | LogicalResult assignTileIdsAndResolveTrivialConflicts( |
| 657 | IRRewriter &rewriter, FunctionOpInterface function, |
| 658 | ArrayRef<LiveRange *> allocatedLiveRanges) { |
| 659 | for (LiveRange const *liveRange : allocatedLiveRanges) { |
| 660 | auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId); |
| 661 | auto isAllocatedToSameTile = [&](Value value) { |
| 662 | if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); |
| 663 | tileOp && tileOp.getTileId() == tileIdAttr) |
| 664 | return true; |
| 665 | return liveRange->values.contains(key: value); |
| 666 | }; |
| 667 | |
| 668 | /// Eliminates copies where the operand has the same tile ID. |
| 669 | auto foldRedundantCopies = [&](Value value) -> LogicalResult { |
| 670 | auto copyOp = value.getDefiningOp<CopyTileOp>(); |
| 671 | if (!copyOp || !isAllocatedToSameTile(copyOp.getTile())) |
| 672 | return failure(); |
| 673 | rewriter.replaceAllUsesWith(copyOp, copyOp.getTile()); |
| 674 | return success(); |
| 675 | }; |
| 676 | |
| 677 | /// Validates each predecessor to a tile block argument has been assigned |
| 678 | /// the same tile ID. |
| 679 | auto validateBlockArguments = [&](Value value) { |
| 680 | auto blockArg = dyn_cast<BlockArgument>(Val&: value); |
| 681 | if (!blockArg) { |
| 682 | // Not a block argument (nothing to validate). |
| 683 | return success(); |
| 684 | } |
| 685 | bool tileMismatch = false; |
| 686 | forEachPredecessorTileValue(blockArg, callback: [&](Value predecessorTile) { |
| 687 | if (tileMismatch) |
| 688 | return; |
| 689 | if (!isAllocatedToSameTile(predecessorTile)) { |
| 690 | blockArg.getOwner()->getParentOp()->emitOpError( |
| 691 | message: "block argument not allocated to the same SME virtial tile as " |
| 692 | "predecessors" ); |
| 693 | tileMismatch = true; |
| 694 | } |
| 695 | }); |
| 696 | return success(/*isSuccess=*/IsSuccess: !tileMismatch); |
| 697 | }; |
| 698 | |
| 699 | /// Attempts to resolve (trivial) tile ID conflicts. |
| 700 | auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult { |
| 701 | auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>(); |
| 702 | OpOperand *tileOperand = getTileOpOperand(tileOp); |
| 703 | if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) { |
| 704 | // Operand already allocated to the correct tile. |
| 705 | // No conflict to resolve. |
| 706 | return success(); |
| 707 | } |
| 708 | auto operandTileOp = |
| 709 | tileOperand->get().getDefiningOp<ArmSMETileOpInterface>(); |
| 710 | if (!isTriviallyCloneableTileOp(operandTileOp)) { |
| 711 | auto error = |
| 712 | tileOp.emitOpError("tile operand allocated to different SME " |
| 713 | "virtial tile (move required)" ); |
| 714 | error.attachNote(tileOperand->get().getLoc()) |
| 715 | << "tile operand is: " << tileOperand->get(); |
| 716 | return error; |
| 717 | } |
| 718 | // Cloning prevents a move/spill (though may require recomputation). |
| 719 | rewriter.setInsertionPoint(tileOp); |
| 720 | auto clonedOp = operandTileOp.clone(); |
| 721 | rewriter.modifyOpInPlace(clonedOp, |
| 722 | [&] { clonedOp.setTileId(tileOp.getTileId()); }); |
| 723 | rewriter.insert(op: clonedOp); |
| 724 | if (isa<CopyTileOp>(tileOp)) { |
| 725 | rewriter.replaceAllUsesWith(tileOp->getResult(0), |
| 726 | clonedOp->getResult(0)); |
| 727 | } else { |
| 728 | rewriter.modifyOpInPlace( |
| 729 | tileOp, [&] { tileOperand->assign(value: clonedOp->getResult(0)); }); |
| 730 | } |
| 731 | return success(); |
| 732 | }; |
| 733 | |
| 734 | for (Value value : liveRange->values) { |
| 735 | // 1. Assign the tile ID to the value. |
| 736 | assignTileIdToValue(rewriter, value, tileIdAttr); |
| 737 | |
| 738 | // 2. Attempt to eliminate redundant tile copies. |
| 739 | if (succeeded(Result: foldRedundantCopies(value))) |
| 740 | continue; |
| 741 | |
| 742 | // 3. Validate tile block arguments. |
| 743 | if (failed(Result: validateBlockArguments(value))) |
| 744 | return failure(); |
| 745 | |
| 746 | // 4. Attempt to resolve (trivial) tile ID conflicts. |
| 747 | if (failed(Result: resolveTrivialTileConflicts(value))) |
| 748 | return failure(); |
| 749 | } |
| 750 | } |
| 751 | return success(); |
| 752 | } |
| 753 | |
| 754 | /// Prints live ranges alongside operation names for debugging. |
| 755 | void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, |
| 756 | ArrayRef<LiveRange const *> liveRanges, |
| 757 | FunctionOpInterface function) { |
| 758 | llvm::errs() << "SME Tile Liveness: @" << function.getName() |
| 759 | << "\nKey:\nS - Start\nE - End\n| - Live\n" ; |
| 760 | for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) { |
| 761 | llvm::errs() << "^bb" << blockIdx << ":\n" ; |
| 762 | for (Operation &op : block.getOperations()) { |
| 763 | unsigned operationIndex = operationToIndexMap.at(&op); |
| 764 | for (LiveRange const *range : liveRanges) { |
| 765 | char liveness = ' '; |
| 766 | for (auto it = range->ranges->begin(); it != range->ranges->end(); |
| 767 | ++it) { |
| 768 | if (it.start() == operationIndex) |
| 769 | liveness = (liveness == 'E' ? '|' : 'S'); |
| 770 | else if (it.stop() == operationIndex) |
| 771 | liveness = (liveness == 'S' ? '|' : 'E'); |
| 772 | else if (operationIndex >= it.start() && operationIndex < it.stop()) |
| 773 | liveness = '|'; |
| 774 | } |
| 775 | llvm::errs() << liveness; |
| 776 | } |
| 777 | llvm::errs() << ' ' << op.getName() << '\n'; |
| 778 | } |
| 779 | } |
| 780 | llvm::errs() << "==========\n" ; |
| 781 | } |
| 782 | |
| 783 | struct TestTileAllocationPass |
| 784 | : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> { |
| 785 | using TestTileAllocationBase::TestTileAllocationBase; |
| 786 | void runOnOperation() override { |
| 787 | FunctionOpInterface function = getOperation(); |
| 788 | if (preprocessOnly) { |
| 789 | IRRewriter rewriter(function); |
| 790 | return preprocessForTileAllocation(rewriter, function); |
| 791 | } |
| 792 | if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) |
| 793 | signalPassFailure(); |
| 794 | } |
| 795 | }; |
| 796 | } // namespace |
| 797 | |
| 798 | LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, |
| 799 | bool dumpRanges) { |
| 800 | if (function.empty()) { |
| 801 | // TODO: Also return early if the function contains no ArmSME ops? |
| 802 | return success(); |
| 803 | } |
| 804 | |
| 805 | LiveRange::Allocator liveRangeAllocator; |
| 806 | IRRewriter rewriter(function.getContext()); |
| 807 | |
| 808 | // 1. Preprocess the IR for tile allocation. |
| 809 | preprocessForTileAllocation(rewriter, function); |
| 810 | |
| 811 | // 2. Gather live ranges for each ArmSME tile within the function. |
| 812 | Liveness liveness(function); |
| 813 | auto operationToIndexMap = generateOperationNumbering(function); |
| 814 | auto initialLiveRanges = gatherTileLiveRanges( |
| 815 | operationToIndexMap, liveRangeAllocator, liveness, function); |
| 816 | if (initialLiveRanges.empty()) |
| 817 | return success(); |
| 818 | |
| 819 | if (dumpRanges) { |
| 820 | // Wrangle initial live ranges into a form suitable for printing. |
| 821 | auto nonEmpty = llvm::make_filter_range( |
| 822 | llvm::make_second_range(initialLiveRanges), |
| 823 | [&](LiveRange const &liveRange) { return !liveRange.empty(); }); |
| 824 | auto initialRanges = llvm::to_vector(llvm::map_range( |
| 825 | nonEmpty, [](LiveRange const &liveRange) { return &liveRange; })); |
| 826 | llvm::sort(initialRanges, |
| 827 | [](LiveRange const *a, LiveRange const *b) { return *a < *b; }); |
| 828 | llvm::errs() << "\n========== Initial Live Ranges:\n" ; |
| 829 | dumpLiveRanges(operationToIndexMap, initialRanges, function); |
| 830 | } |
| 831 | |
| 832 | // 3. Coalesce (non-overlapping) live ranges where it would be beneficial |
| 833 | // for tile allocation. E.g. Unify the result of an operation with its |
| 834 | // operands. |
| 835 | auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges); |
| 836 | |
| 837 | if (dumpRanges) { |
| 838 | llvm::errs() << "\n========== Coalesced Live Ranges:\n" ; |
| 839 | dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function); |
| 840 | } |
| 841 | |
| 842 | // 4. Allocate tile IDs to live ranges. |
| 843 | allocateTilesToLiveRanges(coalescedLiveRanges); |
| 844 | |
| 845 | // 5. Assign the tile IDs back to the ArmSME operations. |
| 846 | if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function, |
| 847 | coalescedLiveRanges))) { |
| 848 | return failure(); |
| 849 | } |
| 850 | |
| 851 | // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no |
| 852 | // users). This prevents the LLVM conversion needlessly inserting spills. |
| 853 | eraseTriviallyDeadTileOps(rewriter, function); |
| 854 | return success(); |
| 855 | } |
| 856 | |