1//===- Spmdization.cpp --------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
10
11#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/BuiltinTypeInterfaces.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/ImplicitLocOpBuilder.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/MLIRContext.h"
24#include "mlir/IR/SymbolTable.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Interfaces/ControlFlowInterfaces.h"
27#include "mlir/Interfaces/FunctionInterfaces.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Support/LLVM.h"
30#include "llvm/ADT/DenseSet.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/Support/Casting.h"
34#include <iterator>
35#include <optional>
36#include <tuple>
37#include <type_traits>
38
39namespace mlir::mesh {
40
41template <typename SourceAxes, typename TargetAxes>
42static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
43 const TargetAxes &targetAxes) {
44 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
45 return sourceAxes.contains(targetAxis);
46 });
47}
48
49// Return the reduced value and its corresponding sharding.
50// Example:
51// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
52// targetSharding = <@mesh_1d, [[]]>
53// Then will apply all-reduce on the source value
54// and return it with the sharding <@mesh_1d, [[0]]>.
55static std::tuple<TypedValue<ShapedType>, MeshSharding>
56handlePartialAxesDuringResharding(OpBuilder &builder,
57 MeshSharding sourceSharding,
58 MeshSharding targetSharding,
59 TypedValue<ShapedType> sourceShard) {
60 if (sourceSharding.getPartialAxes().empty() &&
61 targetSharding.getPartialAxes().empty()) {
62 return {sourceShard, sourceSharding};
63 }
64 assert(targetSharding.getPartialAxes().empty() ||
65 (!sourceSharding.getPartialAxes().empty() &&
66 sourceSharding.getPartialType() == targetSharding.getPartialType()));
67 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
68 using AxisSet = llvm::SmallDenseSet<Axis>;
69 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
70 sourceSharding.getPartialAxes().end());
71 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
72 targetSharding.getPartialAxes().end());
73 assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
74 targetShardingPartialAxesSet));
75 llvm::SmallVector<MeshAxis> allReduceMeshAxes;
76 llvm::copy_if(Range&: sourceShardingPartialAxesSet,
77 Out: std::back_inserter(x&: allReduceMeshAxes),
78 P: [&targetShardingPartialAxesSet](Axis a) {
79 return !targetShardingPartialAxesSet.contains(V: a);
80 });
81 if (allReduceMeshAxes.empty()) {
82 return {sourceShard, sourceSharding};
83 }
84
85 builder.setInsertionPointAfterValue(sourceShard);
86 TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
87 Val: builder
88 .create<AllReduceOp>(location: sourceShard.getLoc(), args: sourceShard.getType(),
89 args: sourceSharding.getMeshAttr().getLeafReference(),
90 args&: allReduceMeshAxes, args&: sourceShard,
91 args: sourceSharding.getPartialType())
92 .getResult());
93
94 llvm::SmallVector<MeshAxis> remainingPartialAxes;
95 llvm::copy_if(Range&: sourceShardingPartialAxesSet,
96 Out: std::back_inserter(x&: allReduceMeshAxes),
97 P: [&targetShardingPartialAxesSet](Axis a) {
98 return targetShardingPartialAxesSet.contains(V: a);
99 });
100 MeshSharding resultSharding = MeshSharding::get(
101 mesh_: sourceSharding.getMeshAttr(), split_axes_: sourceSharding.getSplitAxes(),
102 partial_axes_: remainingPartialAxes, partial_type_: sourceSharding.getPartialType());
103 return {resultValue, resultSharding};
104}
105
106static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
107 MeshSharding sourceSharding,
108 int64_t splitTensorAxis,
109 MeshAxis splitMeshAxis) {
110 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
111 llvm::to_vector(Range: sourceSharding.getSplitAxes());
112 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
113 splitTensorAxis) {
114 targetShardingSplitAxes.push_back(Elt: MeshAxesAttr::get(context: ctx, content: {}));
115 }
116 auto targetSplitAxes =
117 llvm::to_vector(Range: targetShardingSplitAxes[splitTensorAxis].asArrayRef());
118 targetSplitAxes.push_back(Elt: splitMeshAxis);
119 targetShardingSplitAxes[splitTensorAxis] =
120 MeshAxesAttr::get(context: ctx, content: targetSplitAxes);
121 return MeshSharding::get(
122 mesh_: sourceSharding.getMeshAttr(), split_axes_: targetShardingSplitAxes,
123 partial_axes_: sourceSharding.getPartialAxes(), partial_type_: sourceSharding.getPartialType());
124}
125
126// Split a replicated tensor along a mesh axis.
127// E.g. [[0, 1]] -> [[0, 1, 2]].
128// Returns the spmdized target value with its sharding.
129static std::tuple<TypedValue<ShapedType>, MeshSharding>
130splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
131 MeshSharding sourceSharding,
132 TypedValue<ShapedType> sourceShard, MeshOp mesh,
133 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
134 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
135 Val: builder
136 .create<AllSliceOp>(args&: sourceShard, args&: mesh,
137 args: ArrayRef<MeshAxis>(splitMeshAxis),
138 args&: splitTensorAxis)
139 .getResult());
140 MeshSharding targetSharding = targetShardingInSplitLastAxis(
141 ctx: builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
142 return {targetShard, targetSharding};
143}
144
145// Detect if the resharding is of type e.g.
146// [[0, 1]] -> [[0, 1, 2]].
147// If detected, returns the corresponding tensor axis mesh axis pair.
148// Does not detect insertions like
149// [[0, 1]] -> [[0, 2, 1]].
150static std::optional<std::tuple<int64_t, MeshAxis>>
151detectSplitLastAxisInResharding(MeshSharding sourceSharding,
152 MeshSharding targetSharding) {
153 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
154 ++tensorAxis) {
155 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
156 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
157 targetSharding.getSplitAxes()[tensorAxis].size()) {
158 continue;
159 }
160 if (!llvm::equal(
161 LRange: sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
162 RRange: llvm::make_range(
163 x: targetSharding.getSplitAxes()[tensorAxis]
164 .asArrayRef()
165 .begin(),
166 y: targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
167 1))) {
168 continue;
169 }
170 } else {
171 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
172 continue;
173 }
174 }
175 return std::make_tuple(
176 args&: tensorAxis,
177 args: targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
178 }
179 return std::nullopt;
180}
181
182static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
183trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
184 MeshSharding sourceSharding,
185 MeshSharding targetSharding,
186 TypedValue<ShapedType> sourceShard) {
187 if (auto detectRes =
188 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
189 auto [tensorAxis, meshAxis] = detectRes.value();
190 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
191 splitTensorAxis: tensorAxis, splitMeshAxis: meshAxis);
192 }
193
194 return std::nullopt;
195}
196
197// Detect if the resharding is of type e.g.
198// [[0, 1, 2]] -> [[0, 1]].
199// If detected, returns the corresponding tensor axis mesh axis pair.
200static std::optional<std::tuple<int64_t, MeshAxis>>
201detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
202 MeshSharding targetSharding) {
203 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
204 ++tensorAxis) {
205 if (targetSharding.getSplitAxes().size() > tensorAxis) {
206 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
207 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
208 continue;
209 if (!llvm::equal(
210 LRange: llvm::make_range(
211 x: sourceSharding.getSplitAxes()[tensorAxis]
212 .asArrayRef()
213 .begin(),
214 y: sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
215 1),
216 RRange: targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
217 continue;
218 } else {
219 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
220 continue;
221 }
222 return std::make_tuple(
223 args&: tensorAxis,
224 args: sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
225 }
226 return std::nullopt;
227}
228
229static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
230 MeshSharding sourceSharding,
231 int64_t splitTensorAxis) {
232 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
233 llvm::to_vector(Range: sourceSharding.getSplitAxes());
234 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
235 splitTensorAxis);
236 auto targetSplitAxes =
237 llvm::to_vector(Range: targetShardingSplitAxes[splitTensorAxis].asArrayRef());
238
239 targetSplitAxes.pop_back();
240 targetShardingSplitAxes[splitTensorAxis] =
241 MeshAxesAttr::get(context: ctx, content: targetSplitAxes);
242 return MeshSharding::get(
243 mesh_: sourceSharding.getMeshAttr(), split_axes_: targetShardingSplitAxes,
244 partial_axes_: sourceSharding.getPartialAxes(), partial_type_: sourceSharding.getPartialType());
245}
246
247static ShapedType allGatherResultShapeInUnsplitLastAxis(
248 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
249 SmallVector<int64_t> targetShape = llvm::to_vector(Range: sourceShape.getShape());
250 targetShape[splitTensorAxis] =
251 gatherDimension(dimSize: targetShape[splitTensorAxis], shardCount: splitCount);
252 return sourceShape.cloneWith(shape: targetShape, elementType: sourceShape.getElementType());
253}
254
255static std::tuple<TypedValue<ShapedType>, MeshSharding>
256unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
257 MeshSharding sourceSharding,
258 ShapedType sourceUnshardedShape,
259 TypedValue<ShapedType> sourceShard, MeshOp mesh,
260 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
261 MLIRContext *ctx = builder.getContext();
262 builder.setInsertionPointAfterValue(sourceShard);
263
264 MeshSharding targetSharding =
265 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
266 ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
267 sourceShape: sourceShard.getType(), splitCount: mesh.getShape()[splitMeshAxis], splitTensorAxis);
268 Value allGatherResult = builder.create<AllGatherOp>(
269 args: RankedTensorType::get(shape: allGatherResultShape.getShape(),
270 elementType: allGatherResultShape.getElementType()),
271 args: mesh.getSymName(), args: SmallVector<MeshAxis>({splitMeshAxis}), args&: sourceShard,
272 args: APInt(64, splitTensorAxis));
273 ShapedType targetShape =
274 shardShapedType(shape: sourceUnshardedShape, mesh, sharding: targetSharding);
275 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
276 Val: builder.create<tensor::CastOp>(args&: targetShape, args&: allGatherResult).getResult());
277 return {targetShard, targetSharding};
278}
279
280static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
281tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
282 MeshSharding sourceSharding,
283 MeshSharding targetSharding,
284 ShapedType sourceUnshardedShape,
285 TypedValue<ShapedType> sourceShard) {
286 if (auto detectRes =
287 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
288 auto [tensorAxis, meshAxis] = detectRes.value();
289 return unsplitLastAxisInResharding(builder, sourceSharding,
290 sourceUnshardedShape, sourceShard, mesh,
291 splitTensorAxis: tensorAxis, splitMeshAxis: meshAxis);
292 }
293
294 return std::nullopt;
295}
296
297// Detect if the resharding is of type e.g.
298// [[0, 1], [2]] -> [[0], [1, 2]].
299// Only moving the last axis counts.
300// If detected, returns the corresponding (source_tensor_axis,
301// target_tensor_axis, mesh_axis) tuple.
302static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
303detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
304 MeshSharding targetSharding) {
305 for (size_t sourceTensorAxis = 0;
306 sourceTensorAxis < sourceSharding.getSplitAxes().size();
307 ++sourceTensorAxis) {
308 for (size_t targetTensorAxis = 0;
309 targetTensorAxis < targetSharding.getSplitAxes().size();
310 ++targetTensorAxis) {
311 if (sourceTensorAxis == targetTensorAxis)
312 continue;
313 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
314 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
315 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
316 targetSharding.getSplitAxes()[targetTensorAxis]
317 .asArrayRef()
318 .back())
319 continue;
320 if (!llvm::equal(
321 LRange: llvm::make_range(x: sourceSharding.getSplitAxes()[sourceTensorAxis]
322 .asArrayRef()
323 .begin(),
324 y: sourceSharding.getSplitAxes()[sourceTensorAxis]
325 .asArrayRef()
326 .end() -
327 1),
328 RRange: llvm::make_range(x: targetSharding.getSplitAxes()[targetTensorAxis]
329 .asArrayRef()
330 .begin(),
331 y: targetSharding.getSplitAxes()[targetTensorAxis]
332 .asArrayRef()
333 .end() -
334 1)))
335 continue;
336 return std::make_tuple(
337 args&: sourceTensorAxis, args&: targetTensorAxis,
338 args: sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
339 }
340 }
341 return std::nullopt;
342}
343
344static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
345 MeshSharding sourceSharding,
346 int64_t sourceTensorAxis,
347 int64_t targetTensorAxis) {
348 SmallVector<MeshAxesAttr> targetShardingSplitAxes =
349 llvm::to_vector(Range: sourceSharding.getSplitAxes());
350 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
351 targetTensorAxis) {
352 targetShardingSplitAxes.push_back(Elt: MeshAxesAttr::get(context: ctx, content: {}));
353 }
354
355 auto sourceSplitAxes =
356 llvm::to_vector(Range: targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
357 assert(!sourceSplitAxes.empty());
358 auto meshAxis = sourceSplitAxes.back();
359 sourceSplitAxes.pop_back();
360 targetShardingSplitAxes[sourceTensorAxis] =
361 MeshAxesAttr::get(context: ctx, content: sourceSplitAxes);
362
363 auto targetSplitAxes =
364 llvm::to_vector(Range: targetShardingSplitAxes[targetTensorAxis].asArrayRef());
365 targetSplitAxes.push_back(Elt: meshAxis);
366 targetShardingSplitAxes[targetTensorAxis] =
367 MeshAxesAttr::get(context: ctx, content: targetSplitAxes);
368
369 return MeshSharding::get(
370 mesh_: sourceSharding.getMeshAttr(), split_axes_: targetShardingSplitAxes,
371 partial_axes_: sourceSharding.getPartialAxes(), partial_type_: sourceSharding.getPartialType());
372}
373
374static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
375 int64_t splitCount,
376 int64_t sourceTensorAxis,
377 int64_t targetTensorAxis) {
378 SmallVector<int64_t> targetShape = llvm::to_vector(Range: sourceShape.getShape());
379 targetShape[sourceTensorAxis] =
380 gatherDimension(dimSize: targetShape[sourceTensorAxis], shardCount: splitCount);
381 targetShape[targetTensorAxis] =
382 shardDimension(dimSize: targetShape[targetTensorAxis], shardCount: splitCount);
383 return sourceShape.cloneWith(shape: targetShape, elementType: sourceShape.getElementType());
384}
385
386static std::tuple<TypedValue<ShapedType>, MeshSharding>
387moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
388 MeshSharding sourceSharding,
389 ShapedType sourceUnshardedShape,
390 TypedValue<ShapedType> sourceShard,
391 int64_t sourceTensorAxis,
392 int64_t targetTensorAxis, MeshAxis meshAxis) {
393 MLIRContext *ctx = builder.getContext();
394 builder.setInsertionPointAfterValue(sourceShard);
395
396 MeshSharding targetSharding = targetShardingInMoveLastAxis(
397 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
398 ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
399 sourceShape: sourceShard.getType(), splitCount: mesh.getShape()[meshAxis], sourceTensorAxis,
400 targetTensorAxis);
401 Value allToAllResult = builder.create<AllToAllOp>(
402 args: RankedTensorType::get(shape: allToAllResultShape.getShape(),
403 elementType: allToAllResultShape.getElementType()),
404 args: mesh.getSymName(), args: SmallVector<MeshAxis>({meshAxis}), args&: sourceShard,
405 args: APInt(64, targetTensorAxis), args: APInt(64, sourceTensorAxis));
406 ShapedType targetShape =
407 shardShapedType(shape: sourceUnshardedShape, mesh, sharding: targetSharding);
408 TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
409 Val: builder.create<tensor::CastOp>(args&: targetShape, args&: allToAllResult).getResult());
410 return {targetShard, targetSharding};
411}
412
413static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
414tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
415 MeshSharding sourceSharding,
416 MeshSharding targetSharding,
417 ShapedType sourceUnshardedShape,
418 TypedValue<ShapedType> sourceShard) {
419 if (auto detectRes =
420 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
421 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
422 return moveLastSplitAxisInResharding(
423 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
424 sourceTensorAxis, targetTensorAxis, meshAxis);
425 }
426
427 return std::nullopt;
428}
429
430// Detect a change in the halo size (only) and create necessary operations if
431// needed. A changed halo sizes requires copying the "core" of the source tensor
432// into the "core" of the destination tensor followed by an update halo
433// operation.
434static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
435tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
436 MeshSharding sourceSharding,
437 MeshSharding targetSharding,
438 ShapedType sourceUnshardedShape,
439 TypedValue<ShapedType> sourceShard) {
440 // Currently handles only cases where halo sizes differ but everything else
441 // stays the same (from source to destination sharding).
442 if (!sourceSharding.equalSplitAndPartialAxes(rhs: targetSharding) ||
443 !sourceSharding.getPartialAxes().empty() ||
444 !targetSharding.getPartialAxes().empty() ||
445 !sourceSharding.getStaticShardedDimsOffsets().empty() ||
446 !targetSharding.getStaticShardedDimsOffsets().empty() ||
447 sourceSharding.equalHaloSizes(rhs: targetSharding)) {
448 return std::nullopt;
449 }
450
451 auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
452 auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
453 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
454 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
455 ShapedType::isStaticShape(tgtHaloSizes) &&
456 sourceShard.getType().hasStaticShape()) &&
457 "dynamic shapes/halos are not supported yet for mesh-spmdization");
458 auto rank = sourceShard.getType().getRank();
459 auto splitAxes = sourceSharding.getSplitAxes();
460 SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
461 strides(rank, 1), outShape(sourceShard.getType().getShape()),
462 coreShape(sourceShard.getType().getShape());
463
464 // Determine "core" of source and destination.
465 // The core is the local part of the shard excluding halo regions.
466 for (auto i = 0u; i < rank; ++i) {
467 if (i < splitAxes.size() && !splitAxes[i].empty()) {
468 if (!srcHaloSizes.empty()) {
469 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
470 srcCoreOffs[i] = srcHaloSizes[i * 2];
471 }
472 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
473 outShape[i] =
474 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
475 }
476 }
477
478 // Extract core from source and copy into destination core.
479 auto noVals = ValueRange{};
480 auto initVal = builder.create<tensor::EmptyOp>(
481 location: sourceShard.getLoc(), args&: outShape, args: sourceShard.getType().getElementType());
482 auto core = builder.create<tensor::ExtractSliceOp>(
483 location: sourceShard.getLoc(),
484 args: RankedTensorType::get(shape: coreShape, elementType: sourceShard.getType().getElementType()),
485 args&: sourceShard, args&: noVals, args&: noVals, args&: noVals, args&: srcCoreOffs, args&: coreShape, args&: strides);
486 auto initOprnd = builder.create<tensor::InsertSliceOp>(
487 location: sourceShard.getLoc(), args&: core, args&: initVal, args&: noVals, args&: noVals, args&: noVals, args&: tgtCoreOffs,
488 args&: coreShape, args&: strides);
489
490 // Finally update the halo.
491 auto updateHaloResult =
492 builder
493 .create<UpdateHaloOp>(
494 location: sourceShard.getLoc(),
495 args: RankedTensorType::get(shape: outShape,
496 elementType: sourceShard.getType().getElementType()),
497 args&: initOprnd, args: mesh.getSymName(),
498 args: MeshAxesArrayAttr::get(context: builder.getContext(),
499 axes: sourceSharding.getSplitAxes()),
500 args: targetSharding.getDynamicHaloSizes(),
501 args: targetSharding.getStaticHaloSizes())
502 .getResult();
503 return std::make_tuple(args: cast<TypedValue<ShapedType>>(Val&: updateHaloResult),
504 args&: targetSharding);
505}
506
507// Handles only resharding on a 1D mesh.
508// Currently the sharded tensor axes must be exactly divisible by the single
509// mesh axis size.
510static TypedValue<ShapedType>
511reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
512 MeshSharding sourceSharding, MeshSharding targetSharding,
513 TypedValue<ShapedType> sourceUnshardedValue,
514 TypedValue<ShapedType> sourceShard) {
515 assert(sourceShard.getType() ==
516 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
517 [[maybe_unused]] ShapedType targetShardType =
518 shardShapedType(shape: sourceUnshardedValue.getType(), mesh, sharding: targetSharding);
519 assert(sourceShard.getType().getRank() == targetShardType.getRank());
520 assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
521
522 auto [reducedSourceShard, reducedSourceSharding] =
523 handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
524 sourceShard);
525
526 if (reducedSourceSharding == targetSharding) {
527 return reducedSourceShard;
528 }
529
530 TypedValue<ShapedType> targetShard;
531 MeshSharding actualTargetSharding;
532 if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
533 targetSharding.getStaticShardedDimsOffsets().empty() &&
534 reducedSourceSharding.getStaticHaloSizes().empty() &&
535 targetSharding.getStaticHaloSizes().empty()) {
536 if (auto tryRes = tryMoveLastSplitAxisInResharding(
537 builder, mesh, sourceSharding: reducedSourceSharding, targetSharding,
538 sourceUnshardedShape: sourceUnshardedValue.getType(), sourceShard: reducedSourceShard)) {
539 std::tie(args&: targetShard, args&: actualTargetSharding) = tryRes.value();
540 } else if (auto tryRes = trySplitLastAxisInResharding(
541 builder, mesh, sourceSharding: reducedSourceSharding, targetSharding,
542 sourceShard: reducedSourceShard)) {
543 std::tie(args&: targetShard, args&: actualTargetSharding) = tryRes.value();
544 } else if (auto tryRes = tryUnsplitLastAxisInResharding(
545 builder, mesh, sourceSharding: reducedSourceSharding, targetSharding,
546 sourceUnshardedShape: sourceUnshardedValue.getType(), sourceShard: reducedSourceShard)) {
547 std::tie(args&: targetShard, args&: actualTargetSharding) = tryRes.value();
548 }
549 }
550 assert(targetShard && "Did not find any pattern to apply.");
551 assert(actualTargetSharding == targetSharding);
552 assert(targetShard.getType() == targetShardType);
553 return targetShard;
554}
555
556TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
557 MeshSharding sourceSharding,
558 MeshSharding targetSharding,
559 TypedValue<ShapedType> sourceUnshardedValue,
560 TypedValue<ShapedType> sourceShard) {
561 // If source and destination sharding are the same, no need to do anything.
562 if (sourceSharding == targetSharding || (isFullReplication(sharding: sourceSharding) &&
563 isFullReplication(sharding: targetSharding))) {
564 return sourceShard;
565 }
566
567 // Tries to handle the case where the resharding is needed because the halo
568 // sizes are different. Supports arbitrary mesh dimensionality.
569 if (auto tryRes = tryUpdateHaloInResharding(
570 builder, mesh, sourceSharding, targetSharding,
571 sourceUnshardedShape: sourceUnshardedValue.getType(), sourceShard)) {
572 return std::get<0>(t&: tryRes.value()); // targetShard
573 }
574
575 // Resort to handling only 1D meshes since the general case is complicated if
576 // it needs to be communication efficient in terms of minimizing the data
577 // transfered between devices.
578 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
579 sourceUnshardedValue, sourceShard);
580}
581
582TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
583 ShardOp target,
584 TypedValue<ShapedType> sourceShardValue) {
585 assert(source.getResult() == target.getSrc());
586 auto sourceSharding = source.getSharding();
587 auto targetSharding = target.getSharding();
588 ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
589 return reshard(builder&: implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
590 sourceUnshardedValue: cast<TypedValue<ShapedType>>(Val: source.getSrc()),
591 sourceShard: sourceShardValue);
592}
593
594TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
595 ShardOp target,
596 TypedValue<ShapedType> sourceShardValue,
597 SymbolTableCollection &symbolTableCollection) {
598 MeshOp srcMesh = getMesh(op: source, symbolTableCollection);
599 assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
600 return reshard(builder, mesh: srcMesh, source, target, sourceShardValue);
601}
602
603void reshardingRegisterDependentDialects(DialectRegistry &registry) {
604 registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
605}
606
607#define GEN_PASS_DEF_SPMDIZATION
608#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
609
610using UnshardedToShardedValueMap = DenseMap<Value, Value>;
611
612// Get the types of block arguments for an spmdized block.
613// Reads the sharding annotations of the arguments to deduce the sharded types.
614// Types that are not ranked tensors are left unchanged.
615SmallVector<Type>
616shardedBlockArgumentTypes(Block &block,
617 SymbolTableCollection &symbolTableCollection) {
618 SmallVector<Type> res;
619 llvm::transform(
620 Range: block.getArguments(), d_first: std::back_inserter(x&: res),
621 F: [&symbolTableCollection](BlockArgument arg) {
622 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(Val&: arg);
623 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
624 return arg.getType();
625 }
626
627 assert(rankedTensorArg.hasOneUse());
628 Operation *useOp = *rankedTensorArg.getUsers().begin();
629 ShardOp shardOp = llvm::dyn_cast<ShardOp>(Val: useOp);
630 assert(shardOp);
631 MeshOp mesh = getMesh(op: shardOp, symbolTableCollection);
632 return cast<Type>(Val: shardShapedType(shape: rankedTensorArg.getType(), mesh,
633 sharding: shardOp.getSharding()));
634 });
635 return res;
636}
637
638static LogicalResult spmdizeOperation(
639 Operation &op, ArrayRef<Value> spmdizedOperands,
640 ArrayRef<MeshSharding> operandShardings,
641 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
642 SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
643 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(Val&: op);
644 if (!shardingInterface) {
645 // If there is no sharding interface we are conservative and assume that
646 // the op should be fully replicated no all devices.
647 spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
648 resultShardings, spmdizationMap,
649 symbolTable&: symbolTableCollection, builder);
650 } else {
651 if (failed(Result: shardingInterface.spmdize(spmdizedOperands, operandShardings,
652 resultShardings, spmdizationMap,
653 symbolTableCollection, builder))) {
654 return failure();
655 }
656 }
657
658 assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
659 return spmdizationMap.contains(result);
660 }));
661
662 return success();
663}
664
665// Retrieve the sharding annotations for the operands of the given operation.
666// If the type is not a ranked tensor it is not require to have an annotation.
667static std::vector<MeshSharding> getOperandShardings(Operation &op) {
668 std::vector<MeshSharding> res;
669 res.reserve(n: op.getNumOperands());
670 llvm::transform(Range: op.getOperands(), d_first: std::back_inserter(x&: res), F: [](Value operand) {
671 TypedValue<RankedTensorType> rankedTensor =
672 dyn_cast<TypedValue<RankedTensorType>>(Val&: operand);
673 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
674 return MeshSharding();
675 }
676
677 Operation *definingOp = operand.getDefiningOp();
678 assert(definingOp);
679 ShardOp shardOp = llvm::cast<ShardOp>(Val: definingOp);
680 return MeshSharding(shardOp.getSharding());
681 });
682 return res;
683}
684
685// Retrieve the sharding annotations for the results of the given operation.
686// If the type is not a ranked tensor it is not require to have an annotation.
687static std::vector<MeshSharding> getResultShardings(Operation &op) {
688 std::vector<MeshSharding> res;
689 res.reserve(n: op.getNumResults());
690 llvm::transform(
691 Range: op.getResults(), d_first: std::back_inserter(x&: res), F: [&op](OpResult result) {
692 if (!result.hasOneUse() || result.use_empty()) {
693 return MeshSharding();
694 }
695 TypedValue<RankedTensorType> rankedTensor =
696 dyn_cast<TypedValue<RankedTensorType>>(Val&: result);
697 if (!rankedTensor) {
698 return MeshSharding();
699 }
700 Operation *userOp = *result.getUsers().begin();
701 ShardOp shardOp = llvm::dyn_cast<ShardOp>(Val: userOp);
702 if (shardOp) {
703 return MeshSharding(shardOp.getSharding());
704 }
705 if (rankedTensor.getType().getRank() == 0) {
706 // This is a 0d tensor result without explicit sharding.
707 // Find mesh symbol from operands, if any.
708 // Shardings without mesh are not always fully supported yet.
709 for (auto operand : op.getOperands()) {
710 if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
711 return MeshSharding(sharding.getMeshAttr());
712 }
713 }
714 }
715 return MeshSharding();
716 });
717 return res;
718}
719
720static LogicalResult
721spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
722 SymbolTableCollection &symbolTableCollection,
723 OpBuilder &builder) {
724 Value targetSpmdValue;
725
726 // Check if 2 shard ops are chained. If not there is no need for resharding
727 // as the source and target shared the same sharding.
728 ShardOp srcShardOp =
729 dyn_cast_or_null<ShardOp>(Val: shardOp.getSrc().getDefiningOp());
730 if (!srcShardOp) {
731 targetSpmdValue = spmdizationMap.lookup(from: shardOp.getSrc());
732 } else {
733 // Insert resharding.
734 TypedValue<ShapedType> srcSpmdValue =
735 cast<TypedValue<ShapedType>>(Val: spmdizationMap.lookup(from: srcShardOp));
736 targetSpmdValue = reshard(builder, source: srcShardOp, target: shardOp, sourceShardValue: srcSpmdValue,
737 symbolTableCollection);
738 }
739
740 assert(!spmdizationMap.contains(shardOp.getResult()));
741 spmdizationMap.map(from: shardOp.getResult(), to: targetSpmdValue);
742 return success();
743}
744
745static LogicalResult
746spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
747 SymbolTableCollection &symbolTableCollection,
748 OpBuilder &builder) {
749 if (isa<ShardingOp>(Val: op)) {
750 return success();
751 }
752 if (auto getShardingOp = dyn_cast<GetShardingOp>(Val&: op)) {
753 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
754 if (!shardOp) {
755 return op.emitError(message: "expected a shard op as source of get_sharding");
756 }
757 auto newSharding = builder.clone(op&: *shardOp.getSharding().getDefiningOp());
758 spmdizationMap.map(from: op.getResult(idx: 0), to: newSharding->getResult(idx: 0));
759 return success();
760 }
761
762 ShardOp shardOp = llvm::dyn_cast<ShardOp>(Val&: op);
763 if (shardOp) {
764 return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
765 builder);
766 }
767
768 SmallVector<Value> spmdizedOperands;
769 llvm::transform(Range: op.getOperands(), d_first: std::back_inserter(x&: spmdizedOperands),
770 F: [&spmdizationMap](Value operand) {
771 assert(spmdizationMap.contains(operand));
772 return spmdizationMap.lookup(from: operand);
773 });
774 return spmdizeOperation(op, spmdizedOperands, operandShardings: getOperandShardings(op),
775 resultShardings: getResultShardings(op), spmdizationMap,
776 symbolTableCollection, builder);
777}
778
779static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
780 SymbolTableCollection &symbolTableCollection,
781 OpBuilder &builder) {
782
783 SmallVector<Location> argLocations;
784 llvm::transform(Range: block.getArguments(), d_first: std::back_inserter(x&: argLocations),
785 F: [](BlockArgument arg) { return arg.getLoc(); });
786 Block *newBlock = builder.createBlock(
787 parent: block.getParent(), insertPt: {},
788 argTypes: shardedBlockArgumentTypes(block, symbolTableCollection), locs: argLocations);
789 for (auto [unshardedBlockArg, spmdizedBlockArg] :
790 llvm::zip(t: block.getArguments(), u: newBlock->getArguments())) {
791 spmdizationMap.map(from: unshardedBlockArg, to: spmdizedBlockArg);
792 }
793
794 OpBuilder::InsertionGuard insertionGuard(builder);
795 builder.setInsertionPointToEnd(newBlock);
796 for (Operation &op : block.getOperations()) {
797 if (failed(Result: spmdizeOperation(op, spmdizationMap, symbolTableCollection,
798 builder))) {
799 return failure();
800 }
801 }
802
803 return success();
804}
805
806static LogicalResult
807spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
808 SymbolTableCollection &symbolTableCollection) {
809 OpBuilder builder(op.getFunctionBody());
810
811 // Snapshot the original blocks to not mess up the iteration when adding new
812 // blocks.
813 SmallVector<Block *> originalBlocks;
814 for (Block &b : op.getBlocks()) {
815 if (llvm::any_of(Range&: b.getOperations(),
816 P: [](Operation &op) { return isa<ShardOp>(Val: op); })) {
817 originalBlocks.push_back(Elt: &b);
818 }
819 }
820
821 for (Block *block : originalBlocks) {
822 if (failed(Result: spmdizeBlock(block&: *block, spmdizationMap, symbolTableCollection,
823 builder))) {
824 return failure();
825 }
826 }
827
828 for (Block *block : originalBlocks) {
829 block->erase();
830 }
831
832 // Find a return op and change the function results signature to its operands
833 // signature.
834 Operation *returnOp = nullptr;
835 for (Block &block : op.getFunctionBody()) {
836 if (block.empty()) {
837 continue;
838 }
839
840 if (block.back().hasTrait<OpTrait::ReturnLike>()) {
841 returnOp = &block.back();
842 break;
843 }
844 }
845 if (returnOp) {
846 op.setType(FunctionType::get(
847 context: op->getContext(), inputs: op.getFunctionBody().front().getArgumentTypes(),
848 results: returnOp->getOperandTypes()));
849 }
850
851 return success();
852}
853
854namespace {
855
856struct Spmdization : public impl::SpmdizationBase<Spmdization> {
857 void runOnOperation() override {
858 IRMapping spmdizationMap;
859 SymbolTableCollection symbolTableCollection;
860 if (failed(Result: spmdizeFuncOp(op: getOperation(), spmdizationMap,
861 symbolTableCollection))) {
862 return signalPassFailure();
863 }
864 }
865
866 void getDependentDialects(DialectRegistry &registry) const override {
867 reshardingRegisterDependentDialects(registry);
868 registry.insert<mesh::MeshDialect>();
869 }
870};
871
872} // namespace
873
874} // namespace mlir::mesh
875

source code of mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp