1//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallSet.h"
19#include "llvm/Support/Debug.h"
20
21#include <utility>
22
23#define DEBUG_TYPE "sharding-interface"
24#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25
26using namespace mlir;
27using namespace mlir::mesh;
28
29#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30
31//===----------------------------------------------------------------------===//
32// common util functions
33//===----------------------------------------------------------------------===//
34
35static LogicalResult
36checkOperandAffineExprRecursively(AffineExpr expr,
37 SmallVectorImpl<bool> &seenIds) {
38 switch (expr.getKind()) {
39 case AffineExprKind::Add: {
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41 AffineExpr lhs = binOpExpr.getLHS();
42 AffineExpr rhs = binOpExpr.getRHS();
43 if (failed(Result: checkOperandAffineExprRecursively(expr: lhs, seenIds)))
44 return failure();
45 if (failed(Result: checkOperandAffineExprRecursively(expr: rhs, seenIds)))
46 return failure();
47 return success();
48 }
49 case AffineExprKind::Mul: {
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51 AffineExpr lhs = binOpExpr.getLHS();
52 AffineExpr rhs = binOpExpr.getRHS();
53 AffineExpr dimExpr;
54 if (lhs.getKind() == AffineExprKind::DimId &&
55 rhs.getKind() == AffineExprKind::Constant) {
56 dimExpr = lhs;
57 } else if (rhs.getKind() == AffineExprKind::DimId &&
58 lhs.getKind() == AffineExprKind::Constant) {
59 dimExpr = rhs;
60 } else {
61 return failure();
62 }
63 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
64 if ((size_t)position >= seenIds.size() || seenIds[position])
65 return failure();
66 seenIds[position] = true;
67 return success();
68 }
69 case AffineExprKind::DimId: {
70 unsigned position = cast<AffineDimExpr>(expr).getPosition();
71 if ((size_t)position >= seenIds.size() || seenIds[position])
72 return failure();
73 seenIds[position] = true;
74 return success();
75 }
76 default:
77 return failure();
78 }
79}
80
81static FailureOr<llvm::SmallSet<unsigned, 2>>
82checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
83 SmallVector<bool> seenIds(numDims, false);
84 if (failed(Result: checkOperandAffineExprRecursively(expr, seenIds)))
85 return failure();
86
87 llvm::SmallSet<unsigned, 2> positions;
88 for (auto it : llvm::enumerate(First&: seenIds)) {
89 if (it.value())
90 positions.insert(V: (unsigned)it.index());
91 }
92 return positions;
93}
94
95template <typename T>
96SmallVector<MeshAxesAttr>
97fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
98 SmallVector<MeshAxesAttr> res;
99 for (const auto &v : vec) {
100 res.emplace_back(MeshAxesAttr::get(ctxt, v));
101 }
102 return res;
103}
104
105//===----------------------------------------------------------------------===//
106// mesh::getMeshSharding
107//===----------------------------------------------------------------------===//
108
109FailureOr<std::pair<bool, MeshSharding>>
110mesh::getMeshSharding(OpResult result) {
111 Value val = cast<Value>(Val&: result);
112 bool anyShardedForDef = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) {
113 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
114 if (!shardOp)
115 return false;
116 return !shardOp.getAnnotateForUsers();
117 });
118
119 if (anyShardedForDef) {
120 // expected to have exact one use if it has a use of `mesh.shard` without
121 // unit attr annotate_for_users
122 if (!val.hasOneUse())
123 return failure();
124 auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
125 return std::make_pair(x: false, y: MeshSharding(shardOp.getSharding()));
126 }
127
128 bool anyShardedForUsers = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) {
129 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
130 if (!shardOp)
131 return false;
132 return shardOp.getAnnotateForUsers();
133 });
134 if (anyShardedForUsers) {
135 SmallVector<ShardOp> shardOps;
136 for (Operation *user : val.getUsers()) {
137 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
138 if (shardOp)
139 shardOps.push_back(shardOp);
140 }
141 MeshSharding shardForDef = shardOps[0].getSharding();
142 for (size_t i = 1; i < shardOps.size(); ++i) {
143 // TODO: Deduce a reasonable mesh sharding attr for def when they are
144 // different
145 assert(shardForDef == shardOps[i].getSharding() &&
146 "only support all shard ops have the same mesh sharding attr");
147 }
148 return std::make_pair(x: true, y&: shardForDef);
149 }
150 return failure();
151}
152
153FailureOr<std::pair<bool, MeshSharding>>
154mesh::getMeshSharding(OpOperand &opOperand) {
155 Value val = opOperand.get();
156 if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
157 return std::make_pair(shardOp.getAnnotateForUsers(),
158 MeshSharding(shardOp.getSharding()));
159
160 return failure();
161}
162
163//===----------------------------------------------------------------------===//
164// ShardingInterface::verifyShardingInterfaceImpl
165//===----------------------------------------------------------------------===//
166
167LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
168 Operation *op = getOperation();
169
170 // check operands and results type
171 for (Type type : op->getOperandTypes())
172 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
173 return failure();
174 for (Type type : op->getResultTypes())
175 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
176 return failure();
177
178 // check maps
179 SmallVector<AffineMap> maps = getIndexingMaps();
180 if (maps.empty())
181 return failure();
182 unsigned numOperands = op->getNumOperands();
183 unsigned numResults = op->getNumResults();
184 if (numOperands + numResults != maps.size())
185 return failure();
186
187 for (OpResult result : op->getResults()) {
188 auto resultType = dyn_cast<RankedTensorType>(result.getType());
189 if (!resultType)
190 return failure();
191 AffineMap map = maps[numOperands + result.getResultNumber()];
192 if (!map.isProjectedPermutation()) {
193 return failure();
194 }
195 }
196
197 return success();
198}
199
200//===----------------------------------------------------------------------===//
201// ShardingInterface::printLoopTypesAndIndexingMaps
202//===----------------------------------------------------------------------===//
203
204void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
205 os << "print loop types and indexing maps for: \n";
206 getOperation()->print(os);
207 os << "\n";
208 os << "loop types: [";
209 for (utils::IteratorType type : getLoopIteratorTypes()) {
210 os << stringifyEnum(type) << " ";
211 }
212 os << "]\n";
213 os << "indexing maps: \n";
214 for (AffineMap map : getIndexingMaps())
215 os << map << "\n";
216 os << "\n";
217}
218
219//===----------------------------------------------------------------------===//
220// detail::defaultGetShardingOption
221//===----------------------------------------------------------------------===//
222
223namespace {
224
225// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
226static LogicalResult fillShardingOption(Operation *op,
227 ShardingOption &shardingOption,
228 FlatSymbolRefAttr mesh,
229 ArrayRef<MeshAxis> meshAxes,
230 unsigned loopIdx) {
231 if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
232 (!shardingOption.shardingArray[loopIdx].empty() &&
233 shardingOption.shardingArray[loopIdx] != meshAxes)) {
234 LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
235 << loopIdx << "\n");
236 return failure();
237 }
238 for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
239 if (i == loopIdx)
240 continue;
241
242 for (MeshAxis axis : meshAxes) {
243 if (llvm::is_contained(Range&: shardingOption.shardingArray[i], Element: axis)) {
244 LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
245 << axis << " duplicate");
246 return failure();
247 }
248 }
249 }
250 if (mesh)
251 shardingOption.mesh = mesh;
252 if (shardingOption.shardingArray[loopIdx].empty())
253 shardingOption.shardingArray[loopIdx].append(in_start: meshAxes.begin(),
254 in_end: meshAxes.end());
255 return success();
256}
257
258} // namespace
259
260FailureOr<ShardingOption>
261mesh::detail::defaultGetShardingOption(Operation *op,
262 ArrayRef<MeshSharding> operandShardings,
263 ArrayRef<MeshSharding> resultShardings) {
264 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
265 ShardingOption shardingOption;
266
267 if (failed(shardingOp.verifyShardingInterfaceImpl()))
268 return op->emitOpError() << "invalid sharding interface implementation";
269 SmallVector<utils::IteratorType> loopTypes =
270 shardingOp.getLoopIteratorTypes();
271 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
272 unsigned numOperands = op->getNumOperands();
273 shardingOption.shardingArray.resize(loopTypes.size());
274 llvm::SmallVector<MeshAxis> partialMeshAxes;
275 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
276 bool anyShardingInResultsOrOperands = false;
277
278 // 1. Fill sharding option based on op results
279 for (auto shardingIt : llvm::enumerate(First&: resultShardings)) {
280 MeshSharding shardAttr = shardingIt.value();
281 if (!shardAttr)
282 continue;
283 AffineMap map = maps[numOperands + shardingIt.index()];
284 anyShardingInResultsOrOperands = true;
285 if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
286 shardingOption.mesh = shardAttr.getMeshAttr();
287 } else {
288 // Handle the split axes: calculate the corresponding loop index for each
289 // split axes sub-array, and then store the sub-array to
290 // shardingOption[index]
291 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
292 AffineExpr expr = std::get<0>(it);
293 ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
294 auto dim = cast<AffineDimExpr>(expr);
295 unsigned index = dim.getPosition();
296 visitedLoopIndices.insert(index);
297 if (failed(fillShardingOption(op, shardingOption,
298 shardAttr.getMeshAttr(), axes, index)))
299 return failure();
300 }
301 }
302
303 // Handle the partial axes: at this stage, the exact loop index/indices
304 // cannot be decided because there could be multiple reduction loops.
305 ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
306 if (!partialAxes.empty()) {
307 if (!partialMeshAxes.empty())
308 return op->emitOpError() << "at most one result with partial axes is "
309 "supported at present";
310 partialMeshAxes.append(in_start: partialAxes.begin(), in_end: partialAxes.end());
311 // Add all the reduction loop indices to `visitedLoopIndices` if
312 // `partialAxes` is not empty
313 for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
314 if (isReductionLoop(loopTypes[loopIdx]))
315 visitedLoopIndices.insert(V: loopIdx);
316 }
317 }
318 }
319
320 // 2. Fill sharding option based on operands
321 for (auto shardingIt : llvm::enumerate(First&: operandShardings)) {
322 MeshSharding shardAttr = shardingIt.value();
323 if (!shardAttr)
324 continue;
325
326 anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
327 AffineMap map = maps[shardingIt.index()];
328 unsigned numDims = map.getNumDims();
329
330 // Handle the split axes. Partial axes don't need to be handled because they
331 // only affect the defining op of the operand.
332 //
333 // TODO: Change to process the operands with single loop index first and
334 // then the operands with multiple loop indices.
335 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
336 AffineExpr expr = std::get<0>(it);
337 ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
338 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
339 checkOperandAffineExpr(expr, numDims);
340 if (failed(loopIndices))
341 return op->emitOpError()
342 << "operand's affine expression is restricted to const_i * "
343 "dim_i + const_j + dim_j + ...";
344 if (loopIndices->empty())
345 continue;
346 if (loopIndices->size() == 1) {
347 unsigned loopIdx = *loopIndices->begin();
348 visitedLoopIndices.insert(loopIdx);
349 if (failed(fillShardingOption(op, shardingOption,
350 shardAttr.getMeshAttr(), axes, loopIdx)))
351 return failure();
352 }
353 // If multiple loop indices correspond to a dimension of an operand, it is
354 // difficult to infer which loop indices are responsible for sharding.
355 // Therefore, the exact loop index must be specified by others.
356 if (loopIndices->size() > 1) {
357 bool seenLoopIndices = false;
358 for (unsigned loopIdx : *loopIndices) {
359 if (visitedLoopIndices.contains(loopIdx)) {
360 seenLoopIndices = true;
361 break;
362 }
363 }
364 if (!seenLoopIndices)
365 return op->emitOpError()
366 << "the operand " << shardingIt.index()
367 << " has multiple loop indices in a dimension, but none of "
368 "them could be found in the exactly specified annotation "
369 "of op results or operands.";
370 }
371 }
372 }
373
374 // 3. Finalize sharding option
375 if (!partialMeshAxes.empty()) {
376 bool anyNonEmptyReductionLoop = llvm::any_of(
377 Range: llvm::enumerate(First&: shardingOption.shardingArray), P: [&](auto it) {
378 SmallVector<MeshAxis> &subArray = it.value();
379 int64_t idx = it.index();
380 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381 });
382 if (!anyNonEmptyReductionLoop) {
383 bool filled = false;
384 for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
385 if (isReductionLoop(loopTypes[idx])) {
386 std::ignore = fillShardingOption(op, shardingOption, nullptr,
387 partialMeshAxes, idx);
388 filled = true;
389 break;
390 }
391 }
392 if (!filled)
393 return op->emitOpError() << "no matched reduction loop found for the "
394 "result's partial type";
395 }
396 }
397 removeTrailingEmptySubArray(array&: shardingOption.shardingArray);
398 if (!anyShardingInResultsOrOperands)
399 shardingOption.empty = true;
400 return shardingOption;
401}
402
403// Get the sharding attributed for the given result and sharding option.
404MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
405 AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
406 ArrayRef<ReductionKind> reductionLoopKinds) {
407 auto resultType = cast<RankedTensorType>(result.getType());
408 SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
409 SmallVector<MeshAxis> partialAxes;
410
411 // process the split axes
412 for (auto it : llvm::enumerate(First: map.getResults())) {
413 AffineExpr expr = it.value();
414 // `expr` must be an `AffineDimExpr` because `map` is verified by
415 // isProjectedPermutation
416 auto dim = cast<AffineDimExpr>(Val&: expr);
417 unsigned loopIdx = dim.getPosition();
418 if (loopIdx < shardingOption.shardingArray.size())
419 splitAxes[it.index()].append(RHS: shardingOption.shardingArray[loopIdx]);
420 }
421
422 // process the partial axes
423 // partialType will be ignored if partialAxes is empty
424 ReductionKind partialType = ReductionKind::Sum;
425 size_t reductionLoopKindsIdx = 0;
426 for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
427 utils::IteratorType iType = std::get<0>(it);
428 if (isReductionLoop(iType)) {
429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430 ++reductionLoopKindsIdx;
431 if (!partialAxes.empty())
432 assert(partialType == curPartialType &&
433 "Only one reduction type is supported");
434 partialType = curPartialType;
435 const SmallVector<MeshAxis> &axis = std::get<1>(it);
436 partialAxes.append(axis);
437 }
438 }
439
440 removeTrailingEmptySubArray(array&: splitAxes);
441 return MeshSharding::get(shardingOption.mesh,
442 fromArrayOfVector(result.getContext(), splitAxes),
443 partialAxes, partialType);
444}
445
446static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
447 const ShardingOption &shardingOption,
448 AffineMap map) {
449 Value operandValue = opOperand.get();
450 auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
451 if (!operandType) {
452 if (operandValue.getType().isIntOrIndexOrFloat())
453 return MeshSharding();
454 return failure();
455 }
456 // 0d tensors cannot be sharded and must get replicated
457 if (operandType.getRank() == 0) {
458 return MeshSharding(shardingOption.mesh);
459 }
460 SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
461 unsigned numDims = map.getNumDims();
462 for (auto it : llvm::enumerate(First: map.getResults())) {
463 int64_t idx = it.index();
464 AffineExpr expr = it.value();
465 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
466 checkOperandAffineExpr(expr, numDims);
467 if (failed(Result: loopIndices))
468 return failure();
469 SmallVector<unsigned> shardedLoopIndices;
470 for (unsigned loopIdx : *loopIndices) {
471 if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
472 !shardingOption.shardingArray[loopIdx].empty())
473 shardedLoopIndices.push_back(Elt: loopIdx);
474 }
475 // mostly one sharded loop index is accepted
476 if (shardedLoopIndices.size() > 1)
477 return failure();
478 if (shardedLoopIndices.size() == 1) {
479 splitAxes[idx].append(
480 RHS: shardingOption.shardingArray[shardedLoopIndices[0]]);
481 }
482 }
483
484 removeTrailingEmptySubArray(array&: splitAxes);
485 return MeshSharding::get(
486 shardingOption.mesh,
487 fromArrayOfVector(opOperand.get().getContext(), splitAxes));
488}
489
490FailureOr<std::vector<MeshSharding>>
491mesh::detail::defaultGetShardingAnnotations(
492 Operation *op, const ShardingOption &shardingOption) {
493 std::vector<MeshSharding> res;
494
495 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
496 SmallVector<utils::IteratorType> loopTypes =
497 shardingOp.getLoopIteratorTypes();
498 SmallVector<ReductionKind> reductionKinds =
499 shardingOp.getReductionLoopIteratorKinds();
500 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
501 unsigned numOperands = op->getNumOperands();
502
503 for (OpOperand &opOperand : op->getOpOperands()) {
504 FailureOr<MeshSharding> shardingAttr = getSharding(
505 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
506 if (failed(Result: shardingAttr))
507 return failure();
508 res.push_back(*shardingAttr);
509 }
510
511 for (OpResult result : op->getResults()) {
512 res.push_back(getSharding(result, shardingOption,
513 maps[numOperands + result.getResultNumber()],
514 loopTypes, reductionKinds));
515 }
516
517 return res;
518}
519
520//===----------------------------------------------------------------------===//
521// detail::defaultAddShardingAnnotations
522//===----------------------------------------------------------------------===//
523
524// To add a `mesh.shard` op for the given result, based on the details provided
525// in `shardingOption`, `map`, and `loopTypes`.
526static LogicalResult addShardOp(OpBuilder &b, OpResult result,
527 const ShardingOption &shardingOption,
528 AffineMap map,
529 ArrayRef<utils::IteratorType> loopTypes,
530 ArrayRef<ReductionKind> reductionLoopKinds) {
531 MeshSharding sharding =
532 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
533 maybeInsertTargetShardingAnnotation(sharding, result, builder&: b);
534
535 return success();
536}
537
538// To add a `mesh.shard` op for the given operand, based on the details provided
539// in `shardingOption`, `map`, and `loopTypes`.
540static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
541 const ShardingOption &shardingOption,
542 AffineMap map) {
543
544 FailureOr<MeshSharding> sharding =
545 getSharding(opOperand, shardingOption, map);
546 if (failed(Result: sharding)) {
547 return failure();
548 }
549 OpBuilder::InsertionGuard guard(b);
550 maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
551
552 return success();
553}
554
555LogicalResult mesh::detail::defaultAddShardingAnnotations(
556 Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
557 assert(!shardingOption.empty && shardingOption.mesh);
558
559 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
560 SmallVector<utils::IteratorType> loopTypes =
561 shardingOp.getLoopIteratorTypes();
562 SmallVector<ReductionKind> reductionKinds =
563 shardingOp.getReductionLoopIteratorKinds();
564 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
565 unsigned numOperands = op->getNumOperands();
566
567 // 1. add mesh.shard ops for all op results
568 for (OpResult result : op->getResults()) {
569 if (failed(addShardOp(b, result, shardingOption,
570 maps[numOperands + result.getResultNumber()],
571 loopTypes, reductionKinds)))
572 return failure();
573 }
574
575 // 2. add mesh.shard ops for all operands
576 for (OpOperand &opOperand : op->getOpOperands()) {
577 if (failed(Result: addShardOp(b, opOperand, shardingOption,
578 map: maps[opOperand.getOperandNumber()])))
579 return failure();
580 }
581
582 return success();
583}
584
585#ifndef NDEBUG
586static bool
587isValueCompatibleWithFullReplicationSharding(Value value,
588 MeshSharding sharding) {
589 if (isa<RankedTensorType>(Val: value.getType())) {
590 return isFullReplication(sharding);
591 }
592
593 return !sharding;
594}
595
596template <typename ValueRange, typename MeshShardingRage>
597static bool
598areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
599 MeshShardingRage &&shardings) {
600 if (std::size(values) != std::size(shardings)) {
601 return false;
602 }
603 return llvm::all_of(
604 llvm::zip_equal(std::forward<ValueRange>(values),
605 std::forward<MeshShardingRage>(shardings)),
606 [](auto valueAndSharding) {
607 return isValueCompatibleWithFullReplicationSharding(
608 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
609 });
610}
611#endif // NDEBUG
612
613void mesh::spmdizeFullyReplicatedOperation(
614 Operation &op, ArrayRef<Value> spmdizedOperands,
615 ArrayRef<MeshSharding> operandShardings,
616 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
617 SymbolTableCollection &symbolTable, OpBuilder &builder) {
618 assert(spmdizedOperands.size() == operandShardings.size());
619 assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
620 operandShardings));
621 assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
622 resultShardings));
623 // `clone` will populate the mapping of old to new results.
624 builder.clone(op, mapper&: spmdizationMap);
625}
626
627static void updateMeshAxisAssignmentForLoopIterators(
628 ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
629 SmallVector<std::optional<SmallVector<MeshAxis>>>
630 &meshAxesAssignmentForLoopIterators) {
631 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(Val&: indexingExpr);
632 unsigned loopIteratorIdx = affineDimExpr.getPosition();
633 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
634 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
635 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
636 } else {
637 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
638 llvm::to_vector(Range&: meshAxesAssignmentForTensorAxis);
639 }
640}
641
642ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
643 ArrayRef<MeshSharding> operandShardings,
644 ArrayRef<MeshSharding> resultShardings,
645 ArrayRef<utils::IteratorType> loopIteratorTypes,
646 ArrayRef<AffineMap> indexingMaps) {
647 SmallVector<std::optional<SmallVector<MeshAxis>>>
648 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
649 std::vector<MeshSharding> operatorAndResultShardings;
650 operatorAndResultShardings.reserve(n: operandShardings.size() +
651 resultShardings.size());
652 llvm::append_range(C&: operatorAndResultShardings, R&: operandShardings);
653 for (auto [sharding, affineMap] :
654 llvm::zip_equal(t&: operatorAndResultShardings, u&: indexingMaps)) {
655 if (!sharding) {
656 continue;
657 }
658 for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
659 llvm::zip(t: sharding.getSplitAxes(), u: affineMap.getResults())) {
660 updateMeshAxisAssignmentForLoopIterators(
661 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
662 meshAxisAssignmentForLoopIterators);
663 }
664 // Missing trailing split axes means replication on those tensor dimensions.
665 for (unsigned i = sharding.getSplitAxes().size();
666 i < affineMap.getNumResults(); ++i) {
667 updateMeshAxisAssignmentForLoopIterators(
668 meshAxesAssignmentForTensorAxis: {}, indexingExpr: affineMap.getResults()[i], meshAxesAssignmentForLoopIterators&: meshAxisAssignmentForLoopIterators);
669 }
670 }
671
672 ShardingArray res;
673 llvm::transform(Range&: meshAxisAssignmentForLoopIterators, d_first: std::back_inserter(x&: res),
674 F: [](std::optional<SmallVector<MeshAxis>> &axes) {
675 if (!axes) {
676 return SmallVector<MeshAxis>();
677 };
678 return std::move(*axes);
679 });
680 return res;
681}
682
683bool mesh::isAtLeastOneReductionIteratorSharded(
684 ArrayRef<utils::IteratorType> loopIteratorTypes,
685 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
686 for (auto [loopIteratorType, meshAxisAssignment] :
687 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
688 if (loopIteratorType == utils::IteratorType::reduction &&
689 !meshAxisAssignment.empty()) {
690 return true;
691 }
692 }
693 return false;
694}
695
696SmallVector<MeshAxis> mesh::getReductionMeshAxes(
697 ArrayRef<utils::IteratorType> loopIteratorTypes,
698 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
699 SmallVector<MeshAxis> meshAxes;
700 for (auto [loopIteratorType, meshAxisAssignment] :
701 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
702 if (loopIteratorType == utils::IteratorType::reduction) {
703 llvm::append_range(meshAxes, meshAxisAssignment);
704 }
705 }
706 return meshAxes;
707}
708
709void mesh::spmdizeTriviallyShardableOperation(
710 Operation &op, ArrayRef<Value> spmdizedOperands,
711 ArrayRef<MeshSharding> operandShardings,
712 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
713 SymbolTableCollection &symbolTable, OpBuilder &builder) {
714 // `clone` will populate the mapping of old to new results.
715 Operation *newOp = builder.clone(op, mapper&: spmdizationMap);
716 // Set the result types to the sharded counterparts.
717 for (auto [oldResult, newResult, sharding] :
718 llvm::zip_equal(t: op.getResults(), u: newOp->getResults(), args&: resultShardings)) {
719 newResult.setType(shardType(
720 newResult.getType(),
721 getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
722 }
723}
724

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp