1//===- Spmdization.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Mesh/IR/MeshOps.h"
14#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinTypeInterfaces.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/Location.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/SymbolTable.h"
26#include "mlir/IR/Value.h"
27#include "mlir/Interfaces/ControlFlowInterfaces.h"
28#include "mlir/Interfaces/FunctionInterfaces.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Support/LLVM.h"
31#include "llvm/ADT/APInt.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/Casting.h"
36#include <iterator>
37#include <optional>
38#include <tuple>
39#include <type_traits>
40
41namespace mlir::mesh {
42
43template <typename SourceAxes, typename TargetAxes>
44static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
45 const TargetAxes &targetAxes) {
46 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
47 return sourceAxes.contains(targetAxis);
48 });
49}
50
51// Return the reduced value and its corresponding sharding.
52// Example:
53// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
54// targetSharding = <@mesh_1d, [[]]>
55// Then will apply all-reduce on the source value
56// and return it with the sharding <@mesh_1d, [[0]]>.
57static std::tuple<TypedValue<ShapedType>, MeshSharding>
58handlePartialAxesDuringResharding(OpBuilder &builder,
59 MeshSharding sourceSharding,
60 MeshSharding targetSharding,
61 TypedValue<ShapedType> sourceShard) {
62 if (sourceSharding.getPartialAxes().empty() &&
63 targetSharding.getPartialAxes().empty()) {
64 return {sourceShard, sourceSharding};
65 }
66 assert(targetSharding.getPartialAxes().empty() ||
67 (!sourceSharding.getPartialAxes().empty() &&
68 sourceSharding.getPartialType() == targetSharding.getPartialType()));
69 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70 using AxisSet = llvm::SmallDenseSet<Axis>;
71 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72 sourceSharding.getPartialAxes().end());
73 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74 targetSharding.getPartialAxes().end());
75 assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
76 targetShardingPartialAxesSet));
77 llvm::SmallVector<MeshAxis> allReduceMeshAxes;
78 llvm::copy_if(Range&: sourceShardingPartialAxesSet,
79 Out: std::back_inserter(x&: allReduceMeshAxes),
80 P: [&targetShardingPartialAxesSet](Axis a) {
81 return !targetShardingPartialAxesSet.contains(V: a);
82 });
83 if (allReduceMeshAxes.empty()) {
84 return {sourceShard, sourceSharding};
85 }
86
87 builder.setInsertionPointAfterValue(sourceShard);
88 TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
89 builder
90 .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91 sourceSharding.getMeshAttr().getLeafReference(),
92 allReduceMeshAxes, sourceShard,
93 sourceSharding.getPartialType())
94 .getResult());
95
96 llvm::SmallVector<MeshAxis> remainingPartialAxes;
97 llvm::copy_if(Range&: sourceShardingPartialAxesSet,
98 Out: std::back_inserter(x&: allReduceMeshAxes),
99 P: [&targetShardingPartialAxesSet](Axis a) {
100 return targetShardingPartialAxesSet.contains(V: a);
101 });
102 MeshSharding resultSharding = MeshSharding::get(
103 sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104 remainingPartialAxes, sourceSharding.getPartialType());
105 return {resultValue, resultSharding};
106}
107
108static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
109 MeshSharding sourceSharding,
110 int64_t splitTensorAxis,
111 MeshAxis splitMeshAxis) {
112 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
113 llvm::to_vector(sourceSharding.getSplitAxes());
114 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
115 splitTensorAxis) {
116 targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
117 }
118 auto targetSplitAxes =
119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120 targetSplitAxes.push_back(splitMeshAxis);
121 targetShardingSplitAxes[splitTensorAxis] =
122 MeshAxesAttr::get(ctx, targetSplitAxes);
123 return MeshSharding::get(
124 sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
126}
127
128// Split a replicated tensor along a mesh axis.
129// E.g. [[0, 1]] -> [[0, 1, 2]].
130// Returns the spmdized target value with its sharding.
131static std::tuple<TypedValue<ShapedType>, MeshSharding>
132splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
133 MeshSharding sourceSharding,
134 TypedValue<ShapedType> sourceShard, MeshOp mesh,
135 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
136 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
137 builder
138 .create<AllSliceOp>(sourceShard, mesh,
139 ArrayRef<MeshAxis>(splitMeshAxis),
140 splitTensorAxis)
141 .getResult());
142 MeshSharding targetSharding = targetShardingInSplitLastAxis(
143 ctx: builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144 return {targetShard, targetSharding};
145}
146
147// Detect if the resharding is of type e.g.
148// [[0, 1]] -> [[0, 1, 2]].
149// If detected, returns the corresponding tensor axis mesh axis pair.
150// Does not detect insertions like
151// [[0, 1]] -> [[0, 2, 1]].
152static std::optional<std::tuple<int64_t, MeshAxis>>
153detectSplitLastAxisInResharding(MeshSharding sourceSharding,
154 MeshSharding targetSharding) {
155 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156 ++tensorAxis) {
157 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159 targetSharding.getSplitAxes()[tensorAxis].size()) {
160 continue;
161 }
162 if (!llvm::equal(
163 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
164 llvm::make_range(
165 targetSharding.getSplitAxes()[tensorAxis]
166 .asArrayRef()
167 .begin(),
168 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
169 1))) {
170 continue;
171 }
172 } else {
173 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
174 continue;
175 }
176 }
177 return std::make_tuple(
178 tensorAxis,
179 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
180 }
181 return std::nullopt;
182}
183
184static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
185trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
186 MeshSharding sourceSharding,
187 MeshSharding targetSharding,
188 TypedValue<ShapedType> sourceShard) {
189 if (auto detectRes =
190 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
191 auto [tensorAxis, meshAxis] = detectRes.value();
192 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
193 tensorAxis, meshAxis);
194 }
195
196 return std::nullopt;
197}
198
199// Detect if the resharding is of type e.g.
200// [[0, 1, 2]] -> [[0, 1]].
201// If detected, returns the corresponding tensor axis mesh axis pair.
202static std::optional<std::tuple<int64_t, MeshAxis>>
203detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
204 MeshSharding targetSharding) {
205 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206 ++tensorAxis) {
207 if (targetSharding.getSplitAxes().size() > tensorAxis) {
208 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
210 continue;
211 if (!llvm::equal(
212 llvm::make_range(
213 sourceSharding.getSplitAxes()[tensorAxis]
214 .asArrayRef()
215 .begin(),
216 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
217 1),
218 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
219 continue;
220 } else {
221 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
222 continue;
223 }
224 return std::make_tuple(
225 tensorAxis,
226 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
227 }
228 return std::nullopt;
229}
230
231static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
232 MeshSharding sourceSharding,
233 int64_t splitTensorAxis) {
234 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
235 llvm::to_vector(sourceSharding.getSplitAxes());
236 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
237 splitTensorAxis);
238 auto targetSplitAxes =
239 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
240
241 targetSplitAxes.pop_back();
242 targetShardingSplitAxes[splitTensorAxis] =
243 MeshAxesAttr::get(ctx, targetSplitAxes);
244 return MeshSharding::get(
245 sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
247}
248
249static ShapedType allGatherResultShapeInUnsplitLastAxis(
250 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
251 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252 targetShape[splitTensorAxis] =
253 gatherDimension(dimSize: targetShape[splitTensorAxis], shardCount: splitCount);
254 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
255}
256
257static std::tuple<TypedValue<ShapedType>, MeshSharding>
258unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
259 MeshSharding sourceSharding,
260 ShapedType sourceUnshardedShape,
261 TypedValue<ShapedType> sourceShard, MeshOp mesh,
262 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
263 MLIRContext *ctx = builder.getContext();
264 builder.setInsertionPointAfterValue(sourceShard);
265
266 MeshSharding targetSharding =
267 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
268 ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270 Value allGatherResult = builder.create<AllGatherOp>(
271 RankedTensorType::get(allGatherResultShape.getShape(),
272 allGatherResultShape.getElementType()),
273 mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
274 APInt(64, splitTensorAxis));
275 ShapedType targetShape =
276 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
277 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
278 builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279 return {targetShard, targetSharding};
280}
281
282static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
283tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
284 MeshSharding sourceSharding,
285 MeshSharding targetSharding,
286 ShapedType sourceUnshardedShape,
287 TypedValue<ShapedType> sourceShard) {
288 if (auto detectRes =
289 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
290 auto [tensorAxis, meshAxis] = detectRes.value();
291 return unsplitLastAxisInResharding(builder, sourceSharding,
292 sourceUnshardedShape, sourceShard, mesh,
293 tensorAxis, meshAxis);
294 }
295
296 return std::nullopt;
297}
298
299// Detect if the resharding is of type e.g.
300// [[0, 1], [2]] -> [[0], [1, 2]].
301// Only moving the last axis counts.
302// If detected, returns the corresponding (source_tensor_axis,
303// target_tensor_axis, mesh_axis) tuple.
304static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
305detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
306 MeshSharding targetSharding) {
307 for (size_t sourceTensorAxis = 0;
308 sourceTensorAxis < sourceSharding.getSplitAxes().size();
309 ++sourceTensorAxis) {
310 for (size_t targetTensorAxis = 0;
311 targetTensorAxis < targetSharding.getSplitAxes().size();
312 ++targetTensorAxis) {
313 if (sourceTensorAxis == targetTensorAxis)
314 continue;
315 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
316 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
317 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
318 targetSharding.getSplitAxes()[targetTensorAxis]
319 .asArrayRef()
320 .back())
321 continue;
322 if (!llvm::equal(
323 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
324 .asArrayRef()
325 .begin(),
326 sourceSharding.getSplitAxes()[sourceTensorAxis]
327 .asArrayRef()
328 .end() -
329 1),
330 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
331 .asArrayRef()
332 .begin(),
333 targetSharding.getSplitAxes()[targetTensorAxis]
334 .asArrayRef()
335 .end() -
336 1)))
337 continue;
338 return std::make_tuple(
339 sourceTensorAxis, targetTensorAxis,
340 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
341 }
342 }
343 return std::nullopt;
344}
345
346static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
347 MeshSharding sourceSharding,
348 int64_t sourceTensorAxis,
349 int64_t targetTensorAxis) {
350 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
351 llvm::to_vector(sourceSharding.getSplitAxes());
352 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
353 targetTensorAxis) {
354 targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
355 }
356
357 auto sourceSplitAxes =
358 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359 assert(!sourceSplitAxes.empty());
360 auto meshAxis = sourceSplitAxes.back();
361 sourceSplitAxes.pop_back();
362 targetShardingSplitAxes[sourceTensorAxis] =
363 MeshAxesAttr::get(ctx, sourceSplitAxes);
364
365 auto targetSplitAxes =
366 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367 targetSplitAxes.push_back(meshAxis);
368 targetShardingSplitAxes[targetTensorAxis] =
369 MeshAxesAttr::get(ctx, targetSplitAxes);
370
371 return MeshSharding::get(
372 sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
374}
375
376static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
377 int64_t splitCount,
378 int64_t sourceTensorAxis,
379 int64_t targetTensorAxis) {
380 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381 targetShape[sourceTensorAxis] =
382 gatherDimension(dimSize: targetShape[sourceTensorAxis], shardCount: splitCount);
383 targetShape[targetTensorAxis] =
384 shardDimension(dimSize: targetShape[targetTensorAxis], shardCount: splitCount);
385 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
386}
387
388static std::tuple<TypedValue<ShapedType>, MeshSharding>
389moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
390 MeshSharding sourceSharding,
391 ShapedType sourceUnshardedShape,
392 TypedValue<ShapedType> sourceShard,
393 int64_t sourceTensorAxis,
394 int64_t targetTensorAxis, MeshAxis meshAxis) {
395 MLIRContext *ctx = builder.getContext();
396 builder.setInsertionPointAfterValue(sourceShard);
397
398 MeshSharding targetSharding = targetShardingInMoveLastAxis(
399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
400 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
402 targetTensorAxis);
403 Value allToAllResult = builder.create<AllToAllOp>(
404 RankedTensorType::get(allToAllResultShape.getShape(),
405 allToAllResultShape.getElementType()),
406 mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408 ShapedType targetShape =
409 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
410 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
411 builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412 return {targetShard, targetSharding};
413}
414
415static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
416tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
417 MeshSharding sourceSharding,
418 MeshSharding targetSharding,
419 ShapedType sourceUnshardedShape,
420 TypedValue<ShapedType> sourceShard) {
421 if (auto detectRes =
422 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
424 return moveLastSplitAxisInResharding(
425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426 sourceTensorAxis, targetTensorAxis, meshAxis);
427 }
428
429 return std::nullopt;
430}
431
432// Detect a change in the halo size (only) and create necessary operations if
433// needed. A changed halo sizes requires copying the "core" of the source tensor
434// into the "core" of the destination tensor followed by an update halo
435// operation.
436static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
437tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
438 MeshSharding sourceSharding,
439 MeshSharding targetSharding,
440 ShapedType sourceUnshardedShape,
441 TypedValue<ShapedType> sourceShard) {
442 // Currently handles only cases where halo sizes differ but everything else
443 // stays the same (from source to destination sharding).
444 if (!sourceSharding.equalSplitAndPartialAxes(rhs: targetSharding) ||
445 !sourceSharding.getPartialAxes().empty() ||
446 !targetSharding.getPartialAxes().empty() ||
447 !sourceSharding.getStaticShardedDimsOffsets().empty() ||
448 !targetSharding.getStaticShardedDimsOffsets().empty() ||
449 sourceSharding.equalHaloSizes(rhs: targetSharding)) {
450 return std::nullopt;
451 }
452
453 auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
454 auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
455 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456 assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457 !ShapedType::isDynamicShape(tgtHaloSizes) &&
458 sourceShard.getType().hasStaticShape()) &&
459 "dynamic shapes/halos are not supported yet for mesh-spmdization");
460 auto rank = sourceShard.getType().getRank();
461 auto splitAxes = sourceSharding.getSplitAxes();
462 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
463 strides(rank, 1), outShape(sourceShard.getType().getShape()),
464 coreShape(sourceShard.getType().getShape());
465
466 // Determine "core" of source and destination.
467 // The core is the local part of the shard excluding halo regions.
468 for (auto i = 0u; i < rank; ++i) {
469 if (i < splitAxes.size() && !splitAxes[i].empty()) {
470 if (!srcHaloSizes.empty()) {
471 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472 srcCoreOffs[i] = srcHaloSizes[i * 2];
473 }
474 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
475 outShape[i] =
476 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
477 }
478 }
479
480 // Extract core from source and copy into destination core.
481 auto noVals = ValueRange{};
482 auto initVal = builder.create<tensor::EmptyOp>(
483 sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484 auto core = builder.create<tensor::ExtractSliceOp>(
485 sourceShard.getLoc(),
486 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488 auto initOprnd = builder.create<tensor::InsertSliceOp>(
489 sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
490 coreShape, strides);
491
492 // Finally update the halo.
493 auto updateHaloResult =
494 builder
495 .create<UpdateHaloOp>(
496 sourceShard.getLoc(),
497 RankedTensorType::get(outShape,
498 sourceShard.getType().getElementType()),
499 initOprnd, mesh.getSymName(),
500 MeshAxesArrayAttr::get(builder.getContext(),
501 sourceSharding.getSplitAxes()),
502 targetSharding.getDynamicHaloSizes(),
503 targetSharding.getStaticHaloSizes())
504 .getResult();
505 return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
506 targetSharding);
507}
508
509// Handles only resharding on a 1D mesh.
510// Currently the sharded tensor axes must be exactly divisible by the single
511// mesh axis size.
512static TypedValue<ShapedType>
513reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
514 MeshSharding sourceSharding, MeshSharding targetSharding,
515 TypedValue<ShapedType> sourceUnshardedValue,
516 TypedValue<ShapedType> sourceShard) {
517 assert(sourceShard.getType() ==
518 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
519 [[maybe_unused]] ShapedType targetShardType =
520 shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
521 assert(sourceShard.getType().getRank() == targetShardType.getRank());
522 assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
523
524 auto [reducedSourceShard, reducedSourceSharding] =
525 handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
526 sourceShard);
527
528 if (reducedSourceSharding == targetSharding) {
529 return reducedSourceShard;
530 }
531
532 TypedValue<ShapedType> targetShard;
533 MeshSharding actualTargetSharding;
534 if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
535 targetSharding.getStaticShardedDimsOffsets().empty() &&
536 reducedSourceSharding.getStaticHaloSizes().empty() &&
537 targetSharding.getStaticHaloSizes().empty()) {
538 if (auto tryRes = tryMoveLastSplitAxisInResharding(
539 builder, mesh, reducedSourceSharding, targetSharding,
540 sourceUnshardedValue.getType(), reducedSourceShard)) {
541 std::tie(targetShard, actualTargetSharding) = tryRes.value();
542 } else if (auto tryRes = trySplitLastAxisInResharding(
543 builder, mesh, reducedSourceSharding, targetSharding,
544 reducedSourceShard)) {
545 std::tie(targetShard, actualTargetSharding) = tryRes.value();
546 } else if (auto tryRes = tryUnsplitLastAxisInResharding(
547 builder, mesh, reducedSourceSharding, targetSharding,
548 sourceUnshardedValue.getType(), reducedSourceShard)) {
549 std::tie(targetShard, actualTargetSharding) = tryRes.value();
550 }
551 }
552 assert(targetShard && "Did not find any pattern to apply.");
553 assert(actualTargetSharding == targetSharding);
554 assert(targetShard.getType() == targetShardType);
555 return targetShard;
556}
557
558TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
559 MeshSharding sourceSharding,
560 MeshSharding targetSharding,
561 TypedValue<ShapedType> sourceUnshardedValue,
562 TypedValue<ShapedType> sourceShard) {
563 // If source and destination sharding are the same, no need to do anything.
564 if (sourceSharding == targetSharding || (isFullReplication(sharding: sourceSharding) &&
565 isFullReplication(sharding: targetSharding))) {
566 return sourceShard;
567 }
568
569 // Tries to handle the case where the resharding is needed because the halo
570 // sizes are different. Supports arbitrary mesh dimensionality.
571 if (auto tryRes = tryUpdateHaloInResharding(
572 builder, mesh, sourceSharding, targetSharding,
573 sourceUnshardedValue.getType(), sourceShard)) {
574 return std::get<0>(tryRes.value()); // targetShard
575 }
576
577 // Resort to handling only 1D meshes since the general case is complicated if
578 // it needs to be communication efficient in terms of minimizing the data
579 // transfered between devices.
580 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
581 sourceUnshardedValue, sourceShard);
582}
583
584TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
585 ShardOp target,
586 TypedValue<ShapedType> sourceShardValue) {
587 assert(source.getResult() == target.getSrc());
588 auto sourceSharding = source.getSharding();
589 auto targetSharding = target.getSharding();
590 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
591 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
592 cast<TypedValue<ShapedType>>(source.getSrc()),
593 sourceShardValue);
594}
595
596TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
597 ShardOp target,
598 TypedValue<ShapedType> sourceShardValue,
599 SymbolTableCollection &symbolTableCollection) {
600 MeshOp srcMesh = getMesh(source, symbolTableCollection);
601 assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
602 return reshard(builder, srcMesh, source, target, sourceShardValue);
603}
604
605void reshardingRegisterDependentDialects(DialectRegistry &registry) {
606 registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
607}
608
609#define GEN_PASS_DEF_SPMDIZATION
610#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
611
612using UnshardedToShardedValueMap = DenseMap<Value, Value>;
613
614// Get the types of block arguments for an spmdized block.
615// Reads the sharding annotations of the arguments to deduce the sharded types.
616// Types that are not ranked tensors are left unchanged.
617SmallVector<Type>
618shardedBlockArgumentTypes(Block &block,
619 SymbolTableCollection &symbolTableCollection) {
620 SmallVector<Type> res;
621 llvm::transform(
622 block.getArguments(), std::back_inserter(res),
623 [&symbolTableCollection](BlockArgument arg) {
624 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
625 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
626 return arg.getType();
627 }
628
629 assert(rankedTensorArg.hasOneUse());
630 Operation *useOp = *rankedTensorArg.getUsers().begin();
631 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
632 assert(shardOp);
633 MeshOp mesh = getMesh(shardOp, symbolTableCollection);
634 return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
635 shardOp.getSharding()));
636 });
637 return res;
638}
639
640static LogicalResult spmdizeOperation(
641 Operation &op, ArrayRef<Value> spmdizedOperands,
642 ArrayRef<MeshSharding> operandShardings,
643 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
644 SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
645 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
646 if (!shardingInterface) {
647 // If there is no sharding interface we are conservative and assume that
648 // the op should be fully replicated no all devices.
649 spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
650 resultShardings, spmdizationMap,
651 symbolTableCollection, builder);
652 } else {
653 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
654 resultShardings, spmdizationMap,
655 symbolTableCollection, builder))) {
656 return failure();
657 }
658 }
659
660 assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
661 return spmdizationMap.contains(result);
662 }));
663
664 return success();
665}
666
667// Retrieve the sharding annotations for the operands of the given operation.
668// If the type is not a ranked tensor it is not require to have an annotation.
669static std::vector<MeshSharding> getOperandShardings(Operation &op) {
670 std::vector<MeshSharding> res;
671 res.reserve(op.getNumOperands());
672 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
673 TypedValue<RankedTensorType> rankedTensor =
674 dyn_cast<TypedValue<RankedTensorType>>(operand);
675 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
676 return MeshSharding();
677 }
678
679 Operation *definingOp = operand.getDefiningOp();
680 assert(definingOp);
681 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
682 return MeshSharding(shardOp.getSharding());
683 });
684 return res;
685}
686
687// Retrieve the sharding annotations for the results of the given operation.
688// If the type is not a ranked tensor it is not require to have an annotation.
689static std::vector<MeshSharding> getResultShardings(Operation &op) {
690 std::vector<MeshSharding> res;
691 res.reserve(op.getNumResults());
692 llvm::transform(
693 op.getResults(), std::back_inserter(res), [&op](OpResult result) {
694 if (!result.hasOneUse() || result.use_empty()) {
695 return MeshSharding();
696 }
697 TypedValue<RankedTensorType> rankedTensor =
698 dyn_cast<TypedValue<RankedTensorType>>(result);
699 if (!rankedTensor) {
700 return MeshSharding();
701 }
702 Operation *userOp = *result.getUsers().begin();
703 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
704 if (shardOp) {
705 return MeshSharding(shardOp.getSharding());
706 }
707 if (rankedTensor.getType().getRank() == 0) {
708 // This is a 0d tensor result without explicit sharding.
709 // Find mesh symbol from operands, if any.
710 // Shardings without mesh are not always fully supported yet.
711 for (auto operand : op.getOperands()) {
712 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
713 return MeshSharding(sharding.getMeshAttr());
714 }
715 }
716 }
717 return MeshSharding();
718 });
719 return res;
720}
721
722static LogicalResult
723spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
724 SymbolTableCollection &symbolTableCollection,
725 OpBuilder &builder) {
726 Value targetSpmdValue;
727
728 // Check if 2 shard ops are chained. If not there is no need for resharding
729 // as the source and target shared the same sharding.
730 ShardOp srcShardOp =
731 dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
732 if (!srcShardOp) {
733 targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
734 } else {
735 // Insert resharding.
736 TypedValue<ShapedType> srcSpmdValue =
737 cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
738 targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
739 symbolTableCollection);
740 }
741
742 assert(!spmdizationMap.contains(shardOp.getResult()));
743 spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
744 return success();
745}
746
747static LogicalResult
748spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
749 SymbolTableCollection &symbolTableCollection,
750 OpBuilder &builder) {
751 if (isa<ShardingOp>(op)) {
752 return success();
753 }
754 if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
755 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
756 if (!shardOp) {
757 return op.emitError(message: "expected a shard op as source of get_sharding");
758 }
759 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
760 spmdizationMap.map(op.getResult(idx: 0), newSharding->getResult(0));
761 return success();
762 }
763
764 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
765 if (shardOp) {
766 return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
767 builder);
768 }
769
770 SmallVector<Value> spmdizedOperands;
771 llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
772 [&spmdizationMap](Value operand) {
773 assert(spmdizationMap.contains(operand));
774 return spmdizationMap.lookup(operand);
775 });
776 return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
777 getResultShardings(op), spmdizationMap,
778 symbolTableCollection, builder);
779}
780
781static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
782 SymbolTableCollection &symbolTableCollection,
783 OpBuilder &builder) {
784
785 SmallVector<Location> argLocations;
786 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
787 [](BlockArgument arg) { return arg.getLoc(); });
788 Block *newBlock = builder.createBlock(
789 block.getParent(), {},
790 shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
791 for (auto [unshardedBlockArg, spmdizedBlockArg] :
792 llvm::zip(block.getArguments(), newBlock->getArguments())) {
793 spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
794 }
795
796 OpBuilder::InsertionGuard insertionGuard(builder);
797 builder.setInsertionPointToEnd(newBlock);
798 for (Operation &op : block.getOperations()) {
799 if (failed(Result: spmdizeOperation(op, spmdizationMap, symbolTableCollection,
800 builder))) {
801 return failure();
802 }
803 }
804
805 return success();
806}
807
808static LogicalResult
809spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
810 SymbolTableCollection &symbolTableCollection) {
811 OpBuilder builder(op.getFunctionBody());
812
813 // Snapshot the original blocks to not mess up the iteration when adding new
814 // blocks.
815 SmallVector<Block *> originalBlocks;
816 for (Block &b : op.getBlocks()) {
817 if (llvm::any_of(b.getOperations(),
818 [](Operation &op) { return isa<ShardOp>(op); })) {
819 originalBlocks.push_back(&b);
820 }
821 }
822
823 for (Block *block : originalBlocks) {
824 if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
825 builder))) {
826 return failure();
827 }
828 }
829
830 for (Block *block : originalBlocks) {
831 block->erase();
832 }
833
834 // Find a return op and change the function results signature to its operands
835 // signature.
836 Operation *returnOp = nullptr;
837 for (Block &block : op.getFunctionBody()) {
838 if (block.empty()) {
839 continue;
840 }
841
842 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
843 returnOp = &block.back();
844 break;
845 }
846 }
847 if (returnOp) {
848 op.setType(FunctionType::get(
849 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
850 returnOp->getOperandTypes()));
851 }
852
853 return success();
854}
855
856namespace {
857
858struct Spmdization : public impl::SpmdizationBase<Spmdization> {
859 void runOnOperation() override {
860 IRMapping spmdizationMap;
861 SymbolTableCollection symbolTableCollection;
862 if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
863 symbolTableCollection))) {
864 return signalPassFailure();
865 }
866 }
867
868 void getDependentDialects(DialectRegistry &registry) const override {
869 reshardingRegisterDependentDialects(registry);
870 registry.insert<mesh::MeshDialect>();
871 }
872};
873
874} // namespace
875
876} // namespace mlir::mesh
877

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp