1 | //===- ShardingInterface.cpp -------------------------------------*- C++-*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
10 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" |
11 | |
12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
13 | #include "mlir/IR/AffineMap.h" |
14 | #include "mlir/IR/IRMapping.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | #include "llvm/ADT/ArrayRef.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/ADT/SmallSet.h" |
19 | #include "llvm/Support/Debug.h" |
20 | |
21 | #include <utility> |
22 | |
23 | #define DEBUG_TYPE "sharding-interface" |
24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::mesh; |
28 | |
29 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // common util functions |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | static LogicalResult |
36 | checkOperandAffineExprRecursively(AffineExpr expr, |
37 | SmallVectorImpl<bool> &seenIds) { |
38 | switch (expr.getKind()) { |
39 | case AffineExprKind::Add: { |
40 | auto binOpExpr = cast<AffineBinaryOpExpr>(expr); |
41 | AffineExpr lhs = binOpExpr.getLHS(); |
42 | AffineExpr rhs = binOpExpr.getRHS(); |
43 | if (failed(result: checkOperandAffineExprRecursively(expr: lhs, seenIds))) |
44 | return failure(); |
45 | if (failed(result: checkOperandAffineExprRecursively(expr: rhs, seenIds))) |
46 | return failure(); |
47 | return success(); |
48 | } |
49 | case AffineExprKind::Mul: { |
50 | auto binOpExpr = cast<AffineBinaryOpExpr>(expr); |
51 | AffineExpr lhs = binOpExpr.getLHS(); |
52 | AffineExpr rhs = binOpExpr.getRHS(); |
53 | AffineExpr dimExpr; |
54 | if (lhs.getKind() == AffineExprKind::DimId && |
55 | rhs.getKind() == AffineExprKind::Constant) { |
56 | dimExpr = lhs; |
57 | } else if (rhs.getKind() == AffineExprKind::DimId && |
58 | lhs.getKind() == AffineExprKind::Constant) { |
59 | dimExpr = rhs; |
60 | } else |
61 | return failure(); |
62 | unsigned position = cast<AffineDimExpr>(dimExpr).getPosition(); |
63 | if ((size_t)position >= seenIds.size() || seenIds[position]) |
64 | return failure(); |
65 | seenIds[position] = true; |
66 | return success(); |
67 | } |
68 | case AffineExprKind::DimId: { |
69 | unsigned position = cast<AffineDimExpr>(expr).getPosition(); |
70 | if ((size_t)position >= seenIds.size() || seenIds[position]) |
71 | return failure(); |
72 | seenIds[position] = true; |
73 | return success(); |
74 | } |
75 | default: |
76 | return failure(); |
77 | } |
78 | } |
79 | |
80 | static FailureOr<llvm::SmallSet<unsigned, 2>> |
81 | checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { |
82 | SmallVector<bool> seenIds(numDims, false); |
83 | if (failed(result: checkOperandAffineExprRecursively(expr, seenIds))) |
84 | return failure(); |
85 | |
86 | llvm::SmallSet<unsigned, 2> positions; |
87 | for (auto it : llvm::enumerate(First&: seenIds)) { |
88 | if (it.value()) |
89 | positions.insert(V: (unsigned)it.index()); |
90 | } |
91 | return positions; |
92 | } |
93 | |
94 | //===----------------------------------------------------------------------===// |
95 | // mesh::getMeshShardingAttr |
96 | //===----------------------------------------------------------------------===// |
97 | |
98 | FailureOr<std::pair<bool, MeshShardingAttr>> |
99 | mesh::getMeshShardingAttr(OpResult result) { |
100 | Value val = cast<Value>(Val&: result); |
101 | bool anyShardedForDef = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) { |
102 | auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); |
103 | if (!shardOp) |
104 | return false; |
105 | return !shardOp.getAnnotateForUsers(); |
106 | }); |
107 | |
108 | if (anyShardedForDef) { |
109 | // expected to have exact one use if it has a use of `mesh.shard` without |
110 | // unit attr annotate_for_users |
111 | if (!val.hasOneUse()) |
112 | return failure(); |
113 | auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin()); |
114 | return std::make_pair(false, shardOp.getShard()); |
115 | } |
116 | |
117 | bool anyShardedForUsers = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) { |
118 | auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); |
119 | if (!shardOp) |
120 | return false; |
121 | return shardOp.getAnnotateForUsers(); |
122 | }); |
123 | if (anyShardedForUsers) { |
124 | SmallVector<ShardOp> shardOps; |
125 | for (Operation *user : val.getUsers()) { |
126 | ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); |
127 | if (shardOp) |
128 | shardOps.push_back(shardOp); |
129 | } |
130 | MeshShardingAttr shardForDef = shardOps[0].getShard(); |
131 | for (size_t i = 1; i < shardOps.size(); ++i) { |
132 | // TODO: Deduce a reasonable mesh sharding attr for def when they are |
133 | // different |
134 | assert(shardOps[i].getShard() == shardForDef && |
135 | "only support all shard ops have the same mesh sharding attr" ); |
136 | } |
137 | return std::make_pair(true, shardForDef); |
138 | } |
139 | return failure(); |
140 | } |
141 | |
142 | FailureOr<std::pair<bool, MeshShardingAttr>> |
143 | mesh::getMeshShardingAttr(OpOperand &opOperand) { |
144 | Value val = opOperand.get(); |
145 | if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) |
146 | return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard()); |
147 | |
148 | return failure(); |
149 | } |
150 | |
151 | //===----------------------------------------------------------------------===// |
152 | // ShardingInterface::verifyShardingInterfaceImpl |
153 | //===----------------------------------------------------------------------===// |
154 | |
155 | LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { |
156 | Operation *op = getOperation(); |
157 | |
158 | // check operands and results type |
159 | for (Type type : op->getOperandTypes()) |
160 | if (!llvm::isa<RankedTensorType>(type)) |
161 | return failure(); |
162 | for (Type type : op->getResultTypes()) |
163 | if (!llvm::isa<RankedTensorType>(type)) |
164 | return failure(); |
165 | |
166 | // check loop types |
167 | SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes(); |
168 | if (loopTypes.size() == 0) |
169 | return failure(); |
170 | |
171 | // check maps |
172 | SmallVector<AffineMap> maps = getIndexingMaps(); |
173 | if (maps.size() == 0) |
174 | return failure(); |
175 | unsigned numOperands = op->getNumOperands(); |
176 | unsigned numResults = op->getNumResults(); |
177 | if (numOperands + numResults != maps.size()) |
178 | return failure(); |
179 | |
180 | for (OpResult result : op->getResults()) { |
181 | auto resultType = dyn_cast<RankedTensorType>(result.getType()); |
182 | if (!resultType) |
183 | return failure(); |
184 | AffineMap map = maps[numOperands + result.getResultNumber()]; |
185 | if (!map.isProjectedPermutation()) { |
186 | return failure(); |
187 | } |
188 | } |
189 | |
190 | return success(); |
191 | } |
192 | |
193 | //===----------------------------------------------------------------------===// |
194 | // ShardingInterface::printLoopTypesAndIndexingMaps |
195 | //===----------------------------------------------------------------------===// |
196 | |
197 | void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { |
198 | os << "print loop types and indexing maps for: \n" ; |
199 | getOperation()->print(os); |
200 | os << "\n" ; |
201 | os << "loop types: [" ; |
202 | for (utils::IteratorType type : getLoopIteratorTypes()) { |
203 | os << stringifyEnum(type) << " " ; |
204 | } |
205 | os << "]\n" ; |
206 | os << "indexing maps: \n" ; |
207 | for (AffineMap map : getIndexingMaps()) |
208 | os << map << "\n" ; |
209 | os << "\n" ; |
210 | } |
211 | |
212 | //===----------------------------------------------------------------------===// |
213 | // detail::defaultGetShardingOption |
214 | //===----------------------------------------------------------------------===// |
215 | |
216 | namespace { |
217 | |
218 | // Update the given `shardingOption` according to `meshAxes` and `loopIdx` |
219 | static LogicalResult fillShardingOption(Operation *op, |
220 | ShardingOption &shardingOption, |
221 | FlatSymbolRefAttr mesh, |
222 | ArrayRef<MeshAxis> meshAxes, |
223 | unsigned loopIdx) { |
224 | if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || |
225 | (!shardingOption.shardingArray[loopIdx].empty() && |
226 | shardingOption.shardingArray[loopIdx] != meshAxes)) { |
227 | LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " |
228 | << loopIdx << "\n" ); |
229 | return failure(); |
230 | } |
231 | for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) { |
232 | if (i == loopIdx) |
233 | continue; |
234 | |
235 | for (MeshAxis axis : meshAxes) { |
236 | if (llvm::is_contained(Range&: shardingOption.shardingArray[i], Element: axis)) { |
237 | LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " |
238 | << axis << " duplicate" ); |
239 | return failure(); |
240 | } |
241 | } |
242 | } |
243 | if (mesh) |
244 | shardingOption.mesh = mesh; |
245 | if (shardingOption.shardingArray[loopIdx].empty()) |
246 | shardingOption.shardingArray[loopIdx].append(in_start: meshAxes.begin(), |
247 | in_end: meshAxes.end()); |
248 | return success(); |
249 | } |
250 | |
251 | } // namespace |
252 | |
253 | FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption( |
254 | Operation *op, ArrayRef<MeshShardingAttr> operandShardings, |
255 | ArrayRef<MeshShardingAttr> resultShardings) { |
256 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
257 | ShardingOption shardingOption; |
258 | |
259 | if (failed(shardingOp.verifyShardingInterfaceImpl())) |
260 | return op->emitOpError() << "invalid sharding interface implementation" ; |
261 | SmallVector<utils::IteratorType> loopTypes = |
262 | shardingOp.getLoopIteratorTypes(); |
263 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
264 | unsigned numOperands = op->getNumOperands(); |
265 | shardingOption.shardingArray.resize(loopTypes.size()); |
266 | llvm::SmallVector<MeshAxis> partialMeshAxes; |
267 | llvm::SmallSet<unsigned, 4> visitedLoopIndices; |
268 | bool anyShardingInResultsOrOperands = false; |
269 | |
270 | // 1. Fill sharding option based on op results |
271 | for (auto shardingIt : llvm::enumerate(resultShardings)) { |
272 | MeshShardingAttr shardAttr = shardingIt.value(); |
273 | if (!shardAttr) |
274 | continue; |
275 | AffineMap map = maps[numOperands + shardingIt.index()]; |
276 | anyShardingInResultsOrOperands = true; |
277 | // Handle the split axes: calculate the corresponding loop index for each |
278 | // split axes sub-array, and then store the sub-array to |
279 | // shardingOption[index] |
280 | for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { |
281 | AffineExpr expr = std::get<0>(it); |
282 | ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); |
283 | auto dim = cast<AffineDimExpr>(expr); |
284 | unsigned index = dim.getPosition(); |
285 | visitedLoopIndices.insert(index); |
286 | if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(), |
287 | axes, index))) |
288 | return failure(); |
289 | } |
290 | |
291 | // Handle the partial axes: at this stage, the exact loop index/indices |
292 | // cannot be decided because there could be multiple reduction loops. |
293 | ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes(); |
294 | if (!partialAxes.empty()) { |
295 | if (!partialMeshAxes.empty()) |
296 | return op->emitOpError() << "at most one result with partial axes is " |
297 | "supported at present" ; |
298 | partialMeshAxes.append(partialAxes.begin(), partialAxes.end()); |
299 | // Add all the reduction loop indices to `visitedLoopIndices` if |
300 | // `partialAxes` is not empty |
301 | for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) { |
302 | if (isReductionLoop(loopTypes[loopIdx])) |
303 | visitedLoopIndices.insert(loopIdx); |
304 | } |
305 | } |
306 | } |
307 | |
308 | // 2. Fill sharding option based on operands |
309 | for (auto shardingIt : llvm::enumerate(operandShardings)) { |
310 | MeshShardingAttr shardAttr = shardingIt.value(); |
311 | if (!shardAttr) |
312 | continue; |
313 | |
314 | anyShardingInResultsOrOperands = true; |
315 | AffineMap map = maps[shardingIt.index()]; |
316 | unsigned numDims = map.getNumDims(); |
317 | |
318 | // Handle the split axes. Partial axes don't need to be handled because they |
319 | // only affect the defining op of the operand. |
320 | // |
321 | // TODO: Change to process the operands with single loop index first and |
322 | // then the operands with multiple loop indices. |
323 | for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { |
324 | AffineExpr expr = std::get<0>(it); |
325 | ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); |
326 | FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = |
327 | checkOperandAffineExpr(expr, numDims); |
328 | if (failed(loopIndices)) |
329 | return op->emitOpError() |
330 | << "operand's affine expression is restricted to const_i * " |
331 | "dim_i + const_j + dim_j + ..." ; |
332 | if (loopIndices->empty()) |
333 | continue; |
334 | if (loopIndices->size() == 1) { |
335 | unsigned loopIdx = *loopIndices->begin(); |
336 | visitedLoopIndices.insert(loopIdx); |
337 | if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(), |
338 | axes, loopIdx))) |
339 | return failure(); |
340 | } |
341 | // If multiple loop indices correspond to a dimension of an operand, it is |
342 | // difficult to infer which loop indices are responsible for sharding. |
343 | // Therefore, the exact loop index must be specified by others. |
344 | if (loopIndices->size() > 1) { |
345 | bool seenLoopIndices = false; |
346 | for (unsigned loopIdx : *loopIndices) { |
347 | if (visitedLoopIndices.contains(loopIdx)) { |
348 | seenLoopIndices = true; |
349 | break; |
350 | } |
351 | } |
352 | if (!seenLoopIndices) |
353 | return op->emitOpError() |
354 | << "the operand " << shardingIt.index() |
355 | << " has multiple loop indices in a dimension, but none of " |
356 | "them could be found in the exactly specified annotation " |
357 | "of op results or operands." ; |
358 | } |
359 | } |
360 | } |
361 | |
362 | // 3. Finalize sharding option |
363 | if (!partialMeshAxes.empty()) { |
364 | bool anyNonEmptyReductionLoop = llvm::any_of( |
365 | Range: llvm::enumerate(First&: shardingOption.shardingArray), P: [&](auto it) { |
366 | SmallVector<MeshAxis> &subArray = it.value(); |
367 | int64_t idx = it.index(); |
368 | return isReductionLoop(loopTypes[idx]) && !subArray.empty(); |
369 | }); |
370 | if (!anyNonEmptyReductionLoop) { |
371 | bool filled = false; |
372 | for (size_t idx = 0; idx < loopTypes.size(); ++idx) { |
373 | if (isReductionLoop(loopTypes[idx])) { |
374 | std::ignore = fillShardingOption(op, shardingOption, nullptr, |
375 | partialMeshAxes, idx); |
376 | filled = true; |
377 | break; |
378 | } |
379 | } |
380 | if (!filled) |
381 | return op->emitOpError() << "no matched reduction loop found for the " |
382 | "result's partial type" ; |
383 | } |
384 | } |
385 | removeTrailingEmptySubArray(array&: shardingOption.shardingArray); |
386 | if (!anyShardingInResultsOrOperands) |
387 | shardingOption.empty = true; |
388 | return shardingOption; |
389 | } |
390 | |
391 | //===----------------------------------------------------------------------===// |
392 | // detail::defaultAddShardingAnnotations |
393 | //===----------------------------------------------------------------------===// |
394 | |
395 | // To add a `mesh.shard` op for the given result, based on the details provided |
396 | // in `shardingOption`, `map`, and `loopTypes`. |
397 | static LogicalResult addShardOp(OpBuilder &b, OpResult result, |
398 | const ShardingOption &shardingOption, |
399 | AffineMap map, |
400 | ArrayRef<utils::IteratorType> loopTypes, |
401 | ArrayRef<ReductionKind> reductionLoopKinds) { |
402 | FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding = |
403 | getMeshShardingAttr(result); |
404 | if (succeeded(maybeSharding) && !maybeSharding->first) |
405 | return success(); |
406 | |
407 | auto resultType = cast<RankedTensorType>(result.getType()); |
408 | SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank()); |
409 | SmallVector<MeshAxis> partialAxes; |
410 | |
411 | // process the split axes |
412 | for (auto it : llvm::enumerate(First: map.getResults())) { |
413 | AffineExpr expr = it.value(); |
414 | // `expr` must be an `AffineDimExpr` because `map` is verified by |
415 | // isProjectedPermutation |
416 | auto dim = cast<AffineDimExpr>(Val&: expr); |
417 | unsigned loopIdx = dim.getPosition(); |
418 | if (loopIdx < shardingOption.shardingArray.size()) |
419 | splitAxes[it.index()].append(RHS: shardingOption.shardingArray[loopIdx]); |
420 | } |
421 | |
422 | // process the partial axes |
423 | // partialType will be ignored if partialAxes is empty |
424 | ReductionKind partialType = ReductionKind::Sum; |
425 | size_t reductionLoopKindsIdx = 0; |
426 | for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) { |
427 | utils::IteratorType iType = std::get<0>(it); |
428 | if (isReductionLoop(iType)) { |
429 | ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx]; |
430 | ++reductionLoopKindsIdx; |
431 | if (!partialAxes.empty()) |
432 | assert(partialType == curPartialType && |
433 | "Only one reduction type is supported" ); |
434 | partialType = curPartialType; |
435 | const SmallVector<MeshAxis> &axis = std::get<1>(it); |
436 | partialAxes.append(axis); |
437 | } |
438 | } |
439 | |
440 | removeTrailingEmptySubArray(array&: splitAxes); |
441 | MeshShardingAttr shardAttr = MeshShardingAttr::get( |
442 | b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType); |
443 | OpBuilder::InsertionGuard guard(b); |
444 | b.setInsertionPointAfterValue(result); |
445 | auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result, |
446 | shardAttr, /*annotate_for_users*/ false); |
447 | result.replaceAllUsesExcept(shardOp, shardOp); |
448 | return success(); |
449 | } |
450 | |
451 | // To add a `mesh.shard` op for the given operand, based on the details provided |
452 | // in `shardingOption`, `map`, and `loopTypes`. |
453 | static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, |
454 | const ShardingOption &shardingOption, |
455 | AffineMap map) { |
456 | auto maybeShardingAttr = getMeshShardingAttr(opOperand); |
457 | if (succeeded(maybeShardingAttr) && maybeShardingAttr->first) |
458 | return success(); |
459 | Value operand = opOperand.get(); |
460 | auto operandType = cast<RankedTensorType>(operand.getType()); |
461 | SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank()); |
462 | unsigned numDims = map.getNumDims(); |
463 | for (auto it : llvm::enumerate(First: map.getResults())) { |
464 | int64_t idx = it.index(); |
465 | AffineExpr expr = it.value(); |
466 | FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = |
467 | checkOperandAffineExpr(expr, numDims); |
468 | if (failed(result: loopIndices)) |
469 | return failure(); |
470 | SmallVector<unsigned> shardedLoopIndices; |
471 | for (unsigned loopIdx : *loopIndices) { |
472 | if ((size_t)loopIdx < shardingOption.shardingArray.size() && |
473 | !shardingOption.shardingArray[loopIdx].empty()) |
474 | shardedLoopIndices.push_back(Elt: loopIdx); |
475 | } |
476 | // mostly one sharded loop index is accepted |
477 | if (shardedLoopIndices.size() > 1) |
478 | return failure(); |
479 | if (shardedLoopIndices.size() == 1) { |
480 | splitAxes[idx].append( |
481 | RHS: shardingOption.shardingArray[shardedLoopIndices[0]]); |
482 | } |
483 | } |
484 | |
485 | removeTrailingEmptySubArray(array&: splitAxes); |
486 | MeshShardingAttr shardAttr = |
487 | MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes); |
488 | OpBuilder::InsertionGuard guard(b); |
489 | b.setInsertionPoint(opOperand.getOwner()); |
490 | auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand, |
491 | shardAttr, true); |
492 | opOperand.set(shardOp); |
493 | |
494 | return success(); |
495 | } |
496 | |
497 | LogicalResult mesh::detail::defaultAddShardingAnnotations( |
498 | Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { |
499 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
500 | SmallVector<utils::IteratorType> loopTypes = |
501 | shardingOp.getLoopIteratorTypes(); |
502 | SmallVector<ReductionKind> reductionKinds = |
503 | shardingOp.getReductionLoopIteratorKinds(); |
504 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
505 | unsigned numOperands = op->getNumOperands(); |
506 | |
507 | // 1. add mesh.shard ops for all op results |
508 | for (OpResult result : op->getResults()) { |
509 | if (failed(addShardOp(b, result, shardingOption, |
510 | maps[numOperands + result.getResultNumber()], |
511 | loopTypes, reductionKinds))) |
512 | return failure(); |
513 | } |
514 | |
515 | // 2. add mesh.shard ops for all operands |
516 | for (OpOperand &opOperand : op->getOpOperands()) { |
517 | if (failed(result: addShardOp(b, opOperand, shardingOption, |
518 | map: maps[opOperand.getOperandNumber()]))) |
519 | return failure(); |
520 | } |
521 | |
522 | return success(); |
523 | } |
524 | |
525 | #ifndef NDEBUG |
526 | static bool |
527 | isValueCompatibleWithFullReplicationSharding(Value value, |
528 | MeshShardingAttr sharding) { |
529 | if (isa<RankedTensorType>(Val: value.getType())) { |
530 | return sharding && isFullReplication(sharding); |
531 | } |
532 | |
533 | return !sharding; |
534 | } |
535 | |
536 | template <typename ValueRange, typename MeshShardingAttrRage> |
537 | static bool areValuesCompatibleWithFullReplicationShardings( |
538 | ValueRange &&values, MeshShardingAttrRage &&shardings) { |
539 | if (std::size(values) != std::size(shardings)) { |
540 | return false; |
541 | } |
542 | return llvm::all_of(llvm::zip_equal( |
543 | std::forward<ValueRange>(values), |
544 | std::forward<MeshShardingAttrRage>(shardings)), |
545 | [](auto valueAndSharding) { |
546 | return isValueCompatibleWithFullReplicationSharding( |
547 | std::get<0>(valueAndSharding), |
548 | std::get<1>(valueAndSharding)); |
549 | }); |
550 | } |
551 | #endif // NDEBUG |
552 | |
553 | void mesh::spmdizeFullyReplicatedOperation( |
554 | Operation &op, ArrayRef<Value> spmdizedOperands, |
555 | ArrayRef<MeshShardingAttr> operandShardings, |
556 | ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, |
557 | SymbolTableCollection &symbolTable, OpBuilder &builder) { |
558 | assert(spmdizedOperands.size() == operandShardings.size()); |
559 | assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), |
560 | operandShardings)); |
561 | assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), |
562 | resultShardings)); |
563 | // `clone` will populate the mapping of old to new results. |
564 | builder.clone(op, mapper&: spmdizationMap); |
565 | } |
566 | |
567 | static void updateMeshAxisAssignmentForLoopIterators( |
568 | ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, |
569 | SmallVector<std::optional<SmallVector<MeshAxis>>> |
570 | &meshAxesAssignmentForLoopIterators) { |
571 | AffineDimExpr affineDimExpr = cast<AffineDimExpr>(Val&: indexingExpr); |
572 | unsigned loopIteratorIdx = affineDimExpr.getPosition(); |
573 | if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { |
574 | assert(llvm::equal(meshAxesAssignmentForTensorAxis, |
575 | *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); |
576 | } else { |
577 | meshAxesAssignmentForLoopIterators[loopIteratorIdx] = |
578 | llvm::to_vector(Range&: meshAxesAssignmentForTensorAxis); |
579 | } |
580 | } |
581 | |
582 | ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( |
583 | ArrayRef<MeshShardingAttr> operandShardings, |
584 | ArrayRef<MeshShardingAttr> resultShardings, |
585 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
586 | ArrayRef<AffineMap> indexingMaps) { |
587 | SmallVector<std::optional<SmallVector<MeshAxis>>> |
588 | meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); |
589 | SmallVector<MeshShardingAttr> operatorAndResultShardings; |
590 | operatorAndResultShardings.reserve(operandShardings.size() + |
591 | resultShardings.size()); |
592 | llvm::append_range(operatorAndResultShardings, operandShardings); |
593 | for (auto [sharding, affineMap] : |
594 | llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { |
595 | if (!sharding) { |
596 | continue; |
597 | } |
598 | for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : |
599 | llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { |
600 | updateMeshAxisAssignmentForLoopIterators( |
601 | meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, |
602 | meshAxisAssignmentForLoopIterators); |
603 | } |
604 | // Missing trailing split axes means replication on those tensor dimensions. |
605 | for (unsigned i = sharding.getSplitAxes().size(); |
606 | i < affineMap.getNumResults(); ++i) { |
607 | updateMeshAxisAssignmentForLoopIterators( |
608 | {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); |
609 | } |
610 | } |
611 | |
612 | ShardingArray res; |
613 | llvm::transform(Range&: meshAxisAssignmentForLoopIterators, d_first: std::back_inserter(x&: res), |
614 | F: [](std::optional<SmallVector<MeshAxis>> &axes) { |
615 | if (!axes) { |
616 | return SmallVector<MeshAxis>(); |
617 | }; |
618 | return std::move(*axes); |
619 | }); |
620 | return res; |
621 | } |
622 | |
623 | bool mesh::isAtLeastOneReductionIteratorSharded( |
624 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
625 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { |
626 | for (auto [loopIteratorType, meshAxisAssignment] : |
627 | llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { |
628 | if (loopIteratorType == utils::IteratorType::reduction && |
629 | !meshAxisAssignment.empty()) { |
630 | return true; |
631 | } |
632 | } |
633 | return false; |
634 | } |
635 | |
636 | SmallVector<MeshAxis> mesh::getReductionMeshAxes( |
637 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
638 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { |
639 | SmallVector<MeshAxis> meshAxes; |
640 | for (auto [loopIteratorType, meshAxisAssignment] : |
641 | llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { |
642 | if (loopIteratorType == utils::IteratorType::reduction) { |
643 | llvm::append_range(meshAxes, meshAxisAssignment); |
644 | } |
645 | } |
646 | return meshAxes; |
647 | } |
648 | |
649 | void mesh::spmdizeTriviallyShardableOperation( |
650 | Operation &op, ArrayRef<Value> spmdizedOperands, |
651 | ArrayRef<MeshShardingAttr> operandShardings, |
652 | ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, |
653 | SymbolTableCollection &symbolTable, OpBuilder &builder) { |
654 | // `clone` will populate the mapping of old to new results. |
655 | Operation *newOp = builder.clone(op, mapper&: spmdizationMap); |
656 | // Set the result types to the sharded counterparts. |
657 | for (auto [oldResult, newResult, sharding] : |
658 | llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { |
659 | newResult.setType(shardType(newResult.getType(), |
660 | getMesh(&op, sharding.getMesh(), symbolTable), |
661 | sharding)); |
662 | } |
663 | } |
664 | |