1//===- Spmdization.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Mesh/IR/MeshOps.h"
14#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinTypeInterfaces.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/Location.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/SymbolTable.h"
26#include "mlir/IR/Value.h"
27#include "mlir/Interfaces/ControlFlowInterfaces.h"
28#include "mlir/Interfaces/FunctionInterfaces.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Support/LLVM.h"
31#include "mlir/Support/LogicalResult.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/Support/Casting.h"
37#include <iterator>
38#include <optional>
39#include <tuple>
40#include <type_traits>
41
42namespace mlir::mesh {
43
44template <typename SourceAxes, typename TargetAxes>
45static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
46 const TargetAxes &targetAxes) {
47 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
48 return sourceAxes.contains(targetAxis);
49 });
50}
51
52// Return the reduced value and its corresponding sharding.
53// Example:
54// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
55// targetSharding = <@mesh_1d, [[]]>
56// Then will apply all-reduce on the source value
57// and return it with the sharding <@mesh_1d, [[0]]>.
58static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
59handlePartialAxesDuringResharding(OpBuilder &builder,
60 MeshShardingAttr sourceSharding,
61 MeshShardingAttr targetSharding,
62 TypedValue<ShapedType> sourceShard) {
63 if (sourceSharding.getPartialAxes().empty() &&
64 targetSharding.getPartialAxes().empty()) {
65 return {sourceShard, sourceSharding};
66 }
67 assert(targetSharding.getPartialAxes().empty() ||
68 (!sourceSharding.getPartialAxes().empty() &&
69 sourceSharding.getPartialType() == targetSharding.getPartialType()));
70 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
71 using AxisSet = llvm::SmallDenseSet<Axis>;
72 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
73 sourceSharding.getPartialAxes().end());
74 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
75 targetSharding.getPartialAxes().end());
76 assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
77 targetShardingPartialAxesSet));
78 llvm::SmallVector<MeshAxis> allReduceMeshAxes;
79 llvm::copy_if(sourceShardingPartialAxesSet,
80 std::back_inserter(x&: allReduceMeshAxes),
81 [&targetShardingPartialAxesSet](Axis a) {
82 return !targetShardingPartialAxesSet.contains(a);
83 });
84 if (allReduceMeshAxes.empty()) {
85 return {sourceShard, sourceSharding};
86 }
87
88 builder.setInsertionPointAfterValue(sourceShard);
89 TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
90 builder
91 .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92 sourceSharding.getMesh().getLeafReference(),
93 allReduceMeshAxes, sourceShard,
94 sourceSharding.getPartialType())
95 .getResult());
96
97 llvm::SmallVector<MeshAxis> remainingPartialAxes;
98 llvm::copy_if(sourceShardingPartialAxesSet,
99 std::back_inserter(x&: allReduceMeshAxes),
100 [&targetShardingPartialAxesSet](Axis a) {
101 return targetShardingPartialAxesSet.contains(a);
102 });
103 MeshShardingAttr resultSharding =
104 MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
105 sourceSharding.getSplitAxes(), remainingPartialAxes,
106 sourceSharding.getPartialType());
107 return {resultValue, resultSharding};
108}
109
110static MeshShardingAttr
111targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
112 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
113 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
114 llvm::to_vector(sourceSharding.getSplitAxes());
115 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
116 splitTensorAxis) {
117 targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
118 }
119 auto targetSplitAxes =
120 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
121 targetSplitAxes.push_back(splitMeshAxis);
122 targetShardingSplitAxes[splitTensorAxis] =
123 MeshAxesAttr::get(ctx, targetSplitAxes);
124 return MeshShardingAttr::get(
125 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
126 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
127}
128
129// Split a replicated tensor along a mesh axis.
130// e.g. [[0, 1]] -> [[0, 1, 2]].
131// Returns the spmdized target value with its sharding.
132static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
133splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
134 MeshShardingAttr sourceSharding,
135 TypedValue<ShapedType> sourceShard, MeshOp mesh,
136 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
137 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
138 builder
139 .create<AllSliceOp>(sourceShard, mesh,
140 ArrayRef<MeshAxis>(splitMeshAxis),
141 splitTensorAxis)
142 .getResult());
143 MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
144 builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
145 return {targetShard, targetSharding};
146}
147
148// Detect if the resharding is of type e.g.
149// [[0, 1]] -> [[0, 1, 2]].
150// If detected, returns the corresponding tensor axis mesh axis pair.
151// Does not detect insertions like
152// [[0, 1]] -> [[0, 2, 1]].
153static std::optional<std::tuple<int64_t, MeshAxis>>
154detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
155 MeshShardingAttr targetSharding) {
156 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
157 ++tensorAxis) {
158 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
159 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
160 targetSharding.getSplitAxes()[tensorAxis].size()) {
161 continue;
162 }
163 if (!llvm::equal(
164 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
165 llvm::make_range(
166 targetSharding.getSplitAxes()[tensorAxis]
167 .asArrayRef()
168 .begin(),
169 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
170 1))) {
171 continue;
172 }
173 } else {
174 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
175 continue;
176 }
177 }
178 return std::make_tuple(
179 tensorAxis,
180 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
181 }
182 return std::nullopt;
183}
184
185static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
186trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
187 MeshShardingAttr sourceSharding,
188 MeshShardingAttr targetSharding,
189 TypedValue<ShapedType> sourceShard) {
190 if (auto detectRes =
191 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
192 auto [tensorAxis, meshAxis] = detectRes.value();
193 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
194 tensorAxis, meshAxis);
195 }
196
197 return std::nullopt;
198}
199
200// Detect if the resharding is of type e.g.
201// [[0, 1, 2]] -> [[0, 1]].
202// If detected, returns the corresponding tensor axis mesh axis pair.
203static std::optional<std::tuple<int64_t, MeshAxis>>
204detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
205 MeshShardingAttr targetSharding) {
206 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
207 ++tensorAxis) {
208 if (targetSharding.getSplitAxes().size() > tensorAxis) {
209 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
210 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
211 continue;
212 if (!llvm::equal(
213 llvm::make_range(
214 sourceSharding.getSplitAxes()[tensorAxis]
215 .asArrayRef()
216 .begin(),
217 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
218 1),
219 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
220 continue;
221 } else {
222 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
223 continue;
224 }
225 return std::make_tuple(
226 tensorAxis,
227 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
228 }
229 return std::nullopt;
230}
231
232static MeshShardingAttr
233targetShardingInUnsplitLastAxis(MLIRContext *ctx,
234 MeshShardingAttr sourceSharding,
235 int64_t splitTensorAxis) {
236 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
237 llvm::to_vector(sourceSharding.getSplitAxes());
238 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
239 splitTensorAxis);
240 auto targetSplitAxes =
241 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
242
243 targetSplitAxes.pop_back();
244 targetShardingSplitAxes[splitTensorAxis] =
245 MeshAxesAttr::get(ctx, targetSplitAxes);
246 return MeshShardingAttr::get(
247 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
248 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
249}
250
251static ShapedType allGatherResultShapeInUnsplitLastAxis(
252 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
253 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
254 targetShape[splitTensorAxis] =
255 gatherDimension(dimSize: targetShape[splitTensorAxis], shardCount: splitCount);
256 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
257}
258
259static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
260unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
261 MeshShardingAttr sourceSharding,
262 ShapedType sourceUnshardedShape,
263 TypedValue<ShapedType> sourceShard, MeshOp mesh,
264 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
265 MLIRContext *ctx = builder.getContext();
266 builder.setInsertionPointAfterValue(sourceShard);
267
268 MeshShardingAttr targetSharding =
269 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
270 ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
271 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
272 Value allGatherResult = builder.create<AllGatherOp>(
273 RankedTensorType::get(allGatherResultShape.getShape(),
274 allGatherResultShape.getElementType()),
275 mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
276 APInt(64, splitTensorAxis));
277 ShapedType targetShape =
278 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
279 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
280 builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
281 return {targetShard, targetSharding};
282}
283
284static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
285tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
286 MeshShardingAttr sourceSharding,
287 MeshShardingAttr targetSharding,
288 ShapedType sourceUnshardedShape,
289 TypedValue<ShapedType> sourceShard) {
290 if (auto detectRes =
291 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
292 auto [tensorAxis, meshAxis] = detectRes.value();
293 return unsplitLastAxisInResharding(builder, sourceSharding,
294 sourceUnshardedShape, sourceShard, mesh,
295 tensorAxis, meshAxis);
296 }
297
298 return std::nullopt;
299}
300
301// Detect if the resharding is of type e.g.
302// [[0, 1], [2]] -> [[0], [1, 2]].
303// Only moving the last axis counts.
304// If detected, returns the corresponding (source_tensor_axis,
305// target_tensor_axis, mesh_axis) tuple.
306static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
307detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
308 MeshShardingAttr targetSharding) {
309 for (size_t sourceTensorAxis = 0;
310 sourceTensorAxis < sourceSharding.getSplitAxes().size();
311 ++sourceTensorAxis) {
312 for (size_t targetTensorAxis = 0;
313 targetTensorAxis < targetSharding.getSplitAxes().size();
314 ++targetTensorAxis) {
315 if (sourceTensorAxis == targetTensorAxis)
316 continue;
317 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
318 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
319 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
320 targetSharding.getSplitAxes()[targetTensorAxis]
321 .asArrayRef()
322 .back())
323 continue;
324 if (!llvm::equal(
325 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
326 .asArrayRef()
327 .begin(),
328 sourceSharding.getSplitAxes()[sourceTensorAxis]
329 .asArrayRef()
330 .end() -
331 1),
332 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
333 .asArrayRef()
334 .begin(),
335 targetSharding.getSplitAxes()[targetTensorAxis]
336 .asArrayRef()
337 .end() -
338 1)))
339 continue;
340 return std::make_tuple(
341 sourceTensorAxis, targetTensorAxis,
342 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
343 }
344 }
345 return std::nullopt;
346}
347
348static MeshShardingAttr
349targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
350 int64_t sourceTensorAxis,
351 int64_t targetTensorAxis) {
352 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
353 llvm::to_vector(sourceSharding.getSplitAxes());
354 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
355 targetTensorAxis) {
356 targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
357 }
358
359 auto sourceSplitAxes =
360 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
361 assert(!sourceSplitAxes.empty());
362 auto meshAxis = sourceSplitAxes.back();
363 sourceSplitAxes.pop_back();
364 targetShardingSplitAxes[sourceTensorAxis] =
365 MeshAxesAttr::get(ctx, sourceSplitAxes);
366
367 auto targetSplitAxes =
368 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
369 targetSplitAxes.push_back(meshAxis);
370 targetShardingSplitAxes[targetTensorAxis] =
371 MeshAxesAttr::get(ctx, targetSplitAxes);
372
373 return MeshShardingAttr::get(
374 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
375 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
376}
377
378static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
379 int64_t splitCount,
380 int64_t sourceTensorAxis,
381 int64_t targetTensorAxis) {
382 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
383 targetShape[sourceTensorAxis] =
384 gatherDimension(dimSize: targetShape[sourceTensorAxis], shardCount: splitCount);
385 targetShape[targetTensorAxis] =
386 shardDimension(dimSize: targetShape[targetTensorAxis], shardCount: splitCount);
387 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
388}
389
390static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
391moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
392 MeshShardingAttr sourceSharding,
393 ShapedType sourceUnshardedShape,
394 TypedValue<ShapedType> sourceShard,
395 int64_t sourceTensorAxis,
396 int64_t targetTensorAxis, MeshAxis meshAxis) {
397 MLIRContext *ctx = builder.getContext();
398 builder.setInsertionPointAfterValue(sourceShard);
399
400 MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
401 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
402 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
403 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
404 targetTensorAxis);
405 Value allToAllResult = builder.create<AllToAllOp>(
406 RankedTensorType::get(allToAllResultShape.getShape(),
407 allToAllResultShape.getElementType()),
408 mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
409 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
410 ShapedType targetShape =
411 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
412 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
413 builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
414 return {targetShard, targetSharding};
415}
416
417static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
418tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
419 MeshShardingAttr sourceSharding,
420 MeshShardingAttr targetSharding,
421 ShapedType sourceUnshardedShape,
422 TypedValue<ShapedType> sourceShard) {
423 if (auto detectRes =
424 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
425 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
426 return moveLastSplitAxisInResharding(
427 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
428 sourceTensorAxis, targetTensorAxis, meshAxis);
429 }
430
431 return std::nullopt;
432}
433
434// Handles only resharding on a 1D mesh.
435// Currently the sharded tensor axes must be exactly divisible by the single
436// mesh axis size.
437static TypedValue<ShapedType>
438reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
439 MeshShardingAttr sourceSharding,
440 MeshShardingAttr targetSharding,
441 TypedValue<ShapedType> sourceUnshardedValue,
442 TypedValue<ShapedType> sourceShard) {
443 assert(sourceShard.getType() ==
444 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
445 [[maybe_unused]] ShapedType targetShardType =
446 shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
447 assert(sourceShard.getType().getRank() == targetShardType.getRank());
448 assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
449
450 auto [reducedSourceShard, reducedSourceSharding] =
451 handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
452 sourceShard);
453
454 if (reducedSourceSharding == targetSharding) {
455 return reducedSourceShard;
456 }
457
458 TypedValue<ShapedType> targetShard;
459 MeshShardingAttr actualTargetSharding;
460 if (auto tryRes = tryMoveLastSplitAxisInResharding(
461 builder, mesh, reducedSourceSharding, targetSharding,
462 sourceUnshardedValue.getType(), reducedSourceShard)) {
463 std::tie(targetShard, actualTargetSharding) = tryRes.value();
464 } else if (auto tryRes = trySplitLastAxisInResharding(
465 builder, mesh, reducedSourceSharding, targetSharding,
466 reducedSourceShard)) {
467 std::tie(targetShard, actualTargetSharding) = tryRes.value();
468 } else if (auto tryRes = tryUnsplitLastAxisInResharding(
469 builder, mesh, reducedSourceSharding, targetSharding,
470 sourceUnshardedValue.getType(), reducedSourceShard)) {
471 std::tie(targetShard, actualTargetSharding) = tryRes.value();
472 } else {
473 assert(false && "Did not find any pattern to apply.");
474 }
475
476 assert(actualTargetSharding == targetSharding);
477 assert(targetShard.getType() == targetShardType);
478 return targetShard;
479}
480
481TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
482 MeshShardingAttr sourceSharding,
483 MeshShardingAttr targetSharding,
484 TypedValue<ShapedType> sourceUnshardedValue,
485 TypedValue<ShapedType> sourceShard) {
486 // Resort to handling only 1D meshes since the general case is complicated if
487 // it needs to be communication efficient in terms of minimizing the data
488 // transfered between devices.
489 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
490 sourceUnshardedValue, sourceShard);
491}
492
493TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
494 ShardOp target,
495 TypedValue<ShapedType> sourceShardValue) {
496 assert(!source.getAnnotateForUsers());
497 assert(target.getAnnotateForUsers());
498 assert(source.getResult() == target.getOperand());
499 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
500 return reshard(
501 implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
502 cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
503}
504
505TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
506 ShardOp target,
507 TypedValue<ShapedType> sourceShardValue,
508 SymbolTableCollection &symbolTableCollection) {
509 MeshOp srcMesh = getMesh(source, symbolTableCollection);
510 assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
511 return reshard(builder, srcMesh, source, target, sourceShardValue);
512}
513
514void reshardingRegisterDependentDialects(DialectRegistry &registry) {
515 registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
516}
517
518#define GEN_PASS_DEF_SPMDIZATION
519#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
520
521using UnshardedToShardedValueMap = DenseMap<Value, Value>;
522
523// Get the types of block arguments for an spmdized block.
524// Reads the sharding annotations of the arguments to deduce the sharded types.
525// Types that are not ranked tensors are left unchanged.
526SmallVector<Type>
527shardedBlockArgumentTypes(Block &block,
528 SymbolTableCollection &symbolTableCollection) {
529 SmallVector<Type> res;
530 llvm::transform(
531 block.getArguments(), std::back_inserter(res),
532 [&symbolTableCollection](BlockArgument arg) {
533 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
534 if (!rankedTensorArg) {
535 return arg.getType();
536 }
537
538 assert(rankedTensorArg.hasOneUse());
539 Operation *useOp = *rankedTensorArg.getUsers().begin();
540 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
541 assert(shardOp);
542 MeshOp mesh = getMesh(shardOp, symbolTableCollection);
543 return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
544 shardOp.getShardAttr()));
545 });
546 return res;
547}
548
549static LogicalResult spmdizeOperation(
550 Operation &op, ArrayRef<Value> spmdizedOperands,
551 ArrayRef<MeshShardingAttr> operandShardings,
552 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
553 SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
554 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
555 if (!shardingInterface) {
556 // If there is no sharding interface we are conservative and assume that
557 // the op should be fully replicated no all devices.
558 spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
559 resultShardings, spmdizationMap,
560 symbolTableCollection, builder);
561 } else {
562 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
563 resultShardings, spmdizationMap,
564 symbolTableCollection, builder))) {
565 return failure();
566 }
567 }
568
569 assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
570 return spmdizationMap.contains(result);
571 }));
572
573 return success();
574}
575
576// Retrieve the sharding annotations for the operands of the given operation.
577// If the type is not a ranked tensor it is not require to have an annotation.
578static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
579 SmallVector<MeshShardingAttr> res;
580 res.reserve(op.getNumOperands());
581 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
582 TypedValue<RankedTensorType> rankedTensor =
583 dyn_cast<TypedValue<RankedTensorType>>(operand);
584 if (!rankedTensor) {
585 return MeshShardingAttr();
586 }
587
588 Operation *definingOp = operand.getDefiningOp();
589 assert(definingOp);
590 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
591 return shardOp.getShard();
592 });
593 return res;
594}
595
596// Retrieve the sharding annotations for the results of the given operation.
597// If the type is not a ranked tensor it is not require to have an annotation.
598static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
599 SmallVector<MeshShardingAttr> res;
600 res.reserve(op.getNumResults());
601 llvm::transform(op.getResults(), std::back_inserter(res),
602 [](OpResult result) {
603 TypedValue<RankedTensorType> rankedTensor =
604 dyn_cast<TypedValue<RankedTensorType>>(result);
605 if (!rankedTensor) {
606 return MeshShardingAttr();
607 }
608
609 assert(result.hasOneUse());
610 Operation *userOp = *result.getUsers().begin();
611 ShardOp shardOp = llvm::cast<ShardOp>(userOp);
612 return shardOp.getShard();
613 });
614 return res;
615}
616
617static LogicalResult
618spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
619 SymbolTableCollection &symbolTableCollection,
620 OpBuilder &builder) {
621 Value targetSpmdValue;
622
623 // Check if 2 shard ops are chained. If not there is no need for resharding
624 // as the source and target shared the same sharding.
625 ShardOp srcShardOp =
626 dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
627 if (!srcShardOp) {
628 targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
629 } else {
630 // Insert resharding.
631 assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
632 TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
633 spmdizationMap.lookup(srcShardOp.getOperand()));
634 targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
635 symbolTableCollection);
636 }
637
638 assert(!spmdizationMap.contains(shardOp.getResult()));
639 spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
640 return success();
641}
642
643static LogicalResult
644spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
645 SymbolTableCollection &symbolTableCollection,
646 OpBuilder &builder) {
647 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
648 if (shardOp) {
649 return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
650 builder);
651 }
652
653 SmallVector<Value> spmdizedOperands;
654 llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
655 [&spmdizationMap](Value operand) {
656 assert(spmdizationMap.contains(operand));
657 return spmdizationMap.lookup(operand);
658 });
659 return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
660 getResultShardings(op), spmdizationMap,
661 symbolTableCollection, builder);
662}
663
664static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
665 SymbolTableCollection &symbolTableCollection,
666 OpBuilder &builder) {
667 SmallVector<Location> argLocations;
668 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
669 [](BlockArgument arg) { return arg.getLoc(); });
670 Block *newBlock = builder.createBlock(
671 block.getParent(), {},
672 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
673 for (auto [unshardedBlockArg, spmdizedBlockArg] :
674 llvm::zip(block.getArguments(), newBlock->getArguments())) {
675 spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
676 }
677
678 OpBuilder::InsertionGuard insertionGuard(builder);
679 builder.setInsertionPointToEnd(newBlock);
680 for (Operation &op : block.getOperations()) {
681 if (failed(result: spmdizeOperation(op, spmdizationMap, symbolTableCollection,
682 builder))) {
683 return failure();
684 }
685 }
686
687 return success();
688}
689
690static LogicalResult
691spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
692 SymbolTableCollection &symbolTableCollection) {
693 OpBuilder builder(op.getFunctionBody());
694
695 // Snapshot the original blocks to not mess up the iteration when adding new
696 // blocks.
697 SmallVector<Block *> originalBlocks;
698 llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
699 [](Block &b) { return &b; });
700
701 for (Block *block : originalBlocks) {
702 if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
703 builder))) {
704 return failure();
705 }
706 }
707
708 for (Block *block : originalBlocks) {
709 block->erase();
710 }
711
712 // Find a return op and change the function results signature to its operands
713 // signature.
714 Operation *returnOp = nullptr;
715 for (Block &block : op.getFunctionBody()) {
716 if (block.empty()) {
717 continue;
718 }
719
720 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
721 returnOp = &block.back();
722 break;
723 }
724 }
725 assert(returnOp);
726 op.setType(FunctionType::get(op->getContext(),
727 op.getFunctionBody().front().getArgumentTypes(),
728 returnOp->getOperandTypes()));
729
730 return success();
731}
732
733namespace {
734
735struct Spmdization : public impl::SpmdizationBase<Spmdization> {
736 void runOnOperation() override {
737 IRMapping spmdizationMap;
738 SymbolTableCollection symbolTableCollection;
739 if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
740 symbolTableCollection))) {
741 return signalPassFailure();
742 }
743 }
744
745 void getDependentDialects(DialectRegistry &registry) const override {
746 reshardingRegisterDependentDialects(registry);
747 registry.insert<mesh::MeshDialect>();
748 }
749};
750
751} // namespace
752
753} // namespace mlir::mesh
754

source code of mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp