1//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Passes.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Mesh/IR/MeshOps.h"
14#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15#include "mlir/IR/Verifier.h"
16#include "mlir/Interfaces/FunctionInterfaces.h"
17#include "mlir/Pass/Pass.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/raw_ostream.h"
23#include <algorithm>
24#include <vector>
25
26namespace mlir {
27namespace mesh {
28#define GEN_PASS_DEF_SHARDINGPROPAGATION
29#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
30} // namespace mesh
31} // namespace mlir
32
33#define DEBUG_TYPE "sharding-propagation"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35
36using namespace mlir;
37using namespace mlir::mesh;
38
39enum class ReshardingRquirementKind {
40 NO_RESHARDING = 0,
41 NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
42 RESHARDING_FOR_EXPLICIT_ANNOTATIONS
43};
44
45#ifdef LLVM_DEBUG
46
47template <typename T>
48static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
49 const SmallVector<T> &vec);
50template <typename... Ts>
51static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52 const std::tuple<Ts...> &t);
53static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
54 ReshardingRquirementKind v);
55
56template <typename Stream, typename Range>
57static Stream &printRange(Stream &stream, Range &&range) {
58 stream << "[";
59 llvm::for_each(range, [&stream](auto &v) {
60 stream << v;
61 stream << ", ";
62 });
63 return stream << "]";
64}
65
66template <typename T>
67static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
68 const SmallVector<T> &vec) {
69 return printRange(stream, vec);
70}
71
72[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
73 const ShardingOption &v) {
74 return stream << "{empty = " << v.empty << ", mesh" << v.mesh
75 << ", shardingArray = " << v.shardingArray << "}";
76}
77
78template <typename Stream, typename... Ts, size_t... Is>
79static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80 std::index_sequence<Is...>) {
81 static_assert(sizeof...(Is) == sizeof...(Ts),
82 "Indices must have same number of elements as tuple types!");
83 static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
84
85 stream << "{";
86 ((stream << std::get<Is>(tuple) << ", "), ...);
87 return stream << "}";
88}
89
90template <typename... Ts>
91static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
92 const std::tuple<Ts...> &t) {
93 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
94}
95
96[[maybe_unused]] static llvm::raw_ostream &
97operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
98 return stream << static_cast<int>(v);
99}
100
101#endif // LLVM_DEBUG
102
103//===----------------------------------------------------------------------===//
104// Utilities
105//===----------------------------------------------------------------------===//
106
107// This method retrieves all potential sharding attributes, prioritizing
108// specific shardings. For example, mustShardings = [shard0, None] and
109// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
110// [shard0, None]]
111static SmallVector<std::vector<MeshSharding>>
112getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
113 ArrayRef<MeshSharding> optionalShardings) {
114 SmallVector<std::vector<MeshSharding>> allShardingAttrs;
115 std::vector<MeshSharding> curShardingAttrs;
116
117 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
118 if (i == mustShardings.size()) {
119 allShardingAttrs.push_back(Elt: std::vector<MeshSharding>(curShardingAttrs));
120 return;
121 }
122
123 if (mustShardings[i]) {
124 curShardingAttrs.push_back(x: mustShardings[i]);
125 dfsCreateShardingAttrs(i + 1);
126 curShardingAttrs.pop_back();
127 return;
128 }
129
130 if (optionalShardings[i]) {
131 curShardingAttrs.push_back(x: optionalShardings[i]);
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
134 curShardingAttrs.push_back(x: {});
135 dfsCreateShardingAttrs(i + 1);
136 curShardingAttrs.pop_back();
137 return;
138 }
139
140 curShardingAttrs.push_back(x: {});
141 dfsCreateShardingAttrs(i + 1);
142 curShardingAttrs.pop_back();
143 };
144
145 dfsCreateShardingAttrs(0);
146 return allShardingAttrs;
147}
148
149// The order of preference is form highest to lowest:
150// 1. No resharding is required (all existing annotations are compatible).
151// 2. No resharding for operands/results that have annotation specifically
152// targeting this operation. This means
153// * operands that are the result of `mesh.shard` ops marked with
154// `annotate_for_users`.
155// * results that are annotated with `mesh.shard` ops without
156// `annotate_for_users`.
157// 3. All other cases. Resharding is required for operands/results with
158// annotation targeting explicitly this operation.
159ReshardingRquirementKind getReshardingRquirementKind(
160 Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
161 ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
162
163 size_t operandsCount = op->getOperands().size();
164 auto operandShardings =
165 llvm::make_range(x: operandAndResultShardings.begin(),
166 y: operandAndResultShardings.begin() + operandsCount);
167 auto resultShardings =
168 llvm::make_range(x: operandAndResultShardings.begin() + operandsCount,
169 y: operandAndResultShardings.end());
170
171 for (auto [operand, sharding] :
172 llvm::zip_equal(t: op->getOperands(), u&: operandShardings)) {
173 ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
174 if (!shardOp) {
175 continue;
176 }
177 bool needsResharding = sharding != shardOp.getSharding();
178 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
179 if (needsResharding) {
180 if (isExplicitAnnotationForThisOp) {
181 // This is the worst case. No need to continue.
182 return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
183 }
184 res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
185 }
186 }
187
188 for (auto [result, sharding] :
189 llvm::zip_equal(t: op->getResults(), u&: resultShardings)) {
190 for (auto user : result.getUsers()) {
191 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
192 if (!shardOp) {
193 continue;
194 }
195 bool needsResharding = sharding != shardOp.getSharding();
196 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
197 if (needsResharding) {
198 if (isExplicitAnnotationForThisOp) {
199 // This is the worst case. No need to continue.
200 return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
201 }
202 res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
203 }
204 }
205 }
206
207 return res;
208}
209
210// From all the operand and result sharding combinations,
211// return the one that is most desirable.
212// The order of preference is:
213// 1. No resharding with respect to existing sharding annotations.
214// 2. Resharding for values that have already annotations that do not target
215// this op.
216// 3. Resharding of existing explicit sharding annotations for this op.
217static FailureOr<ShardingOption> selectShardingOption(
218 ShardingInterface shardingOp,
219 ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
220 ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
221 SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
222 shardingOptionsAndReshardingRequirements;
223
224 for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
225 for (ArrayRef<MeshSharding> operandShardings :
226 possibleOperandShardingAttrs) {
227 FailureOr<ShardingOption> shardingOption =
228 shardingOp.getShardingOption(operandShardings, resultShardings);
229 if (failed(Result: shardingOption) || shardingOption->empty) {
230 continue;
231 }
232 // These shardings may not be the same as those in operandShardings and
233 // resultShardings.
234 // They may be missing some annotations.
235 // Whatever is returned by getShardingAnnotations is exactly what the op
236 // needs.
237 FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
238 shardingOp.getShardingAnnotations(*shardingOption);
239 if (failed(Result: operandAndResultShardings)) {
240 return failure();
241 }
242
243 // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
244 // << *operandAndResultShardings << "\n";);
245
246 ReshardingRquirementKind reshardingRquirement =
247 getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
248 if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
249 // This is the best case. No need to go on.
250 return *shardingOption;
251 }
252
253 shardingOptionsAndReshardingRequirements.emplace_back(
254 Args: std::move(*shardingOption), Args&: reshardingRquirement);
255 }
256 }
257
258 if (shardingOptionsAndReshardingRequirements.empty()) {
259 return ShardingOption::makeEmpty();
260 }
261
262 std::partial_sort(
263 first: shardingOptionsAndReshardingRequirements.begin(),
264 middle: shardingOptionsAndReshardingRequirements.begin() + 1,
265 last: shardingOptionsAndReshardingRequirements.end(),
266 comp: [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
267 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
268 return std::get<ReshardingRquirementKind>(t: a) <
269 std::get<ReshardingRquirementKind>(t: b);
270 });
271
272 LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
273 << shardingOptionsAndReshardingRequirements << "\n";);
274
275 return std::get<ShardingOption>(
276 t&: shardingOptionsAndReshardingRequirements.front());
277}
278
279// For each operation that implements the ShardingInterface, infer the sharding
280// option of the operation from its operands and/or results using the
281// `getShardingOption` method. If the inferred sharding option is not empty, add
282// a `mesh.shard` operation for all remaining operands and results that do not
283// have sharding annotations.
284static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
285 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
286 if (op->hasTrait<OpTrait::IsTerminator>() ||
287 (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
288 llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
289 return success();
290
291 if (!shardingOp) {
292 op->emitOpError() << "sharding interface is not implemented.";
293 return failure();
294 }
295
296 // collect MeshSharding from results
297 std::vector<MeshSharding> allowConflictsResultShardings;
298 allowConflictsResultShardings.resize(new_size: op->getNumResults());
299 std::vector<MeshSharding> resultMustShardings;
300 resultMustShardings.resize(new_size: op->getNumResults());
301 for (OpResult result : op->getResults()) {
302 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
303 getMeshSharding(result);
304 if (failed(Result: maybeShardAttr))
305 continue;
306 if (!maybeShardAttr->first)
307 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
308 else
309 allowConflictsResultShardings[result.getResultNumber()] =
310 maybeShardAttr->second;
311 }
312
313 // collect MeshSharding from operands
314 std::vector<MeshSharding> allowConflictsOperandShardings;
315 allowConflictsOperandShardings.resize(new_size: op->getNumOperands());
316 std::vector<MeshSharding> operandMustShardings;
317 operandMustShardings.resize(new_size: op->getNumOperands());
318 for (OpOperand &opOperand : op->getOpOperands()) {
319 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
320 getMeshSharding(opOperand);
321 if (failed(Result: maybeShardAttr))
322 continue;
323
324 if (maybeShardAttr->first)
325 operandMustShardings[opOperand.getOperandNumber()] =
326 maybeShardAttr->second;
327 else
328 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
329 maybeShardAttr->second;
330 }
331
332 // try to get the sharding option
333 SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
334 getOrderedPossibleShardingAttrs(mustShardings: operandMustShardings,
335 optionalShardings: allowConflictsOperandShardings);
336 SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
337 getOrderedPossibleShardingAttrs(mustShardings: resultMustShardings,
338 optionalShardings: allowConflictsResultShardings);
339 FailureOr<ShardingOption> shardingOption = selectShardingOption(
340 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
341
342 if (failed(Result: shardingOption)) {
343 op->emitOpError() << "fail to get sharding option.";
344 return failure();
345 }
346
347 LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
348
349 // sharding info is empty, return immediately
350 if (shardingOption->empty)
351 return success();
352
353 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
354 op->emitOpError() << "fail to set sharding annotations.";
355 return failure();
356 }
357 return success();
358}
359
360//===----------------------------------------------------------------------===//
361// ShardingPropagation
362//===----------------------------------------------------------------------===//
363struct ShardingPropagation
364 : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
365 void runOnOperation() override {
366 FunctionOpInterface funcOp = getOperation();
367 MLIRContext *ctx = funcOp.getContext();
368 Region &region = funcOp.getFunctionBody();
369 OpBuilder builder(ctx);
370 if (!region.hasOneBlock()) {
371 funcOp.emitOpError() << "only one block is supported!";
372 return signalPassFailure();
373 }
374 Block &block = region.front();
375
376 LLVM_DEBUG(
377 DBGS() << "print all the ops' iterator types and indexing maps in the "
378 "block.\n";
379 for (Operation &op
380 : block.getOperations()) {
381 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
382 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
383 });
384
385 // 1. propagate in reversed order
386 for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
387 if (failed(visitOp(&op, builder)))
388 return signalPassFailure();
389
390 LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
391 << funcOp << "\n");
392 LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
393
394 // 2. propagate in original order
395 for (Operation &op : llvm::make_early_inc_range(block))
396 if (failed(visitOp(&op, builder)))
397 return signalPassFailure();
398 }
399};
400

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp