1//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/ArrayRef.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallSet.h"
19#include "llvm/Support/Debug.h"
20
21#include <utility>
22
23#define DEBUG_TYPE "sharding-interface"
24#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25
26using namespace mlir;
27using namespace mlir::mesh;
28
29#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30
31//===----------------------------------------------------------------------===//
32// common util functions
33//===----------------------------------------------------------------------===//
34
35static LogicalResult
36checkOperandAffineExprRecursively(AffineExpr expr,
37 SmallVectorImpl<bool> &seenIds) {
38 switch (expr.getKind()) {
39 case AffineExprKind::Add: {
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
41 AffineExpr lhs = binOpExpr.getLHS();
42 AffineExpr rhs = binOpExpr.getRHS();
43 if (failed(result: checkOperandAffineExprRecursively(expr: lhs, seenIds)))
44 return failure();
45 if (failed(result: checkOperandAffineExprRecursively(expr: rhs, seenIds)))
46 return failure();
47 return success();
48 }
49 case AffineExprKind::Mul: {
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
51 AffineExpr lhs = binOpExpr.getLHS();
52 AffineExpr rhs = binOpExpr.getRHS();
53 AffineExpr dimExpr;
54 if (lhs.getKind() == AffineExprKind::DimId &&
55 rhs.getKind() == AffineExprKind::Constant) {
56 dimExpr = lhs;
57 } else if (rhs.getKind() == AffineExprKind::DimId &&
58 lhs.getKind() == AffineExprKind::Constant) {
59 dimExpr = rhs;
60 } else
61 return failure();
62 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63 if ((size_t)position >= seenIds.size() || seenIds[position])
64 return failure();
65 seenIds[position] = true;
66 return success();
67 }
68 case AffineExprKind::DimId: {
69 unsigned position = cast<AffineDimExpr>(expr).getPosition();
70 if ((size_t)position >= seenIds.size() || seenIds[position])
71 return failure();
72 seenIds[position] = true;
73 return success();
74 }
75 default:
76 return failure();
77 }
78}
79
80static FailureOr<llvm::SmallSet<unsigned, 2>>
81checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
82 SmallVector<bool> seenIds(numDims, false);
83 if (failed(result: checkOperandAffineExprRecursively(expr, seenIds)))
84 return failure();
85
86 llvm::SmallSet<unsigned, 2> positions;
87 for (auto it : llvm::enumerate(First&: seenIds)) {
88 if (it.value())
89 positions.insert(V: (unsigned)it.index());
90 }
91 return positions;
92}
93
94//===----------------------------------------------------------------------===//
95// mesh::getMeshShardingAttr
96//===----------------------------------------------------------------------===//
97
98FailureOr<std::pair<bool, MeshShardingAttr>>
99mesh::getMeshShardingAttr(OpResult result) {
100 Value val = cast<Value>(Val&: result);
101 bool anyShardedForDef = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) {
102 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
103 if (!shardOp)
104 return false;
105 return !shardOp.getAnnotateForUsers();
106 });
107
108 if (anyShardedForDef) {
109 // expected to have exact one use if it has a use of `mesh.shard` without
110 // unit attr annotate_for_users
111 if (!val.hasOneUse())
112 return failure();
113 auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
114 return std::make_pair(false, shardOp.getShard());
115 }
116
117 bool anyShardedForUsers = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) {
118 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
119 if (!shardOp)
120 return false;
121 return shardOp.getAnnotateForUsers();
122 });
123 if (anyShardedForUsers) {
124 SmallVector<ShardOp> shardOps;
125 for (Operation *user : val.getUsers()) {
126 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
127 if (shardOp)
128 shardOps.push_back(shardOp);
129 }
130 MeshShardingAttr shardForDef = shardOps[0].getShard();
131 for (size_t i = 1; i < shardOps.size(); ++i) {
132 // TODO: Deduce a reasonable mesh sharding attr for def when they are
133 // different
134 assert(shardOps[i].getShard() == shardForDef &&
135 "only support all shard ops have the same mesh sharding attr");
136 }
137 return std::make_pair(true, shardForDef);
138 }
139 return failure();
140}
141
142FailureOr<std::pair<bool, MeshShardingAttr>>
143mesh::getMeshShardingAttr(OpOperand &opOperand) {
144 Value val = opOperand.get();
145 if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
146 return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
147
148 return failure();
149}
150
151//===----------------------------------------------------------------------===//
152// ShardingInterface::verifyShardingInterfaceImpl
153//===----------------------------------------------------------------------===//
154
155LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
156 Operation *op = getOperation();
157
158 // check operands and results type
159 for (Type type : op->getOperandTypes())
160 if (!llvm::isa<RankedTensorType>(type))
161 return failure();
162 for (Type type : op->getResultTypes())
163 if (!llvm::isa<RankedTensorType>(type))
164 return failure();
165
166 // check loop types
167 SmallVector<utils::IteratorType> loopTypes = getLoopIteratorTypes();
168 if (loopTypes.size() == 0)
169 return failure();
170
171 // check maps
172 SmallVector<AffineMap> maps = getIndexingMaps();
173 if (maps.size() == 0)
174 return failure();
175 unsigned numOperands = op->getNumOperands();
176 unsigned numResults = op->getNumResults();
177 if (numOperands + numResults != maps.size())
178 return failure();
179
180 for (OpResult result : op->getResults()) {
181 auto resultType = dyn_cast<RankedTensorType>(result.getType());
182 if (!resultType)
183 return failure();
184 AffineMap map = maps[numOperands + result.getResultNumber()];
185 if (!map.isProjectedPermutation()) {
186 return failure();
187 }
188 }
189
190 return success();
191}
192
193//===----------------------------------------------------------------------===//
194// ShardingInterface::printLoopTypesAndIndexingMaps
195//===----------------------------------------------------------------------===//
196
197void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
198 os << "print loop types and indexing maps for: \n";
199 getOperation()->print(os);
200 os << "\n";
201 os << "loop types: [";
202 for (utils::IteratorType type : getLoopIteratorTypes()) {
203 os << stringifyEnum(type) << " ";
204 }
205 os << "]\n";
206 os << "indexing maps: \n";
207 for (AffineMap map : getIndexingMaps())
208 os << map << "\n";
209 os << "\n";
210}
211
212//===----------------------------------------------------------------------===//
213// detail::defaultGetShardingOption
214//===----------------------------------------------------------------------===//
215
216namespace {
217
218// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
219static LogicalResult fillShardingOption(Operation *op,
220 ShardingOption &shardingOption,
221 FlatSymbolRefAttr mesh,
222 ArrayRef<MeshAxis> meshAxes,
223 unsigned loopIdx) {
224 if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
225 (!shardingOption.shardingArray[loopIdx].empty() &&
226 shardingOption.shardingArray[loopIdx] != meshAxes)) {
227 LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
228 << loopIdx << "\n");
229 return failure();
230 }
231 for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
232 if (i == loopIdx)
233 continue;
234
235 for (MeshAxis axis : meshAxes) {
236 if (llvm::is_contained(Range&: shardingOption.shardingArray[i], Element: axis)) {
237 LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
238 << axis << " duplicate");
239 return failure();
240 }
241 }
242 }
243 if (mesh)
244 shardingOption.mesh = mesh;
245 if (shardingOption.shardingArray[loopIdx].empty())
246 shardingOption.shardingArray[loopIdx].append(in_start: meshAxes.begin(),
247 in_end: meshAxes.end());
248 return success();
249}
250
251} // namespace
252
253FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
254 Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
255 ArrayRef<MeshShardingAttr> resultShardings) {
256 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
257 ShardingOption shardingOption;
258
259 if (failed(shardingOp.verifyShardingInterfaceImpl()))
260 return op->emitOpError() << "invalid sharding interface implementation";
261 SmallVector<utils::IteratorType> loopTypes =
262 shardingOp.getLoopIteratorTypes();
263 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
264 unsigned numOperands = op->getNumOperands();
265 shardingOption.shardingArray.resize(loopTypes.size());
266 llvm::SmallVector<MeshAxis> partialMeshAxes;
267 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
268 bool anyShardingInResultsOrOperands = false;
269
270 // 1. Fill sharding option based on op results
271 for (auto shardingIt : llvm::enumerate(resultShardings)) {
272 MeshShardingAttr shardAttr = shardingIt.value();
273 if (!shardAttr)
274 continue;
275 AffineMap map = maps[numOperands + shardingIt.index()];
276 anyShardingInResultsOrOperands = true;
277 // Handle the split axes: calculate the corresponding loop index for each
278 // split axes sub-array, and then store the sub-array to
279 // shardingOption[index]
280 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
281 AffineExpr expr = std::get<0>(it);
282 ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
283 auto dim = cast<AffineDimExpr>(expr);
284 unsigned index = dim.getPosition();
285 visitedLoopIndices.insert(index);
286 if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
287 axes, index)))
288 return failure();
289 }
290
291 // Handle the partial axes: at this stage, the exact loop index/indices
292 // cannot be decided because there could be multiple reduction loops.
293 ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
294 if (!partialAxes.empty()) {
295 if (!partialMeshAxes.empty())
296 return op->emitOpError() << "at most one result with partial axes is "
297 "supported at present";
298 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
299 // Add all the reduction loop indices to `visitedLoopIndices` if
300 // `partialAxes` is not empty
301 for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
302 if (isReductionLoop(loopTypes[loopIdx]))
303 visitedLoopIndices.insert(loopIdx);
304 }
305 }
306 }
307
308 // 2. Fill sharding option based on operands
309 for (auto shardingIt : llvm::enumerate(operandShardings)) {
310 MeshShardingAttr shardAttr = shardingIt.value();
311 if (!shardAttr)
312 continue;
313
314 anyShardingInResultsOrOperands = true;
315 AffineMap map = maps[shardingIt.index()];
316 unsigned numDims = map.getNumDims();
317
318 // Handle the split axes. Partial axes don't need to be handled because they
319 // only affect the defining op of the operand.
320 //
321 // TODO: Change to process the operands with single loop index first and
322 // then the operands with multiple loop indices.
323 for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
324 AffineExpr expr = std::get<0>(it);
325 ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
326 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
327 checkOperandAffineExpr(expr, numDims);
328 if (failed(loopIndices))
329 return op->emitOpError()
330 << "operand's affine expression is restricted to const_i * "
331 "dim_i + const_j + dim_j + ...";
332 if (loopIndices->empty())
333 continue;
334 if (loopIndices->size() == 1) {
335 unsigned loopIdx = *loopIndices->begin();
336 visitedLoopIndices.insert(loopIdx);
337 if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
338 axes, loopIdx)))
339 return failure();
340 }
341 // If multiple loop indices correspond to a dimension of an operand, it is
342 // difficult to infer which loop indices are responsible for sharding.
343 // Therefore, the exact loop index must be specified by others.
344 if (loopIndices->size() > 1) {
345 bool seenLoopIndices = false;
346 for (unsigned loopIdx : *loopIndices) {
347 if (visitedLoopIndices.contains(loopIdx)) {
348 seenLoopIndices = true;
349 break;
350 }
351 }
352 if (!seenLoopIndices)
353 return op->emitOpError()
354 << "the operand " << shardingIt.index()
355 << " has multiple loop indices in a dimension, but none of "
356 "them could be found in the exactly specified annotation "
357 "of op results or operands.";
358 }
359 }
360 }
361
362 // 3. Finalize sharding option
363 if (!partialMeshAxes.empty()) {
364 bool anyNonEmptyReductionLoop = llvm::any_of(
365 Range: llvm::enumerate(First&: shardingOption.shardingArray), P: [&](auto it) {
366 SmallVector<MeshAxis> &subArray = it.value();
367 int64_t idx = it.index();
368 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
369 });
370 if (!anyNonEmptyReductionLoop) {
371 bool filled = false;
372 for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
373 if (isReductionLoop(loopTypes[idx])) {
374 std::ignore = fillShardingOption(op, shardingOption, nullptr,
375 partialMeshAxes, idx);
376 filled = true;
377 break;
378 }
379 }
380 if (!filled)
381 return op->emitOpError() << "no matched reduction loop found for the "
382 "result's partial type";
383 }
384 }
385 removeTrailingEmptySubArray(array&: shardingOption.shardingArray);
386 if (!anyShardingInResultsOrOperands)
387 shardingOption.empty = true;
388 return shardingOption;
389}
390
391//===----------------------------------------------------------------------===//
392// detail::defaultAddShardingAnnotations
393//===----------------------------------------------------------------------===//
394
395// To add a `mesh.shard` op for the given result, based on the details provided
396// in `shardingOption`, `map`, and `loopTypes`.
397static LogicalResult addShardOp(OpBuilder &b, OpResult result,
398 const ShardingOption &shardingOption,
399 AffineMap map,
400 ArrayRef<utils::IteratorType> loopTypes,
401 ArrayRef<ReductionKind> reductionLoopKinds) {
402 FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
403 getMeshShardingAttr(result);
404 if (succeeded(maybeSharding) && !maybeSharding->first)
405 return success();
406
407 auto resultType = cast<RankedTensorType>(result.getType());
408 SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
409 SmallVector<MeshAxis> partialAxes;
410
411 // process the split axes
412 for (auto it : llvm::enumerate(First: map.getResults())) {
413 AffineExpr expr = it.value();
414 // `expr` must be an `AffineDimExpr` because `map` is verified by
415 // isProjectedPermutation
416 auto dim = cast<AffineDimExpr>(Val&: expr);
417 unsigned loopIdx = dim.getPosition();
418 if (loopIdx < shardingOption.shardingArray.size())
419 splitAxes[it.index()].append(RHS: shardingOption.shardingArray[loopIdx]);
420 }
421
422 // process the partial axes
423 // partialType will be ignored if partialAxes is empty
424 ReductionKind partialType = ReductionKind::Sum;
425 size_t reductionLoopKindsIdx = 0;
426 for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
427 utils::IteratorType iType = std::get<0>(it);
428 if (isReductionLoop(iType)) {
429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430 ++reductionLoopKindsIdx;
431 if (!partialAxes.empty())
432 assert(partialType == curPartialType &&
433 "Only one reduction type is supported");
434 partialType = curPartialType;
435 const SmallVector<MeshAxis> &axis = std::get<1>(it);
436 partialAxes.append(axis);
437 }
438 }
439
440 removeTrailingEmptySubArray(array&: splitAxes);
441 MeshShardingAttr shardAttr = MeshShardingAttr::get(
442 b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
443 OpBuilder::InsertionGuard guard(b);
444 b.setInsertionPointAfterValue(result);
445 auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
446 shardAttr, /*annotate_for_users*/ false);
447 result.replaceAllUsesExcept(shardOp, shardOp);
448 return success();
449}
450
451// To add a `mesh.shard` op for the given operand, based on the details provided
452// in `shardingOption`, `map`, and `loopTypes`.
453static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
454 const ShardingOption &shardingOption,
455 AffineMap map) {
456 auto maybeShardingAttr = getMeshShardingAttr(opOperand);
457 if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
458 return success();
459 Value operand = opOperand.get();
460 auto operandType = cast<RankedTensorType>(operand.getType());
461 SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
462 unsigned numDims = map.getNumDims();
463 for (auto it : llvm::enumerate(First: map.getResults())) {
464 int64_t idx = it.index();
465 AffineExpr expr = it.value();
466 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
467 checkOperandAffineExpr(expr, numDims);
468 if (failed(result: loopIndices))
469 return failure();
470 SmallVector<unsigned> shardedLoopIndices;
471 for (unsigned loopIdx : *loopIndices) {
472 if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
473 !shardingOption.shardingArray[loopIdx].empty())
474 shardedLoopIndices.push_back(Elt: loopIdx);
475 }
476 // mostly one sharded loop index is accepted
477 if (shardedLoopIndices.size() > 1)
478 return failure();
479 if (shardedLoopIndices.size() == 1) {
480 splitAxes[idx].append(
481 RHS: shardingOption.shardingArray[shardedLoopIndices[0]]);
482 }
483 }
484
485 removeTrailingEmptySubArray(array&: splitAxes);
486 MeshShardingAttr shardAttr =
487 MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
488 OpBuilder::InsertionGuard guard(b);
489 b.setInsertionPoint(opOperand.getOwner());
490 auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
491 shardAttr, true);
492 opOperand.set(shardOp);
493
494 return success();
495}
496
497LogicalResult mesh::detail::defaultAddShardingAnnotations(
498 Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
499 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
500 SmallVector<utils::IteratorType> loopTypes =
501 shardingOp.getLoopIteratorTypes();
502 SmallVector<ReductionKind> reductionKinds =
503 shardingOp.getReductionLoopIteratorKinds();
504 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
505 unsigned numOperands = op->getNumOperands();
506
507 // 1. add mesh.shard ops for all op results
508 for (OpResult result : op->getResults()) {
509 if (failed(addShardOp(b, result, shardingOption,
510 maps[numOperands + result.getResultNumber()],
511 loopTypes, reductionKinds)))
512 return failure();
513 }
514
515 // 2. add mesh.shard ops for all operands
516 for (OpOperand &opOperand : op->getOpOperands()) {
517 if (failed(result: addShardOp(b, opOperand, shardingOption,
518 map: maps[opOperand.getOperandNumber()])))
519 return failure();
520 }
521
522 return success();
523}
524
525#ifndef NDEBUG
526static bool
527isValueCompatibleWithFullReplicationSharding(Value value,
528 MeshShardingAttr sharding) {
529 if (isa<RankedTensorType>(Val: value.getType())) {
530 return sharding && isFullReplication(sharding);
531 }
532
533 return !sharding;
534}
535
536template <typename ValueRange, typename MeshShardingAttrRage>
537static bool areValuesCompatibleWithFullReplicationShardings(
538 ValueRange &&values, MeshShardingAttrRage &&shardings) {
539 if (std::size(values) != std::size(shardings)) {
540 return false;
541 }
542 return llvm::all_of(llvm::zip_equal(
543 std::forward<ValueRange>(values),
544 std::forward<MeshShardingAttrRage>(shardings)),
545 [](auto valueAndSharding) {
546 return isValueCompatibleWithFullReplicationSharding(
547 std::get<0>(valueAndSharding),
548 std::get<1>(valueAndSharding));
549 });
550}
551#endif // NDEBUG
552
553void mesh::spmdizeFullyReplicatedOperation(
554 Operation &op, ArrayRef<Value> spmdizedOperands,
555 ArrayRef<MeshShardingAttr> operandShardings,
556 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
557 SymbolTableCollection &symbolTable, OpBuilder &builder) {
558 assert(spmdizedOperands.size() == operandShardings.size());
559 assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
560 operandShardings));
561 assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
562 resultShardings));
563 // `clone` will populate the mapping of old to new results.
564 builder.clone(op, mapper&: spmdizationMap);
565}
566
567static void updateMeshAxisAssignmentForLoopIterators(
568 ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
569 SmallVector<std::optional<SmallVector<MeshAxis>>>
570 &meshAxesAssignmentForLoopIterators) {
571 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(Val&: indexingExpr);
572 unsigned loopIteratorIdx = affineDimExpr.getPosition();
573 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
574 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
575 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
576 } else {
577 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
578 llvm::to_vector(Range&: meshAxesAssignmentForTensorAxis);
579 }
580}
581
582ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
583 ArrayRef<MeshShardingAttr> operandShardings,
584 ArrayRef<MeshShardingAttr> resultShardings,
585 ArrayRef<utils::IteratorType> loopIteratorTypes,
586 ArrayRef<AffineMap> indexingMaps) {
587 SmallVector<std::optional<SmallVector<MeshAxis>>>
588 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
589 SmallVector<MeshShardingAttr> operatorAndResultShardings;
590 operatorAndResultShardings.reserve(operandShardings.size() +
591 resultShardings.size());
592 llvm::append_range(operatorAndResultShardings, operandShardings);
593 for (auto [sharding, affineMap] :
594 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
595 if (!sharding) {
596 continue;
597 }
598 for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
599 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
600 updateMeshAxisAssignmentForLoopIterators(
601 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
602 meshAxisAssignmentForLoopIterators);
603 }
604 // Missing trailing split axes means replication on those tensor dimensions.
605 for (unsigned i = sharding.getSplitAxes().size();
606 i < affineMap.getNumResults(); ++i) {
607 updateMeshAxisAssignmentForLoopIterators(
608 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
609 }
610 }
611
612 ShardingArray res;
613 llvm::transform(Range&: meshAxisAssignmentForLoopIterators, d_first: std::back_inserter(x&: res),
614 F: [](std::optional<SmallVector<MeshAxis>> &axes) {
615 if (!axes) {
616 return SmallVector<MeshAxis>();
617 };
618 return std::move(*axes);
619 });
620 return res;
621}
622
623bool mesh::isAtLeastOneReductionIteratorSharded(
624 ArrayRef<utils::IteratorType> loopIteratorTypes,
625 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
626 for (auto [loopIteratorType, meshAxisAssignment] :
627 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
628 if (loopIteratorType == utils::IteratorType::reduction &&
629 !meshAxisAssignment.empty()) {
630 return true;
631 }
632 }
633 return false;
634}
635
636SmallVector<MeshAxis> mesh::getReductionMeshAxes(
637 ArrayRef<utils::IteratorType> loopIteratorTypes,
638 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
639 SmallVector<MeshAxis> meshAxes;
640 for (auto [loopIteratorType, meshAxisAssignment] :
641 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
642 if (loopIteratorType == utils::IteratorType::reduction) {
643 llvm::append_range(meshAxes, meshAxisAssignment);
644 }
645 }
646 return meshAxes;
647}
648
649void mesh::spmdizeTriviallyShardableOperation(
650 Operation &op, ArrayRef<Value> spmdizedOperands,
651 ArrayRef<MeshShardingAttr> operandShardings,
652 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
653 SymbolTableCollection &symbolTable, OpBuilder &builder) {
654 // `clone` will populate the mapping of old to new results.
655 Operation *newOp = builder.clone(op, mapper&: spmdizationMap);
656 // Set the result types to the sharded counterparts.
657 for (auto [oldResult, newResult, sharding] :
658 llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
659 newResult.setType(shardType(newResult.getType(),
660 getMesh(&op, sharding.getMesh(), symbolTable),
661 sharding));
662 }
663}
664

source code of mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp