1//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass allocates SME tiles at the 'func.func' op level for ArmSME
10// operations. It does this using a 16-bit tile mask that has a bit for each
11// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
12//
13// The 128-bit tiles overlap with other element tiles as follows (see section
14// B2.3.2 of SME spec [1]):
15//
16// Tile Overlaps
17// ---------------------------------------------------------------------------
18// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
19// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
20// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
21// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
22// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
23// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
24// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
25// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
26// ZA0.D ZA0.Q, ZA8.Q
27// ZA1.D ZA1.Q, ZA9.Q
28// ZA2.D ZA2.Q, ZA10.Q
29// ZA3.D ZA3.Q, ZA11.Q
30// ZA4.D ZA4.Q, ZA12.Q
31// ZA5.D ZA5.Q, ZA13.Q
32// ZA6.D ZA6.Q, ZA14.Q
33// ZA7.D ZA7.Q, ZA15.Q
34//
35// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
36// that is initalized during the first tile allocation within a function and
37// updated on each subsequent allocation.
38//
39// [1] https://developer.arm.com/documentation/ddi0616/aa
40//
41//===----------------------------------------------------------------------===//
42
43#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
44#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
45#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
46#include "mlir/Dialect/Func/IR/FuncOps.h"
47#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
48#include "llvm/ADT/TypeSwitch.h"
49
50#define DEBUG_TYPE "allocate-arm-sme-tiles"
51
52namespace mlir {
53namespace arm_sme {
54#define GEN_PASS_DEF_TILEALLOCATION
55#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
56} // namespace arm_sme
57} // namespace mlir
58
59using namespace mlir;
60using namespace mlir::arm_sme;
61
62namespace {
63
64static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
65static constexpr StringLiteral
66 kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
67
68enum class TileMask : unsigned {
69 // clang-format off
70 kZA0B = 0xffff, // 1111 1111 1111 1111
71
72 kZA0H = 0xaaaa, // 1010 1010 1010 1010
73 kZA1H = 0x5555, // 0101 0101 0101 0101
74
75 kZA0S = 0x8888, // 1000 1000 1000 1000
76 kZA1S = 0x4444, // 0100 0100 0100 0100
77 kZA2S = 0x2222, // 0010 0010 0010 0010
78 kZA3S = 0x1111, // 0001 0001 0001 0001
79
80 kZA0D = 0x8080, // 1000 0000 1000 0000
81 kZA1D = 0x4040, // 0100 0000 0100 0000
82 kZA2D = 0x2020, // 0010 0000 0010 0000
83 kZA3D = 0x1010, // 0001 0000 0001 0000
84 kZA4D = 0x808, // 0000 1000 0000 1000
85 kZA5D = 0x404, // 0000 0100 0000 0100
86 kZA6D = 0x202, // 0000 0010 0000 0010
87 kZA7D = 0x101, // 0000 0001 0000 0001
88
89 kZA0Q = 0x8000, // 1000 0000 0000 0000
90 kZA1Q = 0x4000, // 0100 0000 0000 0000
91 kZA2Q = 0x2000, // 0010 0000 0000 0000
92 kZA3Q = 0x1000, // 0001 0000 0000 0000
93 kZA4Q = 0x800, // 0000 1000 0000 0000
94 kZA5Q = 0x400, // 0000 0100 0000 0000
95 kZA6Q = 0x200, // 0000 0010 0000 0000
96 kZA7Q = 0x100, // 0000 0001 0000 0000
97 kZA8Q = 0x80, // 0000 0000 1000 0000
98 kZA9Q = 0x40, // 0000 0000 0100 0000
99 kZA10Q = 0x20, // 0000 0000 0010 0000
100 kZA11Q = 0x10, // 0000 0000 0001 0000
101 kZA12Q = 0x8, // 0000 0000 0000 1000
102 kZA13Q = 0x4, // 0000 0000 0000 0100
103 kZA14Q = 0x2, // 0000 0000 0000 0010
104 kZA15Q = 0x1, // 0000 0000 0000 0001
105
106 kNone = 0x0, // 0000 0000 0000 0000
107 // clang-format on
108
109 LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
110};
111
112/// Returns the set of masks relevant for the given type.
113static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
114 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117 TileMask::kZA2S, TileMask::kZA3S};
118 static constexpr std::array ZA_D_MASKS = {
119 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121 static constexpr std::array ZA_Q_MASKS = {
122 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
126 switch (type) {
127 case ArmSMETileType::ZAB:
128 return ZA_B_MASKS;
129 case ArmSMETileType::ZAH:
130 return ZA_H_MASKS;
131 case ArmSMETileType::ZAS:
132 return ZA_S_MASKS;
133 case ArmSMETileType::ZAD:
134 return ZA_D_MASKS;
135 case ArmSMETileType::ZAQ:
136 return ZA_Q_MASKS;
137 }
138}
139
140/// Allocates and returns a tile ID. Returns an error if there are no tiles
141/// left.
142static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
143 TileMask &tilesInUse) {
144 auto masks = getMasks(tileType);
145 for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
146 if ((tilesInUse & tileMask) == TileMask::kNone) {
147 tilesInUse |= tileMask;
148 return tileId;
149 }
150 }
151 return failure();
152}
153
154/// Collects transitive uses of a root value through control flow. This can
155/// handle basic SCF constructs, along with control flow (br and cond_br).
156/// Simple loops work at the SCF level, while more complex control flow can be
157/// dealt with after lowering to CF. This is used to implement basic tile
158/// allocation.
159static void findDependantOps(Value rootValue,
160 SetVector<Operation *> &dependantOps) {
161 auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
162 for (auto [idx, value] : llvm::enumerate(inputValues)) {
163 if (value == rootValue)
164 findDependantOps(exitValues[idx], dependantOps);
165 }
166 };
167 for (Operation *user : rootValue.getUsers()) {
168 if (dependantOps.contains(user))
169 continue;
170 dependantOps.insert(user);
171 TypeSwitch<Operation *>(user)
172 .Case<cf::BranchOp>([&](auto branchOp) {
173 // (CF) Follow branch.
174 traverseCorrespondingValues(branchOp.getDestOperands(),
175 branchOp.getDest()->getArguments());
176 })
177 .Case<cf::CondBranchOp>([&](auto condBranchOp) {
178 // (CF) Follow true branch.
179 traverseCorrespondingValues(
180 condBranchOp.getTrueOperands(),
181 condBranchOp.getTrueDest()->getArguments());
182 // (CF) Follow false branch.
183 traverseCorrespondingValues(
184 condBranchOp.getFalseOperands(),
185 condBranchOp.getFalseDest()->getArguments());
186 })
187 .Case<LoopLikeOpInterface>([&](auto loopOp) {
188 // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
189 traverseCorrespondingValues(loopOp.getInits(),
190 loopOp.getRegionIterArgs());
191 })
192 .Case<scf::YieldOp>([&](auto yieldOp) {
193 // (SCF) Follow yields of (basic) control flow (e.g. for loops).
194 auto parent = user->getParentOp();
195 traverseCorrespondingValues(user->getOperands(),
196 parent->getResults());
197 })
198 .Default([&](auto) {
199 // Otherwise, assume users of _any_ result are dependant.
200 for (Value result : user->getResults())
201 findDependantOps(result, dependantOps);
202 });
203 }
204}
205struct AssignTileIDsPattern
206 : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
207 using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
208 LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
209 PatternRewriter &rewriter) const override {
210 if (tileOp.getTileId())
211 return failure();
212
213 auto func = tileOp->getParentOfType<FunctionOpInterface>();
214 auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
215 if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
216 func->getDiscardableAttr(name)))
217 return unsigned(attr.getInt());
218 return defaultVal;
219 };
220 auto setDiscardableIntAttr = [&](StringRef name, auto value) {
221 rewriter.modifyOpInPlace(tileOp, [&] {
222 func->setDiscardableAttr(name,
223 rewriter.getI32IntegerAttr((unsigned)value));
224 });
225 };
226
227 std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
228 if (!tileType)
229 return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
230
231 TileMask tilesInUse =
232 static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
233 auto tileId = allocateTileId(*tileType, tilesInUse);
234 bool tileIsInMemory = failed(tileId);
235 if (tileIsInMemory) {
236 // If we could not find a real tile ID, use an in-memory tile ID (ID >=
237 // 16). A later pass will insert the necessary spills and reloads.
238 tileId =
239 getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
240 tileOp->emitWarning(
241 "failed to allocate SME virtual tile to operation, all tile "
242 "operations will go through memory, expect degraded performance");
243 }
244
245 // Set all operations dependent on `tileOp` to use the same tile ID.
246 // This is a naive tile allocation scheme, but works for common cases. For
247 // example, as this only allocates tile IDs to existing ops, it can't solve
248 // cases like this (%tileA and %tileB come from different root operations):
249 //
250 // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
251 // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
252 // } else {
253 // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
254 // }
255 //
256 // This case would require allocating a new tile for the result of the
257 // scf.if, and moving the contents of %tileA or %tileB to result tile (based
258 // on the %some_cond).
259 // Find all the ops that (transitively) depend on this tile.
260 SetVector<Operation *> dependantOps;
261 findDependantOps(tileOp->getResult(0), dependantOps);
262 auto tileIDAttr = rewriter.getI32IntegerAttr(value: *tileId);
263 for (auto *op : dependantOps) {
264 if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
265 auto currentTileId = dependantTileOp.getTileId();
266 if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
267 return dependantTileOp.emitOpError(
268 "already assigned different SME virtual tile!");
269 }
270 }
271
272 // Rewrite IR.
273 if (!tileIsInMemory)
274 setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
275 else
276 setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
277 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
278 for (auto *op : dependantOps) {
279 if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
280 rewriter.modifyOpInPlace(
281 dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
282 }
283 }
284
285 return success();
286 }
287};
288
289struct TileAllocationPass
290 : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
291 void runOnOperation() override {
292 RewritePatternSet patterns(&getContext());
293 patterns.add<AssignTileIDsPattern>(patterns.getContext());
294 GreedyRewriteConfig config;
295 // Setting useTopDownTraversal ensures tiles are allocated in program
296 // order.
297 config.useTopDownTraversal = true;
298 if (mlir::failed(result: mlir::applyPatternsAndFoldGreedily(
299 getOperation(), std::move(patterns), config))) {
300 signalPassFailure();
301 }
302 }
303};
304} // namespace
305
306std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
307 return std::make_unique<TileAllocationPass>();
308}
309

source code of mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp