1//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Passes.h"
10
11#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
14#include "mlir/IR/Verifier.h"
15#include "mlir/Interfaces/FunctionInterfaces.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/iterator_range.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/raw_ostream.h"
20#include <algorithm>
21#include <vector>
22
23namespace mlir {
24namespace mesh {
25#define GEN_PASS_DEF_SHARDINGPROPAGATION
26#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
27} // namespace mesh
28} // namespace mlir
29
30#define DEBUG_TYPE "sharding-propagation"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32
33using namespace mlir;
34using namespace mlir::mesh;
35
36enum class ReshardingRquirementKind {
37 NO_RESHARDING = 0,
38 NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
39 RESHARDING_FOR_EXPLICIT_ANNOTATIONS
40};
41
42#ifdef LLVM_DEBUG
43
44template <typename T>
45static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
46 const SmallVector<T> &vec);
47template <typename... Ts>
48static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49 const std::tuple<Ts...> &t);
50static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
51 ReshardingRquirementKind v);
52
53template <typename Stream, typename Range>
54static Stream &printRange(Stream &stream, Range &&range) {
55 stream << "[";
56 for (auto &v : range) {
57 stream << v;
58 stream << ", ";
59 }
60 return stream << "]";
61}
62
63template <typename T>
64static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
65 const SmallVector<T> &vec) {
66 return printRange(stream, vec);
67}
68
69[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
70 const ShardingOption &v) {
71 return stream << "{empty = " << v.empty << ", mesh" << v.mesh
72 << ", shardingArray = " << v.shardingArray << "}";
73}
74
75template <typename Stream, typename... Ts, size_t... Is>
76static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
77 std::index_sequence<Is...>) {
78 static_assert(sizeof...(Is) == sizeof...(Ts),
79 "Indices must have same number of elements as tuple types!");
80 static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
81
82 stream << "{";
83 ((stream << std::get<Is>(tuple) << ", "), ...);
84 return stream << "}";
85}
86
87template <typename... Ts>
88static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
89 const std::tuple<Ts...> &t) {
90 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
91}
92
93[[maybe_unused]] static llvm::raw_ostream &
94operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
95 return stream << static_cast<int>(v);
96}
97
98#endif // LLVM_DEBUG
99
100//===----------------------------------------------------------------------===//
101// Utilities
102//===----------------------------------------------------------------------===//
103
104// This method retrieves all potential sharding attributes, prioritizing
105// specific shardings. For example, mustShardings = [shard0, None] and
106// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
107// [shard0, None]]
108static SmallVector<std::vector<MeshSharding>>
109getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
110 ArrayRef<MeshSharding> optionalShardings) {
111 SmallVector<std::vector<MeshSharding>> allShardingAttrs;
112 std::vector<MeshSharding> curShardingAttrs;
113
114 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
115 if (i == mustShardings.size()) {
116 allShardingAttrs.push_back(Elt: std::vector<MeshSharding>(curShardingAttrs));
117 return;
118 }
119
120 if (mustShardings[i]) {
121 curShardingAttrs.push_back(x: mustShardings[i]);
122 dfsCreateShardingAttrs(i + 1);
123 curShardingAttrs.pop_back();
124 return;
125 }
126
127 if (optionalShardings[i]) {
128 curShardingAttrs.push_back(x: optionalShardings[i]);
129 dfsCreateShardingAttrs(i + 1);
130 curShardingAttrs.pop_back();
131 curShardingAttrs.push_back(x: {});
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
134 return;
135 }
136
137 curShardingAttrs.push_back(x: {});
138 dfsCreateShardingAttrs(i + 1);
139 curShardingAttrs.pop_back();
140 };
141
142 dfsCreateShardingAttrs(0);
143 return allShardingAttrs;
144}
145
146// The order of preference is form highest to lowest:
147// 1. No resharding is required (all existing annotations are compatible).
148// 2. No resharding for operands/results that have annotation specifically
149// targeting this operation. This means
150// * operands that are the result of `mesh.shard` ops marked with
151// `annotate_for_users`.
152// * results that are annotated with `mesh.shard` ops without
153// `annotate_for_users`.
154// 3. All other cases. Resharding is required for operands/results with
155// annotation targeting explicitly this operation.
156ReshardingRquirementKind getReshardingRquirementKind(
157 Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
158 ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
159
160 size_t operandsCount = op->getOperands().size();
161 auto operandShardings =
162 llvm::make_range(x: operandAndResultShardings.begin(),
163 y: operandAndResultShardings.begin() + operandsCount);
164 auto resultShardings =
165 llvm::make_range(x: operandAndResultShardings.begin() + operandsCount,
166 y: operandAndResultShardings.end());
167
168 for (auto [operand, sharding] :
169 llvm::zip_equal(t: op->getOperands(), u&: operandShardings)) {
170 ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(Val: operand.getDefiningOp());
171 if (!shardOp) {
172 continue;
173 }
174 bool needsResharding = sharding != shardOp.getSharding();
175 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
176 if (needsResharding) {
177 if (isExplicitAnnotationForThisOp) {
178 // This is the worst case. No need to continue.
179 return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
180 }
181 res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
182 }
183 }
184
185 for (auto [result, sharding] :
186 llvm::zip_equal(t: op->getResults(), u&: resultShardings)) {
187 for (auto user : result.getUsers()) {
188 ShardOp shardOp = llvm::dyn_cast<ShardOp>(Val: user);
189 if (!shardOp) {
190 continue;
191 }
192 bool needsResharding = sharding != shardOp.getSharding();
193 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
194 if (needsResharding) {
195 if (isExplicitAnnotationForThisOp) {
196 // This is the worst case. No need to continue.
197 return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
198 }
199 res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
200 }
201 }
202 }
203
204 return res;
205}
206
207// From all the operand and result sharding combinations,
208// return the one that is most desirable.
209// The order of preference is:
210// 1. No resharding with respect to existing sharding annotations.
211// 2. Resharding for values that have already annotations that do not target
212// this op.
213// 3. Resharding of existing explicit sharding annotations for this op.
214static FailureOr<ShardingOption> selectShardingOption(
215 ShardingInterface shardingOp,
216 ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
217 ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
218 SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
219 shardingOptionsAndReshardingRequirements;
220
221 for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
222 for (ArrayRef<MeshSharding> operandShardings :
223 possibleOperandShardingAttrs) {
224 FailureOr<ShardingOption> shardingOption =
225 shardingOp.getShardingOption(operandShardings, resultShardings);
226 if (failed(Result: shardingOption) || shardingOption->empty) {
227 continue;
228 }
229 // These shardings may not be the same as those in operandShardings and
230 // resultShardings.
231 // They may be missing some annotations.
232 // Whatever is returned by getShardingAnnotations is exactly what the op
233 // needs.
234 FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
235 shardingOp.getShardingAnnotations(shardingOption: *shardingOption);
236 if (failed(Result: operandAndResultShardings)) {
237 return failure();
238 }
239
240 // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
241 // << *operandAndResultShardings << "\n";);
242
243 ReshardingRquirementKind reshardingRquirement =
244 getReshardingRquirementKind(op: shardingOp, operandAndResultShardings: *operandAndResultShardings);
245 if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
246 // This is the best case. No need to go on.
247 return *shardingOption;
248 }
249
250 shardingOptionsAndReshardingRequirements.emplace_back(
251 Args: std::move(*shardingOption), Args&: reshardingRquirement);
252 }
253 }
254
255 if (shardingOptionsAndReshardingRequirements.empty()) {
256 return ShardingOption::makeEmpty();
257 }
258
259 std::partial_sort(
260 first: shardingOptionsAndReshardingRequirements.begin(),
261 middle: shardingOptionsAndReshardingRequirements.begin() + 1,
262 last: shardingOptionsAndReshardingRequirements.end(),
263 comp: [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
264 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
265 return std::get<ReshardingRquirementKind>(t: a) <
266 std::get<ReshardingRquirementKind>(t: b);
267 });
268
269 LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
270 << shardingOptionsAndReshardingRequirements << "\n";);
271
272 return std::get<ShardingOption>(
273 t&: shardingOptionsAndReshardingRequirements.front());
274}
275
276// For each operation that implements the ShardingInterface, infer the sharding
277// option of the operation from its operands and/or results using the
278// `getShardingOption` method. If the inferred sharding option is not empty, add
279// a `mesh.shard` operation for all remaining operands and results that do not
280// have sharding annotations.
281static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
282 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(Val: op);
283 if (op->hasTrait<OpTrait::IsTerminator>() ||
284 (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
285 llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(Val: op))
286 return success();
287
288 if (!shardingOp) {
289 op->emitOpError() << "sharding interface is not implemented.";
290 return failure();
291 }
292
293 // collect MeshSharding from results
294 std::vector<MeshSharding> allowConflictsResultShardings;
295 allowConflictsResultShardings.resize(new_size: op->getNumResults());
296 std::vector<MeshSharding> resultMustShardings;
297 resultMustShardings.resize(new_size: op->getNumResults());
298 for (OpResult result : op->getResults()) {
299 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
300 getMeshSharding(result);
301 if (failed(Result: maybeShardAttr))
302 continue;
303 if (!maybeShardAttr->first)
304 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
305 else
306 allowConflictsResultShardings[result.getResultNumber()] =
307 maybeShardAttr->second;
308 }
309
310 // collect MeshSharding from operands
311 std::vector<MeshSharding> allowConflictsOperandShardings;
312 allowConflictsOperandShardings.resize(new_size: op->getNumOperands());
313 std::vector<MeshSharding> operandMustShardings;
314 operandMustShardings.resize(new_size: op->getNumOperands());
315 for (OpOperand &opOperand : op->getOpOperands()) {
316 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
317 getMeshSharding(opOperand);
318 if (failed(Result: maybeShardAttr))
319 continue;
320
321 if (maybeShardAttr->first)
322 operandMustShardings[opOperand.getOperandNumber()] =
323 maybeShardAttr->second;
324 else
325 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
326 maybeShardAttr->second;
327 }
328
329 // try to get the sharding option
330 SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
331 getOrderedPossibleShardingAttrs(mustShardings: operandMustShardings,
332 optionalShardings: allowConflictsOperandShardings);
333 SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
334 getOrderedPossibleShardingAttrs(mustShardings: resultMustShardings,
335 optionalShardings: allowConflictsResultShardings);
336 FailureOr<ShardingOption> shardingOption = selectShardingOption(
337 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
338
339 if (failed(Result: shardingOption)) {
340 op->emitOpError() << "fail to get sharding option.";
341 return failure();
342 }
343
344 LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
345
346 // sharding info is empty, return immediately
347 if (shardingOption->empty)
348 return success();
349
350 if (failed(Result: shardingOp.addShardingAnnotations(b&: builder, shardingOption: *shardingOption))) {
351 op->emitOpError() << "fail to set sharding annotations.";
352 return failure();
353 }
354 return success();
355}
356
357//===----------------------------------------------------------------------===//
358// ShardingPropagation
359//===----------------------------------------------------------------------===//
360struct ShardingPropagation
361 : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
362
363 using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
364
365 void runOnOperation() override {
366 FunctionOpInterface funcOp = getOperation();
367 MLIRContext *ctx = funcOp.getContext();
368 Region &region = funcOp.getFunctionBody();
369 OpBuilder builder(ctx);
370 if (!region.hasOneBlock()) {
371 funcOp.emitOpError() << "only one block is supported!";
372 return signalPassFailure();
373 }
374 Block &block = region.front();
375
376 LLVM_DEBUG(
377 DBGS() << "print all the ops' iterator types and indexing maps in the "
378 "block.\n";
379 for (Operation &op
380 : block.getOperations()) {
381 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
382 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
383 });
384
385 auto traverse = [&](auto &&range, OpBuilder &builder,
386 const char *order) -> bool {
387 for (Operation &op : range) {
388 if (failed(Result: visitOp(op: &op, builder))) {
389 signalPassFailure();
390 return true;
391 }
392 }
393 LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
394 << funcOp << "\n");
395 LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
396 return false;
397 };
398
399 // 1. Propagate in reversed order.
400 if (traversal == TraversalOrder::Backward ||
401 traversal == TraversalOrder::BackwardForward)
402 traverse(llvm::reverse(C&: block), builder, "backward");
403
404 // 2. Propagate in original order.
405 if (traversal != TraversalOrder::Backward)
406 traverse(block, builder, "forward");
407
408 // 3. Propagate in backward order if needed.
409 if (traversal == TraversalOrder::ForwardBackward)
410 traverse(llvm::reverse(C&: block), builder, "backward");
411 }
412};
413

source code of mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp