1 | //===- ShardingInterface.cpp -------------------------------------*- C++-*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
10 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" |
11 | |
12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
13 | #include "mlir/IR/AffineMap.h" |
14 | #include "mlir/IR/IRMapping.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | #include "llvm/ADT/ArrayRef.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/ADT/SmallSet.h" |
19 | #include "llvm/Support/Debug.h" |
20 | |
21 | #include <utility> |
22 | |
23 | #define DEBUG_TYPE "sharding-interface" |
24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::mesh; |
28 | |
29 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // common util functions |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | static LogicalResult |
36 | checkOperandAffineExprRecursively(AffineExpr expr, |
37 | SmallVectorImpl<bool> &seenIds) { |
38 | switch (expr.getKind()) { |
39 | case AffineExprKind::Add: { |
40 | auto binOpExpr = cast<AffineBinaryOpExpr>(expr); |
41 | AffineExpr lhs = binOpExpr.getLHS(); |
42 | AffineExpr rhs = binOpExpr.getRHS(); |
43 | if (failed(Result: checkOperandAffineExprRecursively(expr: lhs, seenIds))) |
44 | return failure(); |
45 | if (failed(Result: checkOperandAffineExprRecursively(expr: rhs, seenIds))) |
46 | return failure(); |
47 | return success(); |
48 | } |
49 | case AffineExprKind::Mul: { |
50 | auto binOpExpr = cast<AffineBinaryOpExpr>(expr); |
51 | AffineExpr lhs = binOpExpr.getLHS(); |
52 | AffineExpr rhs = binOpExpr.getRHS(); |
53 | AffineExpr dimExpr; |
54 | if (lhs.getKind() == AffineExprKind::DimId && |
55 | rhs.getKind() == AffineExprKind::Constant) { |
56 | dimExpr = lhs; |
57 | } else if (rhs.getKind() == AffineExprKind::DimId && |
58 | lhs.getKind() == AffineExprKind::Constant) { |
59 | dimExpr = rhs; |
60 | } else { |
61 | return failure(); |
62 | } |
63 | unsigned position = cast<AffineDimExpr>(dimExpr).getPosition(); |
64 | if ((size_t)position >= seenIds.size() || seenIds[position]) |
65 | return failure(); |
66 | seenIds[position] = true; |
67 | return success(); |
68 | } |
69 | case AffineExprKind::DimId: { |
70 | unsigned position = cast<AffineDimExpr>(expr).getPosition(); |
71 | if ((size_t)position >= seenIds.size() || seenIds[position]) |
72 | return failure(); |
73 | seenIds[position] = true; |
74 | return success(); |
75 | } |
76 | default: |
77 | return failure(); |
78 | } |
79 | } |
80 | |
81 | static FailureOr<llvm::SmallSet<unsigned, 2>> |
82 | checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { |
83 | SmallVector<bool> seenIds(numDims, false); |
84 | if (failed(Result: checkOperandAffineExprRecursively(expr, seenIds))) |
85 | return failure(); |
86 | |
87 | llvm::SmallSet<unsigned, 2> positions; |
88 | for (auto it : llvm::enumerate(First&: seenIds)) { |
89 | if (it.value()) |
90 | positions.insert(V: (unsigned)it.index()); |
91 | } |
92 | return positions; |
93 | } |
94 | |
95 | template <typename T> |
96 | SmallVector<MeshAxesAttr> |
97 | fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) { |
98 | SmallVector<MeshAxesAttr> res; |
99 | for (const auto &v : vec) { |
100 | res.emplace_back(MeshAxesAttr::get(ctxt, v)); |
101 | } |
102 | return res; |
103 | } |
104 | |
105 | //===----------------------------------------------------------------------===// |
106 | // mesh::getMeshSharding |
107 | //===----------------------------------------------------------------------===// |
108 | |
109 | FailureOr<std::pair<bool, MeshSharding>> |
110 | mesh::getMeshSharding(OpResult result) { |
111 | Value val = cast<Value>(Val&: result); |
112 | bool anyShardedForDef = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) { |
113 | auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); |
114 | if (!shardOp) |
115 | return false; |
116 | return !shardOp.getAnnotateForUsers(); |
117 | }); |
118 | |
119 | if (anyShardedForDef) { |
120 | // expected to have exact one use if it has a use of `mesh.shard` without |
121 | // unit attr annotate_for_users |
122 | if (!val.hasOneUse()) |
123 | return failure(); |
124 | auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin()); |
125 | return std::make_pair(x: false, y: MeshSharding(shardOp.getSharding())); |
126 | } |
127 | |
128 | bool anyShardedForUsers = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) { |
129 | auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); |
130 | if (!shardOp) |
131 | return false; |
132 | return shardOp.getAnnotateForUsers(); |
133 | }); |
134 | if (anyShardedForUsers) { |
135 | SmallVector<ShardOp> shardOps; |
136 | for (Operation *user : val.getUsers()) { |
137 | ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); |
138 | if (shardOp) |
139 | shardOps.push_back(shardOp); |
140 | } |
141 | MeshSharding shardForDef = shardOps[0].getSharding(); |
142 | for (size_t i = 1; i < shardOps.size(); ++i) { |
143 | // TODO: Deduce a reasonable mesh sharding attr for def when they are |
144 | // different |
145 | assert(shardForDef == shardOps[i].getSharding() && |
146 | "only support all shard ops have the same mesh sharding attr"); |
147 | } |
148 | return std::make_pair(x: true, y&: shardForDef); |
149 | } |
150 | return failure(); |
151 | } |
152 | |
153 | FailureOr<std::pair<bool, MeshSharding>> |
154 | mesh::getMeshSharding(OpOperand &opOperand) { |
155 | Value val = opOperand.get(); |
156 | if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) |
157 | return std::make_pair(shardOp.getAnnotateForUsers(), |
158 | MeshSharding(shardOp.getSharding())); |
159 | |
160 | return failure(); |
161 | } |
162 | |
163 | //===----------------------------------------------------------------------===// |
164 | // ShardingInterface::verifyShardingInterfaceImpl |
165 | //===----------------------------------------------------------------------===// |
166 | |
167 | LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { |
168 | Operation *op = getOperation(); |
169 | |
170 | // check operands and results type |
171 | for (Type type : op->getOperandTypes()) |
172 | if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat()) |
173 | return failure(); |
174 | for (Type type : op->getResultTypes()) |
175 | if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat()) |
176 | return failure(); |
177 | |
178 | // check maps |
179 | SmallVector<AffineMap> maps = getIndexingMaps(); |
180 | if (maps.empty()) |
181 | return failure(); |
182 | unsigned numOperands = op->getNumOperands(); |
183 | unsigned numResults = op->getNumResults(); |
184 | if (numOperands + numResults != maps.size()) |
185 | return failure(); |
186 | |
187 | for (OpResult result : op->getResults()) { |
188 | auto resultType = dyn_cast<RankedTensorType>(result.getType()); |
189 | if (!resultType) |
190 | return failure(); |
191 | AffineMap map = maps[numOperands + result.getResultNumber()]; |
192 | if (!map.isProjectedPermutation()) { |
193 | return failure(); |
194 | } |
195 | } |
196 | |
197 | return success(); |
198 | } |
199 | |
200 | //===----------------------------------------------------------------------===// |
201 | // ShardingInterface::printLoopTypesAndIndexingMaps |
202 | //===----------------------------------------------------------------------===// |
203 | |
204 | void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { |
205 | os << "print loop types and indexing maps for: \n"; |
206 | getOperation()->print(os); |
207 | os << "\n"; |
208 | os << "loop types: ["; |
209 | for (utils::IteratorType type : getLoopIteratorTypes()) { |
210 | os << stringifyEnum(type) << " "; |
211 | } |
212 | os << "]\n"; |
213 | os << "indexing maps: \n"; |
214 | for (AffineMap map : getIndexingMaps()) |
215 | os << map << "\n"; |
216 | os << "\n"; |
217 | } |
218 | |
219 | //===----------------------------------------------------------------------===// |
220 | // detail::defaultGetShardingOption |
221 | //===----------------------------------------------------------------------===// |
222 | |
223 | namespace { |
224 | |
225 | // Update the given `shardingOption` according to `meshAxes` and `loopIdx` |
226 | static LogicalResult fillShardingOption(Operation *op, |
227 | ShardingOption &shardingOption, |
228 | FlatSymbolRefAttr mesh, |
229 | ArrayRef<MeshAxis> meshAxes, |
230 | unsigned loopIdx) { |
231 | if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || |
232 | (!shardingOption.shardingArray[loopIdx].empty() && |
233 | shardingOption.shardingArray[loopIdx] != meshAxes)) { |
234 | LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " |
235 | << loopIdx << "\n"); |
236 | return failure(); |
237 | } |
238 | for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) { |
239 | if (i == loopIdx) |
240 | continue; |
241 | |
242 | for (MeshAxis axis : meshAxes) { |
243 | if (llvm::is_contained(Range&: shardingOption.shardingArray[i], Element: axis)) { |
244 | LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " |
245 | << axis << " duplicate"); |
246 | return failure(); |
247 | } |
248 | } |
249 | } |
250 | if (mesh) |
251 | shardingOption.mesh = mesh; |
252 | if (shardingOption.shardingArray[loopIdx].empty()) |
253 | shardingOption.shardingArray[loopIdx].append(in_start: meshAxes.begin(), |
254 | in_end: meshAxes.end()); |
255 | return success(); |
256 | } |
257 | |
258 | } // namespace |
259 | |
260 | FailureOr<ShardingOption> |
261 | mesh::detail::defaultGetShardingOption(Operation *op, |
262 | ArrayRef<MeshSharding> operandShardings, |
263 | ArrayRef<MeshSharding> resultShardings) { |
264 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
265 | ShardingOption shardingOption; |
266 | |
267 | if (failed(shardingOp.verifyShardingInterfaceImpl())) |
268 | return op->emitOpError() << "invalid sharding interface implementation"; |
269 | SmallVector<utils::IteratorType> loopTypes = |
270 | shardingOp.getLoopIteratorTypes(); |
271 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
272 | unsigned numOperands = op->getNumOperands(); |
273 | shardingOption.shardingArray.resize(loopTypes.size()); |
274 | llvm::SmallVector<MeshAxis> partialMeshAxes; |
275 | llvm::SmallSet<unsigned, 4> visitedLoopIndices; |
276 | bool anyShardingInResultsOrOperands = false; |
277 | |
278 | // 1. Fill sharding option based on op results |
279 | for (auto shardingIt : llvm::enumerate(First&: resultShardings)) { |
280 | MeshSharding shardAttr = shardingIt.value(); |
281 | if (!shardAttr) |
282 | continue; |
283 | AffineMap map = maps[numOperands + shardingIt.index()]; |
284 | anyShardingInResultsOrOperands = true; |
285 | if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { |
286 | shardingOption.mesh = shardAttr.getMeshAttr(); |
287 | } else { |
288 | // Handle the split axes: calculate the corresponding loop index for each |
289 | // split axes sub-array, and then store the sub-array to |
290 | // shardingOption[index] |
291 | for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { |
292 | AffineExpr expr = std::get<0>(it); |
293 | ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); |
294 | auto dim = cast<AffineDimExpr>(expr); |
295 | unsigned index = dim.getPosition(); |
296 | visitedLoopIndices.insert(index); |
297 | if (failed(fillShardingOption(op, shardingOption, |
298 | shardAttr.getMeshAttr(), axes, index))) |
299 | return failure(); |
300 | } |
301 | } |
302 | |
303 | // Handle the partial axes: at this stage, the exact loop index/indices |
304 | // cannot be decided because there could be multiple reduction loops. |
305 | ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes(); |
306 | if (!partialAxes.empty()) { |
307 | if (!partialMeshAxes.empty()) |
308 | return op->emitOpError() << "at most one result with partial axes is " |
309 | "supported at present"; |
310 | partialMeshAxes.append(in_start: partialAxes.begin(), in_end: partialAxes.end()); |
311 | // Add all the reduction loop indices to `visitedLoopIndices` if |
312 | // `partialAxes` is not empty |
313 | for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) { |
314 | if (isReductionLoop(loopTypes[loopIdx])) |
315 | visitedLoopIndices.insert(V: loopIdx); |
316 | } |
317 | } |
318 | } |
319 | |
320 | // 2. Fill sharding option based on operands |
321 | for (auto shardingIt : llvm::enumerate(First&: operandShardings)) { |
322 | MeshSharding shardAttr = shardingIt.value(); |
323 | if (!shardAttr) |
324 | continue; |
325 | |
326 | anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty(); |
327 | AffineMap map = maps[shardingIt.index()]; |
328 | unsigned numDims = map.getNumDims(); |
329 | |
330 | // Handle the split axes. Partial axes don't need to be handled because they |
331 | // only affect the defining op of the operand. |
332 | // |
333 | // TODO: Change to process the operands with single loop index first and |
334 | // then the operands with multiple loop indices. |
335 | for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { |
336 | AffineExpr expr = std::get<0>(it); |
337 | ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); |
338 | FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = |
339 | checkOperandAffineExpr(expr, numDims); |
340 | if (failed(loopIndices)) |
341 | return op->emitOpError() |
342 | << "operand's affine expression is restricted to const_i * " |
343 | "dim_i + const_j + dim_j + ..."; |
344 | if (loopIndices->empty()) |
345 | continue; |
346 | if (loopIndices->size() == 1) { |
347 | unsigned loopIdx = *loopIndices->begin(); |
348 | visitedLoopIndices.insert(loopIdx); |
349 | if (failed(fillShardingOption(op, shardingOption, |
350 | shardAttr.getMeshAttr(), axes, loopIdx))) |
351 | return failure(); |
352 | } |
353 | // If multiple loop indices correspond to a dimension of an operand, it is |
354 | // difficult to infer which loop indices are responsible for sharding. |
355 | // Therefore, the exact loop index must be specified by others. |
356 | if (loopIndices->size() > 1) { |
357 | bool seenLoopIndices = false; |
358 | for (unsigned loopIdx : *loopIndices) { |
359 | if (visitedLoopIndices.contains(loopIdx)) { |
360 | seenLoopIndices = true; |
361 | break; |
362 | } |
363 | } |
364 | if (!seenLoopIndices) |
365 | return op->emitOpError() |
366 | << "the operand "<< shardingIt.index() |
367 | << " has multiple loop indices in a dimension, but none of " |
368 | "them could be found in the exactly specified annotation " |
369 | "of op results or operands."; |
370 | } |
371 | } |
372 | } |
373 | |
374 | // 3. Finalize sharding option |
375 | if (!partialMeshAxes.empty()) { |
376 | bool anyNonEmptyReductionLoop = llvm::any_of( |
377 | Range: llvm::enumerate(First&: shardingOption.shardingArray), P: [&](auto it) { |
378 | SmallVector<MeshAxis> &subArray = it.value(); |
379 | int64_t idx = it.index(); |
380 | return isReductionLoop(loopTypes[idx]) && !subArray.empty(); |
381 | }); |
382 | if (!anyNonEmptyReductionLoop) { |
383 | bool filled = false; |
384 | for (size_t idx = 0; idx < loopTypes.size(); ++idx) { |
385 | if (isReductionLoop(loopTypes[idx])) { |
386 | std::ignore = fillShardingOption(op, shardingOption, nullptr, |
387 | partialMeshAxes, idx); |
388 | filled = true; |
389 | break; |
390 | } |
391 | } |
392 | if (!filled) |
393 | return op->emitOpError() << "no matched reduction loop found for the " |
394 | "result's partial type"; |
395 | } |
396 | } |
397 | removeTrailingEmptySubArray(array&: shardingOption.shardingArray); |
398 | if (!anyShardingInResultsOrOperands) |
399 | shardingOption.empty = true; |
400 | return shardingOption; |
401 | } |
402 | |
403 | // Get the sharding attributed for the given result and sharding option. |
404 | MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, |
405 | AffineMap map, ArrayRef<utils::IteratorType> loopTypes, |
406 | ArrayRef<ReductionKind> reductionLoopKinds) { |
407 | auto resultType = cast<RankedTensorType>(result.getType()); |
408 | SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank()); |
409 | SmallVector<MeshAxis> partialAxes; |
410 | |
411 | // process the split axes |
412 | for (auto it : llvm::enumerate(First: map.getResults())) { |
413 | AffineExpr expr = it.value(); |
414 | // `expr` must be an `AffineDimExpr` because `map` is verified by |
415 | // isProjectedPermutation |
416 | auto dim = cast<AffineDimExpr>(Val&: expr); |
417 | unsigned loopIdx = dim.getPosition(); |
418 | if (loopIdx < shardingOption.shardingArray.size()) |
419 | splitAxes[it.index()].append(RHS: shardingOption.shardingArray[loopIdx]); |
420 | } |
421 | |
422 | // process the partial axes |
423 | // partialType will be ignored if partialAxes is empty |
424 | ReductionKind partialType = ReductionKind::Sum; |
425 | size_t reductionLoopKindsIdx = 0; |
426 | for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) { |
427 | utils::IteratorType iType = std::get<0>(it); |
428 | if (isReductionLoop(iType)) { |
429 | ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx]; |
430 | ++reductionLoopKindsIdx; |
431 | if (!partialAxes.empty()) |
432 | assert(partialType == curPartialType && |
433 | "Only one reduction type is supported"); |
434 | partialType = curPartialType; |
435 | const SmallVector<MeshAxis> &axis = std::get<1>(it); |
436 | partialAxes.append(axis); |
437 | } |
438 | } |
439 | |
440 | removeTrailingEmptySubArray(array&: splitAxes); |
441 | return MeshSharding::get(shardingOption.mesh, |
442 | fromArrayOfVector(result.getContext(), splitAxes), |
443 | partialAxes, partialType); |
444 | } |
445 | |
446 | static FailureOr<MeshSharding> getSharding(OpOperand &opOperand, |
447 | const ShardingOption &shardingOption, |
448 | AffineMap map) { |
449 | Value operandValue = opOperand.get(); |
450 | auto operandType = dyn_cast<RankedTensorType>(operandValue.getType()); |
451 | if (!operandType) { |
452 | if (operandValue.getType().isIntOrIndexOrFloat()) |
453 | return MeshSharding(); |
454 | return failure(); |
455 | } |
456 | // 0d tensors cannot be sharded and must get replicated |
457 | if (operandType.getRank() == 0) { |
458 | return MeshSharding(shardingOption.mesh); |
459 | } |
460 | SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank()); |
461 | unsigned numDims = map.getNumDims(); |
462 | for (auto it : llvm::enumerate(First: map.getResults())) { |
463 | int64_t idx = it.index(); |
464 | AffineExpr expr = it.value(); |
465 | FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = |
466 | checkOperandAffineExpr(expr, numDims); |
467 | if (failed(Result: loopIndices)) |
468 | return failure(); |
469 | SmallVector<unsigned> shardedLoopIndices; |
470 | for (unsigned loopIdx : *loopIndices) { |
471 | if ((size_t)loopIdx < shardingOption.shardingArray.size() && |
472 | !shardingOption.shardingArray[loopIdx].empty()) |
473 | shardedLoopIndices.push_back(Elt: loopIdx); |
474 | } |
475 | // mostly one sharded loop index is accepted |
476 | if (shardedLoopIndices.size() > 1) |
477 | return failure(); |
478 | if (shardedLoopIndices.size() == 1) { |
479 | splitAxes[idx].append( |
480 | RHS: shardingOption.shardingArray[shardedLoopIndices[0]]); |
481 | } |
482 | } |
483 | |
484 | removeTrailingEmptySubArray(array&: splitAxes); |
485 | return MeshSharding::get( |
486 | shardingOption.mesh, |
487 | fromArrayOfVector(opOperand.get().getContext(), splitAxes)); |
488 | } |
489 | |
490 | FailureOr<std::vector<MeshSharding>> |
491 | mesh::detail::defaultGetShardingAnnotations( |
492 | Operation *op, const ShardingOption &shardingOption) { |
493 | std::vector<MeshSharding> res; |
494 | |
495 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
496 | SmallVector<utils::IteratorType> loopTypes = |
497 | shardingOp.getLoopIteratorTypes(); |
498 | SmallVector<ReductionKind> reductionKinds = |
499 | shardingOp.getReductionLoopIteratorKinds(); |
500 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
501 | unsigned numOperands = op->getNumOperands(); |
502 | |
503 | for (OpOperand &opOperand : op->getOpOperands()) { |
504 | FailureOr<MeshSharding> shardingAttr = getSharding( |
505 | opOperand, shardingOption, maps[opOperand.getOperandNumber()]); |
506 | if (failed(Result: shardingAttr)) |
507 | return failure(); |
508 | res.push_back(*shardingAttr); |
509 | } |
510 | |
511 | for (OpResult result : op->getResults()) { |
512 | res.push_back(getSharding(result, shardingOption, |
513 | maps[numOperands + result.getResultNumber()], |
514 | loopTypes, reductionKinds)); |
515 | } |
516 | |
517 | return res; |
518 | } |
519 | |
520 | //===----------------------------------------------------------------------===// |
521 | // detail::defaultAddShardingAnnotations |
522 | //===----------------------------------------------------------------------===// |
523 | |
524 | // To add a `mesh.shard` op for the given result, based on the details provided |
525 | // in `shardingOption`, `map`, and `loopTypes`. |
526 | static LogicalResult addShardOp(OpBuilder &b, OpResult result, |
527 | const ShardingOption &shardingOption, |
528 | AffineMap map, |
529 | ArrayRef<utils::IteratorType> loopTypes, |
530 | ArrayRef<ReductionKind> reductionLoopKinds) { |
531 | MeshSharding sharding = |
532 | getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds); |
533 | maybeInsertTargetShardingAnnotation(sharding, result, builder&: b); |
534 | |
535 | return success(); |
536 | } |
537 | |
538 | // To add a `mesh.shard` op for the given operand, based on the details provided |
539 | // in `shardingOption`, `map`, and `loopTypes`. |
540 | static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, |
541 | const ShardingOption &shardingOption, |
542 | AffineMap map) { |
543 | |
544 | FailureOr<MeshSharding> sharding = |
545 | getSharding(opOperand, shardingOption, map); |
546 | if (failed(Result: sharding)) { |
547 | return failure(); |
548 | } |
549 | OpBuilder::InsertionGuard guard(b); |
550 | maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b); |
551 | |
552 | return success(); |
553 | } |
554 | |
555 | LogicalResult mesh::detail::defaultAddShardingAnnotations( |
556 | Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { |
557 | assert(!shardingOption.empty && shardingOption.mesh); |
558 | |
559 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
560 | SmallVector<utils::IteratorType> loopTypes = |
561 | shardingOp.getLoopIteratorTypes(); |
562 | SmallVector<ReductionKind> reductionKinds = |
563 | shardingOp.getReductionLoopIteratorKinds(); |
564 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
565 | unsigned numOperands = op->getNumOperands(); |
566 | |
567 | // 1. add mesh.shard ops for all op results |
568 | for (OpResult result : op->getResults()) { |
569 | if (failed(addShardOp(b, result, shardingOption, |
570 | maps[numOperands + result.getResultNumber()], |
571 | loopTypes, reductionKinds))) |
572 | return failure(); |
573 | } |
574 | |
575 | // 2. add mesh.shard ops for all operands |
576 | for (OpOperand &opOperand : op->getOpOperands()) { |
577 | if (failed(Result: addShardOp(b, opOperand, shardingOption, |
578 | map: maps[opOperand.getOperandNumber()]))) |
579 | return failure(); |
580 | } |
581 | |
582 | return success(); |
583 | } |
584 | |
585 | #ifndef NDEBUG |
586 | static bool |
587 | isValueCompatibleWithFullReplicationSharding(Value value, |
588 | MeshSharding sharding) { |
589 | if (isa<RankedTensorType>(Val: value.getType())) { |
590 | return isFullReplication(sharding); |
591 | } |
592 | |
593 | return !sharding; |
594 | } |
595 | |
596 | template <typename ValueRange, typename MeshShardingRage> |
597 | static bool |
598 | areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, |
599 | MeshShardingRage &&shardings) { |
600 | if (std::size(values) != std::size(shardings)) { |
601 | return false; |
602 | } |
603 | return llvm::all_of( |
604 | llvm::zip_equal(std::forward<ValueRange>(values), |
605 | std::forward<MeshShardingRage>(shardings)), |
606 | [](auto valueAndSharding) { |
607 | return isValueCompatibleWithFullReplicationSharding( |
608 | std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); |
609 | }); |
610 | } |
611 | #endif // NDEBUG |
612 | |
613 | void mesh::spmdizeFullyReplicatedOperation( |
614 | Operation &op, ArrayRef<Value> spmdizedOperands, |
615 | ArrayRef<MeshSharding> operandShardings, |
616 | ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, |
617 | SymbolTableCollection &symbolTable, OpBuilder &builder) { |
618 | assert(spmdizedOperands.size() == operandShardings.size()); |
619 | assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), |
620 | operandShardings)); |
621 | assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), |
622 | resultShardings)); |
623 | // `clone` will populate the mapping of old to new results. |
624 | builder.clone(op, mapper&: spmdizationMap); |
625 | } |
626 | |
627 | static void updateMeshAxisAssignmentForLoopIterators( |
628 | ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, |
629 | SmallVector<std::optional<SmallVector<MeshAxis>>> |
630 | &meshAxesAssignmentForLoopIterators) { |
631 | AffineDimExpr affineDimExpr = cast<AffineDimExpr>(Val&: indexingExpr); |
632 | unsigned loopIteratorIdx = affineDimExpr.getPosition(); |
633 | if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { |
634 | assert(llvm::equal(meshAxesAssignmentForTensorAxis, |
635 | *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); |
636 | } else { |
637 | meshAxesAssignmentForLoopIterators[loopIteratorIdx] = |
638 | llvm::to_vector(Range&: meshAxesAssignmentForTensorAxis); |
639 | } |
640 | } |
641 | |
642 | ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( |
643 | ArrayRef<MeshSharding> operandShardings, |
644 | ArrayRef<MeshSharding> resultShardings, |
645 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
646 | ArrayRef<AffineMap> indexingMaps) { |
647 | SmallVector<std::optional<SmallVector<MeshAxis>>> |
648 | meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); |
649 | std::vector<MeshSharding> operatorAndResultShardings; |
650 | operatorAndResultShardings.reserve(n: operandShardings.size() + |
651 | resultShardings.size()); |
652 | llvm::append_range(C&: operatorAndResultShardings, R&: operandShardings); |
653 | for (auto [sharding, affineMap] : |
654 | llvm::zip_equal(t&: operatorAndResultShardings, u&: indexingMaps)) { |
655 | if (!sharding) { |
656 | continue; |
657 | } |
658 | for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : |
659 | llvm::zip(t: sharding.getSplitAxes(), u: affineMap.getResults())) { |
660 | updateMeshAxisAssignmentForLoopIterators( |
661 | meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, |
662 | meshAxisAssignmentForLoopIterators); |
663 | } |
664 | // Missing trailing split axes means replication on those tensor dimensions. |
665 | for (unsigned i = sharding.getSplitAxes().size(); |
666 | i < affineMap.getNumResults(); ++i) { |
667 | updateMeshAxisAssignmentForLoopIterators( |
668 | meshAxesAssignmentForTensorAxis: {}, indexingExpr: affineMap.getResults()[i], meshAxesAssignmentForLoopIterators&: meshAxisAssignmentForLoopIterators); |
669 | } |
670 | } |
671 | |
672 | ShardingArray res; |
673 | llvm::transform(Range&: meshAxisAssignmentForLoopIterators, d_first: std::back_inserter(x&: res), |
674 | F: [](std::optional<SmallVector<MeshAxis>> &axes) { |
675 | if (!axes) { |
676 | return SmallVector<MeshAxis>(); |
677 | }; |
678 | return std::move(*axes); |
679 | }); |
680 | return res; |
681 | } |
682 | |
683 | bool mesh::isAtLeastOneReductionIteratorSharded( |
684 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
685 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { |
686 | for (auto [loopIteratorType, meshAxisAssignment] : |
687 | llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { |
688 | if (loopIteratorType == utils::IteratorType::reduction && |
689 | !meshAxisAssignment.empty()) { |
690 | return true; |
691 | } |
692 | } |
693 | return false; |
694 | } |
695 | |
696 | SmallVector<MeshAxis> mesh::getReductionMeshAxes( |
697 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
698 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { |
699 | SmallVector<MeshAxis> meshAxes; |
700 | for (auto [loopIteratorType, meshAxisAssignment] : |
701 | llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { |
702 | if (loopIteratorType == utils::IteratorType::reduction) { |
703 | llvm::append_range(meshAxes, meshAxisAssignment); |
704 | } |
705 | } |
706 | return meshAxes; |
707 | } |
708 | |
709 | void mesh::spmdizeTriviallyShardableOperation( |
710 | Operation &op, ArrayRef<Value> spmdizedOperands, |
711 | ArrayRef<MeshSharding> operandShardings, |
712 | ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, |
713 | SymbolTableCollection &symbolTable, OpBuilder &builder) { |
714 | // `clone` will populate the mapping of old to new results. |
715 | Operation *newOp = builder.clone(op, mapper&: spmdizationMap); |
716 | // Set the result types to the sharded counterparts. |
717 | for (auto [oldResult, newResult, sharding] : |
718 | llvm::zip_equal(t: op.getResults(), u: newOp->getResults(), args&: resultShardings)) { |
719 | newResult.setType(shardType( |
720 | newResult.getType(), |
721 | getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding)); |
722 | } |
723 | } |
724 |
Definitions
- checkOperandAffineExprRecursively
- checkOperandAffineExpr
- fromArrayOfVector
- getMeshSharding
- getMeshSharding
- fillShardingOption
- defaultGetShardingOption
- getSharding
- getSharding
- defaultGetShardingAnnotations
- addShardOp
- addShardOp
- defaultAddShardingAnnotations
- isValueCompatibleWithFullReplicationSharding
- areValuesCompatibleWithFullReplicationShardings
- spmdizeFullyReplicatedOperation
- updateMeshAxisAssignmentForLoopIterators
- getMeshAxisAssignmentForLoopIterators
- isAtLeastOneReductionIteratorSharded
- getReductionMeshAxes
Improve your Profiling and Debugging skills
Find out more