1//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Passes.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Mesh/IR/MeshOps.h"
14#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15#include "mlir/Interfaces/FunctionInterfaces.h"
16#include "mlir/Pass/Pass.h"
17#include "llvm/Support/Debug.h"
18#include <vector>
19
20namespace mlir {
21namespace mesh {
22#define GEN_PASS_DEF_SHARDINGPROPAGATION
23#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
24} // namespace mesh
25} // namespace mlir
26
27#define DEBUG_TYPE "sharding-propagation"
28#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29
30using namespace mlir;
31using namespace mlir::mesh;
32
33//===----------------------------------------------------------------------===//
34// Utilities
35//===----------------------------------------------------------------------===//
36
37// This method retrieves all potential sharding attributes, prioritizing
38// specific shardings. For example, mustShardings = [shard0, None] and
39// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
40// [shard0, None]]
41static SmallVector<SmallVector<MeshShardingAttr>>
42getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
43 ArrayRef<MeshShardingAttr> optionalShardings) {
44 SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs;
45 SmallVector<MeshShardingAttr> curShardingAttrs;
46
47 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
48 if (i == mustShardings.size()) {
49 allShardingAttrs.push_back(
50 SmallVector<MeshShardingAttr>(curShardingAttrs));
51 return;
52 }
53
54 if (mustShardings[i]) {
55 curShardingAttrs.push_back(mustShardings[i]);
56 dfsCreateShardingAttrs(i + 1);
57 curShardingAttrs.pop_back();
58 return;
59 }
60
61 if (optionalShardings[i]) {
62 curShardingAttrs.push_back(optionalShardings[i]);
63 dfsCreateShardingAttrs(i + 1);
64 curShardingAttrs.pop_back();
65 curShardingAttrs.push_back(nullptr);
66 dfsCreateShardingAttrs(i + 1);
67 curShardingAttrs.pop_back();
68 return;
69 }
70
71 curShardingAttrs.push_back(nullptr);
72 dfsCreateShardingAttrs(i + 1);
73 curShardingAttrs.pop_back();
74 };
75
76 dfsCreateShardingAttrs(0);
77 return allShardingAttrs;
78}
79
80// For each operation that implements the ShardingInterface, infer the sharding
81// option of the operation from its operands and/or results using the
82// `getShardingOption` method. If the inferred sharding option is not empty, add
83// a `mesh.shard` operation for all remaining operands and results that do not
84// have sharding annotations.
85static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
86 if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
87 return success();
88
89 ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
90 if (!shardingOp) {
91 op->emitOpError() << "sharding interface is not implemented.";
92 return failure();
93 }
94
95 // collect MeshShardingAttr from results
96 SmallVector<MeshShardingAttr> allowConflictsResultShardings;
97 allowConflictsResultShardings.resize(op->getNumResults());
98 SmallVector<MeshShardingAttr> resultMustShardings;
99 resultMustShardings.resize(op->getNumResults());
100 for (OpResult result : op->getResults()) {
101 FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
102 getMeshShardingAttr(result);
103 if (failed(maybeShardAttr))
104 continue;
105 if (!maybeShardAttr->first)
106 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
107 else
108 allowConflictsResultShardings[result.getResultNumber()] =
109 maybeShardAttr->second;
110 }
111
112 // collect MeshShardingAttr from operands
113 SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
114 allowConflictsOperandShardings.resize(op->getNumOperands());
115 SmallVector<MeshShardingAttr> operandMustShardings;
116 operandMustShardings.resize(op->getNumOperands());
117 for (OpOperand &opOperand : op->getOpOperands()) {
118 FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
119 getMeshShardingAttr(opOperand);
120 if (failed(maybeShardAttr))
121 continue;
122
123 if (maybeShardAttr->first)
124 operandMustShardings[opOperand.getOperandNumber()] =
125 maybeShardAttr->second;
126 else
127 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
128 maybeShardAttr->second;
129 }
130
131 // try to get the sharding option
132 SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
133 getOrderedPossibleShardingAttrs(operandMustShardings,
134 allowConflictsOperandShardings);
135 SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
136 getOrderedPossibleShardingAttrs(resultMustShardings,
137 allowConflictsResultShardings);
138 FailureOr<ShardingOption> finalShardingOption = failure();
139 for (ArrayRef<MeshShardingAttr> resultShardings :
140 possibleResultShardingAttrs) {
141 if (succeeded(finalShardingOption))
142 break;
143 for (ArrayRef<MeshShardingAttr> operandShardings :
144 possibleOperandShardingAttrs) {
145 FailureOr<ShardingOption> shardingOption =
146 shardingOp.getShardingOption(operandShardings, resultShardings);
147 if (succeeded(shardingOption)) {
148 finalShardingOption = shardingOption;
149 break;
150 }
151 }
152 }
153
154 if (failed(result: finalShardingOption)) {
155 op->emitOpError() << "fail to get sharding option.";
156 return failure();
157 }
158 // sharding info is empty, return immediately
159 if (finalShardingOption->empty)
160 return success();
161
162 if (failed(
163 shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
164 op->emitOpError() << "fail to set sharding annotations.";
165 return failure();
166 }
167 return success();
168}
169
170//===----------------------------------------------------------------------===//
171// ShardingPropagation
172//===----------------------------------------------------------------------===//
173struct ShardingPropagation
174 : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
175 void runOnOperation() override {
176 FunctionOpInterface funcOp = getOperation();
177 MLIRContext *ctx = funcOp.getContext();
178 Region &region = funcOp.getFunctionBody();
179 OpBuilder builder(ctx);
180 if (!region.hasOneBlock()) {
181 funcOp.emitOpError() << "only one block is supported!";
182 signalPassFailure();
183 }
184 Block &block = region.front();
185
186 LLVM_DEBUG(
187 DBGS() << "print all the ops' iterator types and indexing maps in the "
188 "block.\n";
189 for (Operation &op
190 : block.getOperations()) {
191 if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
192 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
193 });
194
195 // 1. propagate in reversed order
196 for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
197 if (failed(visitOp(&op, builder)))
198 return signalPassFailure();
199
200 LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
201 << funcOp << "\n");
202
203 // 2. propagate in original order
204 for (Operation &op : llvm::make_early_inc_range(block))
205 if (failed(visitOp(&op, builder)))
206 return signalPassFailure();
207 }
208};
209

source code of mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp