1 | //===- ShardingPropagation.cpp ------------------------------------- C++ --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
13 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
14 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
15 | #include "mlir/Interfaces/FunctionInterfaces.h" |
16 | #include "mlir/Pass/Pass.h" |
17 | #include "llvm/Support/Debug.h" |
18 | #include <vector> |
19 | |
20 | namespace mlir { |
21 | namespace mesh { |
22 | #define GEN_PASS_DEF_SHARDINGPROPAGATION |
23 | #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" |
24 | } // namespace mesh |
25 | } // namespace mlir |
26 | |
27 | #define DEBUG_TYPE "sharding-propagation" |
28 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::mesh; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Utilities |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | // This method retrieves all potential sharding attributes, prioritizing |
38 | // specific shardings. For example, mustShardings = [shard0, None] and |
39 | // optionalShardings = [None, shard1], the result will be [[shard0, shard1], |
40 | // [shard0, None]] |
41 | static SmallVector<SmallVector<MeshShardingAttr>> |
42 | getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings, |
43 | ArrayRef<MeshShardingAttr> optionalShardings) { |
44 | SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs; |
45 | SmallVector<MeshShardingAttr> curShardingAttrs; |
46 | |
47 | std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) { |
48 | if (i == mustShardings.size()) { |
49 | allShardingAttrs.push_back( |
50 | SmallVector<MeshShardingAttr>(curShardingAttrs)); |
51 | return; |
52 | } |
53 | |
54 | if (mustShardings[i]) { |
55 | curShardingAttrs.push_back(mustShardings[i]); |
56 | dfsCreateShardingAttrs(i + 1); |
57 | curShardingAttrs.pop_back(); |
58 | return; |
59 | } |
60 | |
61 | if (optionalShardings[i]) { |
62 | curShardingAttrs.push_back(optionalShardings[i]); |
63 | dfsCreateShardingAttrs(i + 1); |
64 | curShardingAttrs.pop_back(); |
65 | curShardingAttrs.push_back(nullptr); |
66 | dfsCreateShardingAttrs(i + 1); |
67 | curShardingAttrs.pop_back(); |
68 | return; |
69 | } |
70 | |
71 | curShardingAttrs.push_back(nullptr); |
72 | dfsCreateShardingAttrs(i + 1); |
73 | curShardingAttrs.pop_back(); |
74 | }; |
75 | |
76 | dfsCreateShardingAttrs(0); |
77 | return allShardingAttrs; |
78 | } |
79 | |
80 | // For each operation that implements the ShardingInterface, infer the sharding |
81 | // option of the operation from its operands and/or results using the |
82 | // `getShardingOption` method. If the inferred sharding option is not empty, add |
83 | // a `mesh.shard` operation for all remaining operands and results that do not |
84 | // have sharding annotations. |
85 | static LogicalResult visitOp(Operation *op, OpBuilder &builder) { |
86 | if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op)) |
87 | return success(); |
88 | |
89 | ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op); |
90 | if (!shardingOp) { |
91 | op->emitOpError() << "sharding interface is not implemented." ; |
92 | return failure(); |
93 | } |
94 | |
95 | // collect MeshShardingAttr from results |
96 | SmallVector<MeshShardingAttr> allowConflictsResultShardings; |
97 | allowConflictsResultShardings.resize(op->getNumResults()); |
98 | SmallVector<MeshShardingAttr> resultMustShardings; |
99 | resultMustShardings.resize(op->getNumResults()); |
100 | for (OpResult result : op->getResults()) { |
101 | FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr = |
102 | getMeshShardingAttr(result); |
103 | if (failed(maybeShardAttr)) |
104 | continue; |
105 | if (!maybeShardAttr->first) |
106 | resultMustShardings[result.getResultNumber()] = maybeShardAttr->second; |
107 | else |
108 | allowConflictsResultShardings[result.getResultNumber()] = |
109 | maybeShardAttr->second; |
110 | } |
111 | |
112 | // collect MeshShardingAttr from operands |
113 | SmallVector<MeshShardingAttr> allowConflictsOperandShardings; |
114 | allowConflictsOperandShardings.resize(op->getNumOperands()); |
115 | SmallVector<MeshShardingAttr> operandMustShardings; |
116 | operandMustShardings.resize(op->getNumOperands()); |
117 | for (OpOperand &opOperand : op->getOpOperands()) { |
118 | FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr = |
119 | getMeshShardingAttr(opOperand); |
120 | if (failed(maybeShardAttr)) |
121 | continue; |
122 | |
123 | if (maybeShardAttr->first) |
124 | operandMustShardings[opOperand.getOperandNumber()] = |
125 | maybeShardAttr->second; |
126 | else |
127 | allowConflictsOperandShardings[opOperand.getOperandNumber()] = |
128 | maybeShardAttr->second; |
129 | } |
130 | |
131 | // try to get the sharding option |
132 | SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs = |
133 | getOrderedPossibleShardingAttrs(operandMustShardings, |
134 | allowConflictsOperandShardings); |
135 | SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs = |
136 | getOrderedPossibleShardingAttrs(resultMustShardings, |
137 | allowConflictsResultShardings); |
138 | FailureOr<ShardingOption> finalShardingOption = failure(); |
139 | for (ArrayRef<MeshShardingAttr> resultShardings : |
140 | possibleResultShardingAttrs) { |
141 | if (succeeded(finalShardingOption)) |
142 | break; |
143 | for (ArrayRef<MeshShardingAttr> operandShardings : |
144 | possibleOperandShardingAttrs) { |
145 | FailureOr<ShardingOption> shardingOption = |
146 | shardingOp.getShardingOption(operandShardings, resultShardings); |
147 | if (succeeded(shardingOption)) { |
148 | finalShardingOption = shardingOption; |
149 | break; |
150 | } |
151 | } |
152 | } |
153 | |
154 | if (failed(result: finalShardingOption)) { |
155 | op->emitOpError() << "fail to get sharding option." ; |
156 | return failure(); |
157 | } |
158 | // sharding info is empty, return immediately |
159 | if (finalShardingOption->empty) |
160 | return success(); |
161 | |
162 | if (failed( |
163 | shardingOp.addShardingAnnotations(builder, *finalShardingOption))) { |
164 | op->emitOpError() << "fail to set sharding annotations." ; |
165 | return failure(); |
166 | } |
167 | return success(); |
168 | } |
169 | |
170 | //===----------------------------------------------------------------------===// |
171 | // ShardingPropagation |
172 | //===----------------------------------------------------------------------===// |
173 | struct ShardingPropagation |
174 | : public mesh::impl::ShardingPropagationBase<ShardingPropagation> { |
175 | void runOnOperation() override { |
176 | FunctionOpInterface funcOp = getOperation(); |
177 | MLIRContext *ctx = funcOp.getContext(); |
178 | Region ®ion = funcOp.getFunctionBody(); |
179 | OpBuilder builder(ctx); |
180 | if (!region.hasOneBlock()) { |
181 | funcOp.emitOpError() << "only one block is supported!" ; |
182 | signalPassFailure(); |
183 | } |
184 | Block &block = region.front(); |
185 | |
186 | LLVM_DEBUG( |
187 | DBGS() << "print all the ops' iterator types and indexing maps in the " |
188 | "block.\n" ; |
189 | for (Operation &op |
190 | : block.getOperations()) { |
191 | if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) |
192 | shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); |
193 | }); |
194 | |
195 | // 1. propagate in reversed order |
196 | for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) |
197 | if (failed(visitOp(&op, builder))) |
198 | return signalPassFailure(); |
199 | |
200 | LLVM_DEBUG(DBGS() << "After reversed order propagation:\n" |
201 | << funcOp << "\n" ); |
202 | |
203 | // 2. propagate in original order |
204 | for (Operation &op : llvm::make_early_inc_range(block)) |
205 | if (failed(visitOp(&op, builder))) |
206 | return signalPassFailure(); |
207 | } |
208 | }; |
209 | |