1 | //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass allocates SME tiles at the 'func.func' op level for ArmSME |
10 | // operations. It does this using a 16-bit tile mask that has a bit for each |
11 | // 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule. |
12 | // |
13 | // The 128-bit tiles overlap with other element tiles as follows (see section |
14 | // B2.3.2 of SME spec [1]): |
15 | // |
16 | // Tile Overlaps |
17 | // --------------------------------------------------------------------------- |
18 | // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, |
19 | // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q |
20 | // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q |
21 | // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q |
22 | // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q |
23 | // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q |
24 | // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q |
25 | // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q |
26 | // ZA0.D ZA0.Q, ZA8.Q |
27 | // ZA1.D ZA1.Q, ZA9.Q |
28 | // ZA2.D ZA2.Q, ZA10.Q |
29 | // ZA3.D ZA3.Q, ZA11.Q |
30 | // ZA4.D ZA4.Q, ZA12.Q |
31 | // ZA5.D ZA5.Q, ZA13.Q |
32 | // ZA6.D ZA6.Q, ZA14.Q |
33 | // ZA7.D ZA7.Q, ZA15.Q |
34 | // |
35 | // The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use' |
36 | // that is initalized during the first tile allocation within a function and |
37 | // updated on each subsequent allocation. |
38 | // |
39 | // [1] https://developer.arm.com/documentation/ddi0616/aa |
40 | // |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
44 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
45 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
46 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
47 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
48 | #include "llvm/ADT/TypeSwitch.h" |
49 | |
50 | #define DEBUG_TYPE "allocate-arm-sme-tiles" |
51 | |
52 | namespace mlir { |
53 | namespace arm_sme { |
54 | #define GEN_PASS_DEF_TILEALLOCATION |
55 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
56 | } // namespace arm_sme |
57 | } // namespace mlir |
58 | |
59 | using namespace mlir; |
60 | using namespace mlir::arm_sme; |
61 | |
62 | namespace { |
63 | |
64 | static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use" ); |
65 | static constexpr StringLiteral |
66 | kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id" ); |
67 | |
68 | enum class TileMask : unsigned { |
69 | // clang-format off |
70 | kZA0B = 0xffff, // 1111 1111 1111 1111 |
71 | |
72 | kZA0H = 0xaaaa, // 1010 1010 1010 1010 |
73 | kZA1H = 0x5555, // 0101 0101 0101 0101 |
74 | |
75 | kZA0S = 0x8888, // 1000 1000 1000 1000 |
76 | kZA1S = 0x4444, // 0100 0100 0100 0100 |
77 | kZA2S = 0x2222, // 0010 0010 0010 0010 |
78 | kZA3S = 0x1111, // 0001 0001 0001 0001 |
79 | |
80 | kZA0D = 0x8080, // 1000 0000 1000 0000 |
81 | kZA1D = 0x4040, // 0100 0000 0100 0000 |
82 | kZA2D = 0x2020, // 0010 0000 0010 0000 |
83 | kZA3D = 0x1010, // 0001 0000 0001 0000 |
84 | kZA4D = 0x808, // 0000 1000 0000 1000 |
85 | kZA5D = 0x404, // 0000 0100 0000 0100 |
86 | kZA6D = 0x202, // 0000 0010 0000 0010 |
87 | kZA7D = 0x101, // 0000 0001 0000 0001 |
88 | |
89 | kZA0Q = 0x8000, // 1000 0000 0000 0000 |
90 | kZA1Q = 0x4000, // 0100 0000 0000 0000 |
91 | kZA2Q = 0x2000, // 0010 0000 0000 0000 |
92 | kZA3Q = 0x1000, // 0001 0000 0000 0000 |
93 | kZA4Q = 0x800, // 0000 1000 0000 0000 |
94 | kZA5Q = 0x400, // 0000 0100 0000 0000 |
95 | kZA6Q = 0x200, // 0000 0010 0000 0000 |
96 | kZA7Q = 0x100, // 0000 0001 0000 0000 |
97 | kZA8Q = 0x80, // 0000 0000 1000 0000 |
98 | kZA9Q = 0x40, // 0000 0000 0100 0000 |
99 | kZA10Q = 0x20, // 0000 0000 0010 0000 |
100 | kZA11Q = 0x10, // 0000 0000 0001 0000 |
101 | kZA12Q = 0x8, // 0000 0000 0000 1000 |
102 | kZA13Q = 0x4, // 0000 0000 0000 0100 |
103 | kZA14Q = 0x2, // 0000 0000 0000 0010 |
104 | kZA15Q = 0x1, // 0000 0000 0000 0001 |
105 | |
106 | kNone = 0x0, // 0000 0000 0000 0000 |
107 | // clang-format on |
108 | |
109 | LLVM_MARK_AS_BITMASK_ENUM(kZA0B) |
110 | }; |
111 | |
112 | /// Returns the set of masks relevant for the given type. |
113 | static ArrayRef<TileMask> getMasks(ArmSMETileType type) { |
114 | static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; |
115 | static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; |
116 | static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, |
117 | TileMask::kZA2S, TileMask::kZA3S}; |
118 | static constexpr std::array ZA_D_MASKS = { |
119 | TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, |
120 | TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; |
121 | static constexpr std::array ZA_Q_MASKS = { |
122 | TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, |
123 | TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, |
124 | TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, |
125 | TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; |
126 | switch (type) { |
127 | case ArmSMETileType::ZAB: |
128 | return ZA_B_MASKS; |
129 | case ArmSMETileType::ZAH: |
130 | return ZA_H_MASKS; |
131 | case ArmSMETileType::ZAS: |
132 | return ZA_S_MASKS; |
133 | case ArmSMETileType::ZAD: |
134 | return ZA_D_MASKS; |
135 | case ArmSMETileType::ZAQ: |
136 | return ZA_Q_MASKS; |
137 | } |
138 | } |
139 | |
140 | /// Allocates and returns a tile ID. Returns an error if there are no tiles |
141 | /// left. |
142 | static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType, |
143 | TileMask &tilesInUse) { |
144 | auto masks = getMasks(tileType); |
145 | for (auto [tileId, tileMask] : llvm::enumerate(masks)) { |
146 | if ((tilesInUse & tileMask) == TileMask::kNone) { |
147 | tilesInUse |= tileMask; |
148 | return tileId; |
149 | } |
150 | } |
151 | return failure(); |
152 | } |
153 | |
154 | /// Collects transitive uses of a root value through control flow. This can |
155 | /// handle basic SCF constructs, along with control flow (br and cond_br). |
156 | /// Simple loops work at the SCF level, while more complex control flow can be |
157 | /// dealt with after lowering to CF. This is used to implement basic tile |
158 | /// allocation. |
159 | static void findDependantOps(Value rootValue, |
160 | SetVector<Operation *> &dependantOps) { |
161 | auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) { |
162 | for (auto [idx, value] : llvm::enumerate(inputValues)) { |
163 | if (value == rootValue) |
164 | findDependantOps(exitValues[idx], dependantOps); |
165 | } |
166 | }; |
167 | for (Operation *user : rootValue.getUsers()) { |
168 | if (dependantOps.contains(user)) |
169 | continue; |
170 | dependantOps.insert(user); |
171 | TypeSwitch<Operation *>(user) |
172 | .Case<cf::BranchOp>([&](auto branchOp) { |
173 | // (CF) Follow branch. |
174 | traverseCorrespondingValues(branchOp.getDestOperands(), |
175 | branchOp.getDest()->getArguments()); |
176 | }) |
177 | .Case<cf::CondBranchOp>([&](auto condBranchOp) { |
178 | // (CF) Follow true branch. |
179 | traverseCorrespondingValues( |
180 | condBranchOp.getTrueOperands(), |
181 | condBranchOp.getTrueDest()->getArguments()); |
182 | // (CF) Follow false branch. |
183 | traverseCorrespondingValues( |
184 | condBranchOp.getFalseOperands(), |
185 | condBranchOp.getFalseDest()->getArguments()); |
186 | }) |
187 | .Case<LoopLikeOpInterface>([&](auto loopOp) { |
188 | // (SCF) Follow iter_args of (basic) loops (e.g. for loops). |
189 | traverseCorrespondingValues(loopOp.getInits(), |
190 | loopOp.getRegionIterArgs()); |
191 | }) |
192 | .Case<scf::YieldOp>([&](auto yieldOp) { |
193 | // (SCF) Follow yields of (basic) control flow (e.g. for loops). |
194 | auto parent = user->getParentOp(); |
195 | traverseCorrespondingValues(user->getOperands(), |
196 | parent->getResults()); |
197 | }) |
198 | .Default([&](auto) { |
199 | // Otherwise, assume users of _any_ result are dependant. |
200 | for (Value result : user->getResults()) |
201 | findDependantOps(result, dependantOps); |
202 | }); |
203 | } |
204 | } |
205 | struct AssignTileIDsPattern |
206 | : public OpInterfaceRewritePattern<ArmSMETileOpInterface> { |
207 | using OpInterfaceRewritePattern::OpInterfaceRewritePattern; |
208 | LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp, |
209 | PatternRewriter &rewriter) const override { |
210 | if (tileOp.getTileId()) |
211 | return failure(); |
212 | |
213 | auto func = tileOp->getParentOfType<FunctionOpInterface>(); |
214 | auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) { |
215 | if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>( |
216 | func->getDiscardableAttr(name))) |
217 | return unsigned(attr.getInt()); |
218 | return defaultVal; |
219 | }; |
220 | auto setDiscardableIntAttr = [&](StringRef name, auto value) { |
221 | rewriter.modifyOpInPlace(tileOp, [&] { |
222 | func->setDiscardableAttr(name, |
223 | rewriter.getI32IntegerAttr((unsigned)value)); |
224 | }); |
225 | }; |
226 | |
227 | std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType(); |
228 | if (!tileType) |
229 | return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile" ); |
230 | |
231 | TileMask tilesInUse = |
232 | static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr)); |
233 | auto tileId = allocateTileId(*tileType, tilesInUse); |
234 | bool tileIsInMemory = failed(tileId); |
235 | if (tileIsInMemory) { |
236 | // If we could not find a real tile ID, use an in-memory tile ID (ID >= |
237 | // 16). A later pass will insert the necessary spills and reloads. |
238 | tileId = |
239 | getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); |
240 | tileOp->emitWarning( |
241 | "failed to allocate SME virtual tile to operation, all tile " |
242 | "operations will go through memory, expect degraded performance" ); |
243 | } |
244 | |
245 | // Set all operations dependent on `tileOp` to use the same tile ID. |
246 | // This is a naive tile allocation scheme, but works for common cases. For |
247 | // example, as this only allocates tile IDs to existing ops, it can't solve |
248 | // cases like this (%tileA and %tileB come from different root operations): |
249 | // |
250 | // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> { |
251 | // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32> |
252 | // } else { |
253 | // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32> |
254 | // } |
255 | // |
256 | // This case would require allocating a new tile for the result of the |
257 | // scf.if, and moving the contents of %tileA or %tileB to result tile (based |
258 | // on the %some_cond). |
259 | // Find all the ops that (transitively) depend on this tile. |
260 | SetVector<Operation *> dependantOps; |
261 | findDependantOps(tileOp->getResult(0), dependantOps); |
262 | auto tileIDAttr = rewriter.getI32IntegerAttr(value: *tileId); |
263 | for (auto *op : dependantOps) { |
264 | if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) { |
265 | auto currentTileId = dependantTileOp.getTileId(); |
266 | if (currentTileId && unsigned(currentTileId.getInt()) != tileId) |
267 | return dependantTileOp.emitOpError( |
268 | "already assigned different SME virtual tile!" ); |
269 | } |
270 | } |
271 | |
272 | // Rewrite IR. |
273 | if (!tileIsInMemory) |
274 | setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); |
275 | else |
276 | setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); |
277 | rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); |
278 | for (auto *op : dependantOps) { |
279 | if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) { |
280 | rewriter.modifyOpInPlace( |
281 | dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); }); |
282 | } |
283 | } |
284 | |
285 | return success(); |
286 | } |
287 | }; |
288 | |
289 | struct TileAllocationPass |
290 | : public arm_sme::impl::TileAllocationBase<TileAllocationPass> { |
291 | void runOnOperation() override { |
292 | RewritePatternSet patterns(&getContext()); |
293 | patterns.add<AssignTileIDsPattern>(patterns.getContext()); |
294 | GreedyRewriteConfig config; |
295 | // Setting useTopDownTraversal ensures tiles are allocated in program |
296 | // order. |
297 | config.useTopDownTraversal = true; |
298 | if (mlir::failed(result: mlir::applyPatternsAndFoldGreedily( |
299 | getOperation(), std::move(patterns), config))) { |
300 | signalPassFailure(); |
301 | } |
302 | } |
303 | }; |
304 | } // namespace |
305 | |
306 | std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() { |
307 | return std::make_unique<TileAllocationPass>(); |
308 | } |
309 | |