1 | //===- Spmdization.cpp --------------------------------------------- C++ --===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" |
10 | |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
13 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
14 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
23 | #include "mlir/IR/Location.h" |
24 | #include "mlir/IR/MLIRContext.h" |
25 | #include "mlir/IR/SymbolTable.h" |
26 | #include "mlir/IR/Value.h" |
27 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
28 | #include "mlir/Interfaces/FunctionInterfaces.h" |
29 | #include "mlir/Pass/Pass.h" |
30 | #include "mlir/Support/LLVM.h" |
31 | #include "llvm/ADT/APInt.h" |
32 | #include "llvm/ADT/DenseSet.h" |
33 | #include "llvm/ADT/STLExtras.h" |
34 | #include "llvm/ADT/SmallVector.h" |
35 | #include "llvm/Support/Casting.h" |
36 | #include <iterator> |
37 | #include <optional> |
38 | #include <tuple> |
39 | #include <type_traits> |
40 | |
41 | namespace mlir::mesh { |
42 | |
43 | template <typename SourceAxes, typename TargetAxes> |
44 | static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, |
45 | const TargetAxes &targetAxes) { |
46 | return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { |
47 | return sourceAxes.contains(targetAxis); |
48 | }); |
49 | } |
50 | |
51 | // Return the reduced value and its corresponding sharding. |
52 | // Example: |
53 | // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> |
54 | // targetSharding = <@mesh_1d, [[]]> |
55 | // Then will apply all-reduce on the source value |
56 | // and return it with the sharding <@mesh_1d, [[0]]>. |
57 | static std::tuple<TypedValue<ShapedType>, MeshSharding> |
58 | handlePartialAxesDuringResharding(OpBuilder &builder, |
59 | MeshSharding sourceSharding, |
60 | MeshSharding targetSharding, |
61 | TypedValue<ShapedType> sourceShard) { |
62 | if (sourceSharding.getPartialAxes().empty() && |
63 | targetSharding.getPartialAxes().empty()) { |
64 | return {sourceShard, sourceSharding}; |
65 | } |
66 | assert(targetSharding.getPartialAxes().empty() || |
67 | (!sourceSharding.getPartialAxes().empty() && |
68 | sourceSharding.getPartialType() == targetSharding.getPartialType())); |
69 | using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>; |
70 | using AxisSet = llvm::SmallDenseSet<Axis>; |
71 | AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), |
72 | sourceSharding.getPartialAxes().end()); |
73 | AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), |
74 | targetSharding.getPartialAxes().end()); |
75 | assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, |
76 | targetShardingPartialAxesSet)); |
77 | llvm::SmallVector<MeshAxis> allReduceMeshAxes; |
78 | llvm::copy_if(Range&: sourceShardingPartialAxesSet, |
79 | Out: std::back_inserter(x&: allReduceMeshAxes), |
80 | P: [&targetShardingPartialAxesSet](Axis a) { |
81 | return !targetShardingPartialAxesSet.contains(V: a); |
82 | }); |
83 | if (allReduceMeshAxes.empty()) { |
84 | return {sourceShard, sourceSharding}; |
85 | } |
86 | |
87 | builder.setInsertionPointAfterValue(sourceShard); |
88 | TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>( |
89 | builder |
90 | .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(), |
91 | sourceSharding.getMeshAttr().getLeafReference(), |
92 | allReduceMeshAxes, sourceShard, |
93 | sourceSharding.getPartialType()) |
94 | .getResult()); |
95 | |
96 | llvm::SmallVector<MeshAxis> remainingPartialAxes; |
97 | llvm::copy_if(Range&: sourceShardingPartialAxesSet, |
98 | Out: std::back_inserter(x&: allReduceMeshAxes), |
99 | P: [&targetShardingPartialAxesSet](Axis a) { |
100 | return targetShardingPartialAxesSet.contains(V: a); |
101 | }); |
102 | MeshSharding resultSharding = MeshSharding::get( |
103 | sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(), |
104 | remainingPartialAxes, sourceSharding.getPartialType()); |
105 | return {resultValue, resultSharding}; |
106 | } |
107 | |
108 | static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, |
109 | MeshSharding sourceSharding, |
110 | int64_t splitTensorAxis, |
111 | MeshAxis splitMeshAxis) { |
112 | SmallVector<MeshAxesAttr> targetShardingSplitAxes = |
113 | llvm::to_vector(sourceSharding.getSplitAxes()); |
114 | while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= |
115 | splitTensorAxis) { |
116 | targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); |
117 | } |
118 | auto targetSplitAxes = |
119 | llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); |
120 | targetSplitAxes.push_back(splitMeshAxis); |
121 | targetShardingSplitAxes[splitTensorAxis] = |
122 | MeshAxesAttr::get(ctx, targetSplitAxes); |
123 | return MeshSharding::get( |
124 | sourceSharding.getMeshAttr(), targetShardingSplitAxes, |
125 | sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); |
126 | } |
127 | |
128 | // Split a replicated tensor along a mesh axis. |
129 | // E.g. [[0, 1]] -> [[0, 1, 2]]. |
130 | // Returns the spmdized target value with its sharding. |
131 | static std::tuple<TypedValue<ShapedType>, MeshSharding> |
132 | splitLastAxisInResharding(ImplicitLocOpBuilder &builder, |
133 | MeshSharding sourceSharding, |
134 | TypedValue<ShapedType> sourceShard, MeshOp mesh, |
135 | int64_t splitTensorAxis, MeshAxis splitMeshAxis) { |
136 | TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( |
137 | builder |
138 | .create<AllSliceOp>(sourceShard, mesh, |
139 | ArrayRef<MeshAxis>(splitMeshAxis), |
140 | splitTensorAxis) |
141 | .getResult()); |
142 | MeshSharding targetSharding = targetShardingInSplitLastAxis( |
143 | ctx: builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); |
144 | return {targetShard, targetSharding}; |
145 | } |
146 | |
147 | // Detect if the resharding is of type e.g. |
148 | // [[0, 1]] -> [[0, 1, 2]]. |
149 | // If detected, returns the corresponding tensor axis mesh axis pair. |
150 | // Does not detect insertions like |
151 | // [[0, 1]] -> [[0, 2, 1]]. |
152 | static std::optional<std::tuple<int64_t, MeshAxis>> |
153 | detectSplitLastAxisInResharding(MeshSharding sourceSharding, |
154 | MeshSharding targetSharding) { |
155 | for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); |
156 | ++tensorAxis) { |
157 | if (sourceSharding.getSplitAxes().size() > tensorAxis) { |
158 | if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != |
159 | targetSharding.getSplitAxes()[tensorAxis].size()) { |
160 | continue; |
161 | } |
162 | if (!llvm::equal( |
163 | sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), |
164 | llvm::make_range( |
165 | targetSharding.getSplitAxes()[tensorAxis] |
166 | .asArrayRef() |
167 | .begin(), |
168 | targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - |
169 | 1))) { |
170 | continue; |
171 | } |
172 | } else { |
173 | if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { |
174 | continue; |
175 | } |
176 | } |
177 | return std::make_tuple( |
178 | tensorAxis, |
179 | targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); |
180 | } |
181 | return std::nullopt; |
182 | } |
183 | |
184 | static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> |
185 | trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
186 | MeshSharding sourceSharding, |
187 | MeshSharding targetSharding, |
188 | TypedValue<ShapedType> sourceShard) { |
189 | if (auto detectRes = |
190 | detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { |
191 | auto [tensorAxis, meshAxis] = detectRes.value(); |
192 | return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, |
193 | tensorAxis, meshAxis); |
194 | } |
195 | |
196 | return std::nullopt; |
197 | } |
198 | |
199 | // Detect if the resharding is of type e.g. |
200 | // [[0, 1, 2]] -> [[0, 1]]. |
201 | // If detected, returns the corresponding tensor axis mesh axis pair. |
202 | static std::optional<std::tuple<int64_t, MeshAxis>> |
203 | detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, |
204 | MeshSharding targetSharding) { |
205 | for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); |
206 | ++tensorAxis) { |
207 | if (targetSharding.getSplitAxes().size() > tensorAxis) { |
208 | if (sourceSharding.getSplitAxes()[tensorAxis].size() != |
209 | targetSharding.getSplitAxes()[tensorAxis].size() + 1) |
210 | continue; |
211 | if (!llvm::equal( |
212 | llvm::make_range( |
213 | sourceSharding.getSplitAxes()[tensorAxis] |
214 | .asArrayRef() |
215 | .begin(), |
216 | sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - |
217 | 1), |
218 | targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) |
219 | continue; |
220 | } else { |
221 | if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) |
222 | continue; |
223 | } |
224 | return std::make_tuple( |
225 | tensorAxis, |
226 | sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); |
227 | } |
228 | return std::nullopt; |
229 | } |
230 | |
231 | static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, |
232 | MeshSharding sourceSharding, |
233 | int64_t splitTensorAxis) { |
234 | SmallVector<MeshAxesAttr> targetShardingSplitAxes = |
235 | llvm::to_vector(sourceSharding.getSplitAxes()); |
236 | assert(static_cast<int64_t>(targetShardingSplitAxes.size()) > |
237 | splitTensorAxis); |
238 | auto targetSplitAxes = |
239 | llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); |
240 | |
241 | targetSplitAxes.pop_back(); |
242 | targetShardingSplitAxes[splitTensorAxis] = |
243 | MeshAxesAttr::get(ctx, targetSplitAxes); |
244 | return MeshSharding::get( |
245 | sourceSharding.getMeshAttr(), targetShardingSplitAxes, |
246 | sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); |
247 | } |
248 | |
249 | static ShapedType allGatherResultShapeInUnsplitLastAxis( |
250 | ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { |
251 | SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); |
252 | targetShape[splitTensorAxis] = |
253 | gatherDimension(dimSize: targetShape[splitTensorAxis], shardCount: splitCount); |
254 | return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); |
255 | } |
256 | |
257 | static std::tuple<TypedValue<ShapedType>, MeshSharding> |
258 | unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, |
259 | MeshSharding sourceSharding, |
260 | ShapedType sourceUnshardedShape, |
261 | TypedValue<ShapedType> sourceShard, MeshOp mesh, |
262 | int64_t splitTensorAxis, MeshAxis splitMeshAxis) { |
263 | MLIRContext *ctx = builder.getContext(); |
264 | builder.setInsertionPointAfterValue(sourceShard); |
265 | |
266 | MeshSharding targetSharding = |
267 | targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); |
268 | ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( |
269 | sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); |
270 | Value allGatherResult = builder.create<AllGatherOp>( |
271 | RankedTensorType::get(allGatherResultShape.getShape(), |
272 | allGatherResultShape.getElementType()), |
273 | mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard, |
274 | APInt(64, splitTensorAxis)); |
275 | ShapedType targetShape = |
276 | shardShapedType(sourceUnshardedShape, mesh, targetSharding); |
277 | TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( |
278 | builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult()); |
279 | return {targetShard, targetSharding}; |
280 | } |
281 | |
282 | static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> |
283 | tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
284 | MeshSharding sourceSharding, |
285 | MeshSharding targetSharding, |
286 | ShapedType sourceUnshardedShape, |
287 | TypedValue<ShapedType> sourceShard) { |
288 | if (auto detectRes = |
289 | detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { |
290 | auto [tensorAxis, meshAxis] = detectRes.value(); |
291 | return unsplitLastAxisInResharding(builder, sourceSharding, |
292 | sourceUnshardedShape, sourceShard, mesh, |
293 | tensorAxis, meshAxis); |
294 | } |
295 | |
296 | return std::nullopt; |
297 | } |
298 | |
299 | // Detect if the resharding is of type e.g. |
300 | // [[0, 1], [2]] -> [[0], [1, 2]]. |
301 | // Only moving the last axis counts. |
302 | // If detected, returns the corresponding (source_tensor_axis, |
303 | // target_tensor_axis, mesh_axis) tuple. |
304 | static std::optional<std::tuple<int64_t, int64_t, MeshAxis>> |
305 | detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, |
306 | MeshSharding targetSharding) { |
307 | for (size_t sourceTensorAxis = 0; |
308 | sourceTensorAxis < sourceSharding.getSplitAxes().size(); |
309 | ++sourceTensorAxis) { |
310 | for (size_t targetTensorAxis = 0; |
311 | targetTensorAxis < targetSharding.getSplitAxes().size(); |
312 | ++targetTensorAxis) { |
313 | if (sourceTensorAxis == targetTensorAxis) |
314 | continue; |
315 | if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || |
316 | targetSharding.getSplitAxes()[targetTensorAxis].empty() || |
317 | sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != |
318 | targetSharding.getSplitAxes()[targetTensorAxis] |
319 | .asArrayRef() |
320 | .back()) |
321 | continue; |
322 | if (!llvm::equal( |
323 | llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] |
324 | .asArrayRef() |
325 | .begin(), |
326 | sourceSharding.getSplitAxes()[sourceTensorAxis] |
327 | .asArrayRef() |
328 | .end() - |
329 | 1), |
330 | llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] |
331 | .asArrayRef() |
332 | .begin(), |
333 | targetSharding.getSplitAxes()[targetTensorAxis] |
334 | .asArrayRef() |
335 | .end() - |
336 | 1))) |
337 | continue; |
338 | return std::make_tuple( |
339 | sourceTensorAxis, targetTensorAxis, |
340 | sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); |
341 | } |
342 | } |
343 | return std::nullopt; |
344 | } |
345 | |
346 | static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, |
347 | MeshSharding sourceSharding, |
348 | int64_t sourceTensorAxis, |
349 | int64_t targetTensorAxis) { |
350 | SmallVector<MeshAxesAttr> targetShardingSplitAxes = |
351 | llvm::to_vector(sourceSharding.getSplitAxes()); |
352 | while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= |
353 | targetTensorAxis) { |
354 | targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); |
355 | } |
356 | |
357 | auto sourceSplitAxes = |
358 | llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); |
359 | assert(!sourceSplitAxes.empty()); |
360 | auto meshAxis = sourceSplitAxes.back(); |
361 | sourceSplitAxes.pop_back(); |
362 | targetShardingSplitAxes[sourceTensorAxis] = |
363 | MeshAxesAttr::get(ctx, sourceSplitAxes); |
364 | |
365 | auto targetSplitAxes = |
366 | llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); |
367 | targetSplitAxes.push_back(meshAxis); |
368 | targetShardingSplitAxes[targetTensorAxis] = |
369 | MeshAxesAttr::get(ctx, targetSplitAxes); |
370 | |
371 | return MeshSharding::get( |
372 | sourceSharding.getMeshAttr(), targetShardingSplitAxes, |
373 | sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); |
374 | } |
375 | |
376 | static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, |
377 | int64_t splitCount, |
378 | int64_t sourceTensorAxis, |
379 | int64_t targetTensorAxis) { |
380 | SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); |
381 | targetShape[sourceTensorAxis] = |
382 | gatherDimension(dimSize: targetShape[sourceTensorAxis], shardCount: splitCount); |
383 | targetShape[targetTensorAxis] = |
384 | shardDimension(dimSize: targetShape[targetTensorAxis], shardCount: splitCount); |
385 | return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); |
386 | } |
387 | |
388 | static std::tuple<TypedValue<ShapedType>, MeshSharding> |
389 | moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
390 | MeshSharding sourceSharding, |
391 | ShapedType sourceUnshardedShape, |
392 | TypedValue<ShapedType> sourceShard, |
393 | int64_t sourceTensorAxis, |
394 | int64_t targetTensorAxis, MeshAxis meshAxis) { |
395 | MLIRContext *ctx = builder.getContext(); |
396 | builder.setInsertionPointAfterValue(sourceShard); |
397 | |
398 | MeshSharding targetSharding = targetShardingInMoveLastAxis( |
399 | ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); |
400 | ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( |
401 | sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, |
402 | targetTensorAxis); |
403 | Value allToAllResult = builder.create<AllToAllOp>( |
404 | RankedTensorType::get(allToAllResultShape.getShape(), |
405 | allToAllResultShape.getElementType()), |
406 | mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard, |
407 | APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); |
408 | ShapedType targetShape = |
409 | shardShapedType(sourceUnshardedShape, mesh, targetSharding); |
410 | TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( |
411 | builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult()); |
412 | return {targetShard, targetSharding}; |
413 | } |
414 | |
415 | static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> |
416 | tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
417 | MeshSharding sourceSharding, |
418 | MeshSharding targetSharding, |
419 | ShapedType sourceUnshardedShape, |
420 | TypedValue<ShapedType> sourceShard) { |
421 | if (auto detectRes = |
422 | detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { |
423 | auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); |
424 | return moveLastSplitAxisInResharding( |
425 | builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, |
426 | sourceTensorAxis, targetTensorAxis, meshAxis); |
427 | } |
428 | |
429 | return std::nullopt; |
430 | } |
431 | |
432 | // Detect a change in the halo size (only) and create necessary operations if |
433 | // needed. A changed halo sizes requires copying the "core" of the source tensor |
434 | // into the "core" of the destination tensor followed by an update halo |
435 | // operation. |
436 | static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> |
437 | tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
438 | MeshSharding sourceSharding, |
439 | MeshSharding targetSharding, |
440 | ShapedType sourceUnshardedShape, |
441 | TypedValue<ShapedType> sourceShard) { |
442 | // Currently handles only cases where halo sizes differ but everything else |
443 | // stays the same (from source to destination sharding). |
444 | if (!sourceSharding.equalSplitAndPartialAxes(rhs: targetSharding) || |
445 | !sourceSharding.getPartialAxes().empty() || |
446 | !targetSharding.getPartialAxes().empty() || |
447 | !sourceSharding.getStaticShardedDimsOffsets().empty() || |
448 | !targetSharding.getStaticShardedDimsOffsets().empty() || |
449 | sourceSharding.equalHaloSizes(rhs: targetSharding)) { |
450 | return std::nullopt; |
451 | } |
452 | |
453 | auto srcHaloSizes = sourceSharding.getStaticHaloSizes(); |
454 | auto tgtHaloSizes = targetSharding.getStaticHaloSizes(); |
455 | assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size()); |
456 | assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) && |
457 | !ShapedType::isDynamicShape(tgtHaloSizes) && |
458 | sourceShard.getType().hasStaticShape()) && |
459 | "dynamic shapes/halos are not supported yet for mesh-spmdization"); |
460 | auto rank = sourceShard.getType().getRank(); |
461 | auto splitAxes = sourceSharding.getSplitAxes(); |
462 | SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0), |
463 | strides(rank, 1), outShape(sourceShard.getType().getShape()), |
464 | coreShape(sourceShard.getType().getShape()); |
465 | |
466 | // Determine "core" of source and destination. |
467 | // The core is the local part of the shard excluding halo regions. |
468 | for (auto i = 0u; i < rank; ++i) { |
469 | if (i < splitAxes.size() && !splitAxes[i].empty()) { |
470 | if (!srcHaloSizes.empty()) { |
471 | coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1]; |
472 | srcCoreOffs[i] = srcHaloSizes[i * 2]; |
473 | } |
474 | tgtCoreOffs[i] = tgtHaloSizes[i * 2]; |
475 | outShape[i] = |
476 | coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1]; |
477 | } |
478 | } |
479 | |
480 | // Extract core from source and copy into destination core. |
481 | auto noVals = ValueRange{}; |
482 | auto initVal = builder.create<tensor::EmptyOp>( |
483 | sourceShard.getLoc(), outShape, sourceShard.getType().getElementType()); |
484 | auto core = builder.create<tensor::ExtractSliceOp>( |
485 | sourceShard.getLoc(), |
486 | RankedTensorType::get(coreShape, sourceShard.getType().getElementType()), |
487 | sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides); |
488 | auto initOprnd = builder.create<tensor::InsertSliceOp>( |
489 | sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs, |
490 | coreShape, strides); |
491 | |
492 | // Finally update the halo. |
493 | auto updateHaloResult = |
494 | builder |
495 | .create<UpdateHaloOp>( |
496 | sourceShard.getLoc(), |
497 | RankedTensorType::get(outShape, |
498 | sourceShard.getType().getElementType()), |
499 | initOprnd, mesh.getSymName(), |
500 | MeshAxesArrayAttr::get(builder.getContext(), |
501 | sourceSharding.getSplitAxes()), |
502 | targetSharding.getDynamicHaloSizes(), |
503 | targetSharding.getStaticHaloSizes()) |
504 | .getResult(); |
505 | return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult), |
506 | targetSharding); |
507 | } |
508 | |
509 | // Handles only resharding on a 1D mesh. |
510 | // Currently the sharded tensor axes must be exactly divisible by the single |
511 | // mesh axis size. |
512 | static TypedValue<ShapedType> |
513 | reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, |
514 | MeshSharding sourceSharding, MeshSharding targetSharding, |
515 | TypedValue<ShapedType> sourceUnshardedValue, |
516 | TypedValue<ShapedType> sourceShard) { |
517 | assert(sourceShard.getType() == |
518 | shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); |
519 | [[maybe_unused]] ShapedType targetShardType = |
520 | shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); |
521 | assert(sourceShard.getType().getRank() == targetShardType.getRank()); |
522 | assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); |
523 | |
524 | auto [reducedSourceShard, reducedSourceSharding] = |
525 | handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, |
526 | sourceShard); |
527 | |
528 | if (reducedSourceSharding == targetSharding) { |
529 | return reducedSourceShard; |
530 | } |
531 | |
532 | TypedValue<ShapedType> targetShard; |
533 | MeshSharding actualTargetSharding; |
534 | if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() && |
535 | targetSharding.getStaticShardedDimsOffsets().empty() && |
536 | reducedSourceSharding.getStaticHaloSizes().empty() && |
537 | targetSharding.getStaticHaloSizes().empty()) { |
538 | if (auto tryRes = tryMoveLastSplitAxisInResharding( |
539 | builder, mesh, reducedSourceSharding, targetSharding, |
540 | sourceUnshardedValue.getType(), reducedSourceShard)) { |
541 | std::tie(targetShard, actualTargetSharding) = tryRes.value(); |
542 | } else if (auto tryRes = trySplitLastAxisInResharding( |
543 | builder, mesh, reducedSourceSharding, targetSharding, |
544 | reducedSourceShard)) { |
545 | std::tie(targetShard, actualTargetSharding) = tryRes.value(); |
546 | } else if (auto tryRes = tryUnsplitLastAxisInResharding( |
547 | builder, mesh, reducedSourceSharding, targetSharding, |
548 | sourceUnshardedValue.getType(), reducedSourceShard)) { |
549 | std::tie(targetShard, actualTargetSharding) = tryRes.value(); |
550 | } |
551 | } |
552 | assert(targetShard && "Did not find any pattern to apply."); |
553 | assert(actualTargetSharding == targetSharding); |
554 | assert(targetShard.getType() == targetShardType); |
555 | return targetShard; |
556 | } |
557 | |
558 | TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, |
559 | MeshSharding sourceSharding, |
560 | MeshSharding targetSharding, |
561 | TypedValue<ShapedType> sourceUnshardedValue, |
562 | TypedValue<ShapedType> sourceShard) { |
563 | // If source and destination sharding are the same, no need to do anything. |
564 | if (sourceSharding == targetSharding || (isFullReplication(sharding: sourceSharding) && |
565 | isFullReplication(sharding: targetSharding))) { |
566 | return sourceShard; |
567 | } |
568 | |
569 | // Tries to handle the case where the resharding is needed because the halo |
570 | // sizes are different. Supports arbitrary mesh dimensionality. |
571 | if (auto tryRes = tryUpdateHaloInResharding( |
572 | builder, mesh, sourceSharding, targetSharding, |
573 | sourceUnshardedValue.getType(), sourceShard)) { |
574 | return std::get<0>(tryRes.value()); // targetShard |
575 | } |
576 | |
577 | // Resort to handling only 1D meshes since the general case is complicated if |
578 | // it needs to be communication efficient in terms of minimizing the data |
579 | // transfered between devices. |
580 | return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, |
581 | sourceUnshardedValue, sourceShard); |
582 | } |
583 | |
584 | TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, |
585 | ShardOp target, |
586 | TypedValue<ShapedType> sourceShardValue) { |
587 | assert(source.getResult() == target.getSrc()); |
588 | auto sourceSharding = source.getSharding(); |
589 | auto targetSharding = target.getSharding(); |
590 | ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); |
591 | return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding, |
592 | cast<TypedValue<ShapedType>>(source.getSrc()), |
593 | sourceShardValue); |
594 | } |
595 | |
596 | TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, |
597 | ShardOp target, |
598 | TypedValue<ShapedType> sourceShardValue, |
599 | SymbolTableCollection &symbolTableCollection) { |
600 | MeshOp srcMesh = getMesh(source, symbolTableCollection); |
601 | assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); |
602 | return reshard(builder, srcMesh, source, target, sourceShardValue); |
603 | } |
604 | |
605 | void reshardingRegisterDependentDialects(DialectRegistry ®istry) { |
606 | registry.insert<mesh::MeshDialect, tensor::TensorDialect>(); |
607 | } |
608 | |
609 | #define GEN_PASS_DEF_SPMDIZATION |
610 | #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" |
611 | |
612 | using UnshardedToShardedValueMap = DenseMap<Value, Value>; |
613 | |
614 | // Get the types of block arguments for an spmdized block. |
615 | // Reads the sharding annotations of the arguments to deduce the sharded types. |
616 | // Types that are not ranked tensors are left unchanged. |
617 | SmallVector<Type> |
618 | shardedBlockArgumentTypes(Block &block, |
619 | SymbolTableCollection &symbolTableCollection) { |
620 | SmallVector<Type> res; |
621 | llvm::transform( |
622 | block.getArguments(), std::back_inserter(res), |
623 | [&symbolTableCollection](BlockArgument arg) { |
624 | auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg); |
625 | if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) { |
626 | return arg.getType(); |
627 | } |
628 | |
629 | assert(rankedTensorArg.hasOneUse()); |
630 | Operation *useOp = *rankedTensorArg.getUsers().begin(); |
631 | ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp); |
632 | assert(shardOp); |
633 | MeshOp mesh = getMesh(shardOp, symbolTableCollection); |
634 | return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh, |
635 | shardOp.getSharding())); |
636 | }); |
637 | return res; |
638 | } |
639 | |
640 | static LogicalResult spmdizeOperation( |
641 | Operation &op, ArrayRef<Value> spmdizedOperands, |
642 | ArrayRef<MeshSharding> operandShardings, |
643 | ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, |
644 | SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { |
645 | ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op); |
646 | if (!shardingInterface) { |
647 | // If there is no sharding interface we are conservative and assume that |
648 | // the op should be fully replicated no all devices. |
649 | spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, |
650 | resultShardings, spmdizationMap, |
651 | symbolTableCollection, builder); |
652 | } else { |
653 | if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, |
654 | resultShardings, spmdizationMap, |
655 | symbolTableCollection, builder))) { |
656 | return failure(); |
657 | } |
658 | } |
659 | |
660 | assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { |
661 | return spmdizationMap.contains(result); |
662 | })); |
663 | |
664 | return success(); |
665 | } |
666 | |
667 | // Retrieve the sharding annotations for the operands of the given operation. |
668 | // If the type is not a ranked tensor it is not require to have an annotation. |
669 | static std::vector<MeshSharding> getOperandShardings(Operation &op) { |
670 | std::vector<MeshSharding> res; |
671 | res.reserve(op.getNumOperands()); |
672 | llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { |
673 | TypedValue<RankedTensorType> rankedTensor = |
674 | dyn_cast<TypedValue<RankedTensorType>>(operand); |
675 | if (!rankedTensor || rankedTensor.getType().getRank() == 0) { |
676 | return MeshSharding(); |
677 | } |
678 | |
679 | Operation *definingOp = operand.getDefiningOp(); |
680 | assert(definingOp); |
681 | ShardOp shardOp = llvm::cast<ShardOp>(definingOp); |
682 | return MeshSharding(shardOp.getSharding()); |
683 | }); |
684 | return res; |
685 | } |
686 | |
687 | // Retrieve the sharding annotations for the results of the given operation. |
688 | // If the type is not a ranked tensor it is not require to have an annotation. |
689 | static std::vector<MeshSharding> getResultShardings(Operation &op) { |
690 | std::vector<MeshSharding> res; |
691 | res.reserve(op.getNumResults()); |
692 | llvm::transform( |
693 | op.getResults(), std::back_inserter(res), [&op](OpResult result) { |
694 | if (!result.hasOneUse() || result.use_empty()) { |
695 | return MeshSharding(); |
696 | } |
697 | TypedValue<RankedTensorType> rankedTensor = |
698 | dyn_cast<TypedValue<RankedTensorType>>(result); |
699 | if (!rankedTensor) { |
700 | return MeshSharding(); |
701 | } |
702 | Operation *userOp = *result.getUsers().begin(); |
703 | ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp); |
704 | if (shardOp) { |
705 | return MeshSharding(shardOp.getSharding()); |
706 | } |
707 | if (rankedTensor.getType().getRank() == 0) { |
708 | // This is a 0d tensor result without explicit sharding. |
709 | // Find mesh symbol from operands, if any. |
710 | // Shardings without mesh are not always fully supported yet. |
711 | for (auto operand : op.getOperands()) { |
712 | if (auto sharding = operand.getDefiningOp<ShardingOp>()) { |
713 | return MeshSharding(sharding.getMeshAttr()); |
714 | } |
715 | } |
716 | } |
717 | return MeshSharding(); |
718 | }); |
719 | return res; |
720 | } |
721 | |
722 | static LogicalResult |
723 | spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, |
724 | SymbolTableCollection &symbolTableCollection, |
725 | OpBuilder &builder) { |
726 | Value targetSpmdValue; |
727 | |
728 | // Check if 2 shard ops are chained. If not there is no need for resharding |
729 | // as the source and target shared the same sharding. |
730 | ShardOp srcShardOp = |
731 | dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp()); |
732 | if (!srcShardOp) { |
733 | targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc()); |
734 | } else { |
735 | // Insert resharding. |
736 | TypedValue<ShapedType> srcSpmdValue = |
737 | cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp)); |
738 | targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, |
739 | symbolTableCollection); |
740 | } |
741 | |
742 | assert(!spmdizationMap.contains(shardOp.getResult())); |
743 | spmdizationMap.map(shardOp.getResult(), targetSpmdValue); |
744 | return success(); |
745 | } |
746 | |
747 | static LogicalResult |
748 | spmdizeOperation(Operation &op, IRMapping &spmdizationMap, |
749 | SymbolTableCollection &symbolTableCollection, |
750 | OpBuilder &builder) { |
751 | if (isa<ShardingOp>(op)) { |
752 | return success(); |
753 | } |
754 | if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) { |
755 | auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>(); |
756 | if (!shardOp) { |
757 | return op.emitError(message: "expected a shard op as source of get_sharding"); |
758 | } |
759 | auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp()); |
760 | spmdizationMap.map(op.getResult(idx: 0), newSharding->getResult(0)); |
761 | return success(); |
762 | } |
763 | |
764 | ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); |
765 | if (shardOp) { |
766 | return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, |
767 | builder); |
768 | } |
769 | |
770 | SmallVector<Value> spmdizedOperands; |
771 | llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), |
772 | [&spmdizationMap](Value operand) { |
773 | assert(spmdizationMap.contains(operand)); |
774 | return spmdizationMap.lookup(operand); |
775 | }); |
776 | return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), |
777 | getResultShardings(op), spmdizationMap, |
778 | symbolTableCollection, builder); |
779 | } |
780 | |
781 | static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, |
782 | SymbolTableCollection &symbolTableCollection, |
783 | OpBuilder &builder) { |
784 | |
785 | SmallVector<Location> argLocations; |
786 | llvm::transform(block.getArguments(), std::back_inserter(argLocations), |
787 | [](BlockArgument arg) { return arg.getLoc(); }); |
788 | Block *newBlock = builder.createBlock( |
789 | block.getParent(), {}, |
790 | shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); |
791 | for (auto [unshardedBlockArg, spmdizedBlockArg] : |
792 | llvm::zip(block.getArguments(), newBlock->getArguments())) { |
793 | spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); |
794 | } |
795 | |
796 | OpBuilder::InsertionGuard insertionGuard(builder); |
797 | builder.setInsertionPointToEnd(newBlock); |
798 | for (Operation &op : block.getOperations()) { |
799 | if (failed(Result: spmdizeOperation(op, spmdizationMap, symbolTableCollection, |
800 | builder))) { |
801 | return failure(); |
802 | } |
803 | } |
804 | |
805 | return success(); |
806 | } |
807 | |
808 | static LogicalResult |
809 | spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, |
810 | SymbolTableCollection &symbolTableCollection) { |
811 | OpBuilder builder(op.getFunctionBody()); |
812 | |
813 | // Snapshot the original blocks to not mess up the iteration when adding new |
814 | // blocks. |
815 | SmallVector<Block *> originalBlocks; |
816 | for (Block &b : op.getBlocks()) { |
817 | if (llvm::any_of(b.getOperations(), |
818 | [](Operation &op) { return isa<ShardOp>(op); })) { |
819 | originalBlocks.push_back(&b); |
820 | } |
821 | } |
822 | |
823 | for (Block *block : originalBlocks) { |
824 | if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, |
825 | builder))) { |
826 | return failure(); |
827 | } |
828 | } |
829 | |
830 | for (Block *block : originalBlocks) { |
831 | block->erase(); |
832 | } |
833 | |
834 | // Find a return op and change the function results signature to its operands |
835 | // signature. |
836 | Operation *returnOp = nullptr; |
837 | for (Block &block : op.getFunctionBody()) { |
838 | if (block.empty()) { |
839 | continue; |
840 | } |
841 | |
842 | if (block.back().hasTrait<OpTrait::ReturnLike>()) { |
843 | returnOp = &block.back(); |
844 | break; |
845 | } |
846 | } |
847 | if (returnOp) { |
848 | op.setType(FunctionType::get( |
849 | op->getContext(), op.getFunctionBody().front().getArgumentTypes(), |
850 | returnOp->getOperandTypes())); |
851 | } |
852 | |
853 | return success(); |
854 | } |
855 | |
856 | namespace { |
857 | |
858 | struct Spmdization : public impl::SpmdizationBase<Spmdization> { |
859 | void runOnOperation() override { |
860 | IRMapping spmdizationMap; |
861 | SymbolTableCollection symbolTableCollection; |
862 | if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, |
863 | symbolTableCollection))) { |
864 | return signalPassFailure(); |
865 | } |
866 | } |
867 | |
868 | void getDependentDialects(DialectRegistry ®istry) const override { |
869 | reshardingRegisterDependentDialects(registry); |
870 | registry.insert<mesh::MeshDialect>(); |
871 | } |
872 | }; |
873 | |
874 | } // namespace |
875 | |
876 | } // namespace mlir::mesh |
877 |
Definitions
- arePartialAxesCompatible
- handlePartialAxesDuringResharding
- targetShardingInSplitLastAxis
- splitLastAxisInResharding
- detectSplitLastAxisInResharding
- trySplitLastAxisInResharding
- detectUnsplitLastAxisInResharding
- targetShardingInUnsplitLastAxis
- allGatherResultShapeInUnsplitLastAxis
- unsplitLastAxisInResharding
- tryUnsplitLastAxisInResharding
- detectMoveLastSplitAxisInResharding
- targetShardingInMoveLastAxis
- allToAllResultShapeInMoveLastAxis
- moveLastSplitAxisInResharding
- tryMoveLastSplitAxisInResharding
- tryUpdateHaloInResharding
- reshardOn1DMesh
- reshard
- reshard
- reshard
- reshardingRegisterDependentDialects
- shardedBlockArgumentTypes
- spmdizeOperation
- getOperandShardings
- getResultShardings
- spmdizeOperation
- spmdizeOperation
- spmdizeBlock
- spmdizeFuncOp
- Spmdization
- runOnOperation
Improve your Profiling and Debugging skills
Find out more