1//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16#include "mlir/Dialect/Mesh/IR/MeshOps.h"
17#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
18#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
19#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/IR/AffineExpr.h"
24#include "mlir/IR/DialectRegistry.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/OpDefinition.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/SymbolTable.h"
31#include "mlir/IR/Value.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include <iterator>
36#include <numeric>
37#include <optional>
38
39namespace mlir::linalg {
40
41using MeshAxis = mesh::MeshAxis;
42using ReductionKind = mesh::ReductionKind;
43using MeshSharding = mesh::MeshSharding;
44using ShardingArray = mesh::ShardingArray;
45using MeshOp = mesh::MeshOp;
46
47// Returns the corresponding mesh reduction kind for the given arith op.
48static ReductionKind getReductionKind(Operation *op) {
49 return llvm::TypeSwitch<Operation *, ReductionKind>(op)
50 // Floating-point operations.
51 .Case(caseFn: [](arith::AddFOp op) { return ReductionKind::Sum; })
52 .Case(caseFn: [](arith::MulFOp op) { return ReductionKind::Product; })
53 // TODO: handle maxnumf and minnumf.
54 .Case(caseFn: [](arith::MaximumFOp op) { return ReductionKind::Max; })
55 .Case(caseFn: [](arith::MinimumFOp op) { return ReductionKind::Min; })
56 // Integer operations.
57 .Case(caseFn: [](arith::AddIOp op) { return ReductionKind::Sum; })
58 .Case(caseFn: [](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
59 .Case(caseFn: [](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
60 .Case(caseFn: [](arith::AndIOp op) { return ReductionKind::Sum; })
61 // TODO: handle signless, signed and unsigned types properly.
62 // It is assumed that the element type of the collective operands and
63 // result drive the meaning of the reduction kind, whether it is signed
64 // or unsigned.
65 // The reduction op inside the linalg op may have different result type
66 // from the element type of the linalg op's result.
67 // Also signed and unsigned Arith dialect ops may accept signed, unsigned
68 // or signless operands.
69 // Maybe expand the reduction kinds.
70 .Case(caseFn: [](arith::MaxUIOp op) { return ReductionKind::Max; })
71 .Case(caseFn: [](arith::MinUIOp op) { return ReductionKind::Min; })
72 .Case(caseFn: [](arith::MaxSIOp op) { return ReductionKind::Max; })
73 .Case(caseFn: [](arith::MinSIOp op) { return ReductionKind::Min; })
74 .Case(caseFn: [](arith::MulIOp op) { return ReductionKind::Product; })
75 .Default(defaultFn: [](Operation *op) { return ReductionKind::Generic; });
76}
77
78static std::optional<Operation *> getCombinerOp(LinalgOp op) {
79 SmallVector<Operation *> combinerOps;
80 Value reducedValue = matchReduction(iterCarriedArgs: op.getRegionOutputArgs(), redPos: 0, combinerOps);
81 if (!reducedValue || combinerOps.size() != 1) {
82 return std::nullopt;
83 }
84
85 return combinerOps[0];
86}
87
88static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
89 std::optional<Operation *> reductionOp = getCombinerOp(op);
90 if (!reductionOp) {
91 return ReductionKind::Generic;
92 }
93 [[maybe_unused]] Type resultElementType =
94 llvm::cast<RankedTensorType>(Val: op->getResult(idx: 0).getType()).getElementType();
95 // TODO: handle case when result type of the reduction op does not match the
96 // element type of the result tensor.
97 // Would it makes sense at all?
98 assert(resultElementType == reductionOp.value()->getResult(0).getType());
99 return getReductionKind(op: reductionOp.value());
100}
101
102static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
103 ArrayRef<MeshSharding> resultShardings,
104 SymbolTableCollection &symbolTable) {
105 for (const MeshSharding &sharding : operandShardings) {
106 if (sharding) {
107 return mesh::getMesh(op, meshSymbol: sharding.getMeshAttr(), symbolTableCollection&: symbolTable);
108 }
109 }
110
111 for (const MeshSharding &sharding : resultShardings) {
112 if (sharding) {
113 return mesh::getMesh(op, meshSymbol: sharding.getMeshAttr(), symbolTableCollection&: symbolTable);
114 }
115 }
116
117 assert(false);
118 return nullptr;
119}
120
121// Choose the operand based on the current process index along the reduction
122// mesh axes.
123// We need to use the initial value only once to avoid including it in the
124// reduction multiple times.
125// In each process group only the leading process with linear index 0 would use
126// the original operand.
127// The other processes would use the reduction operation neutral tensor.
128static Value createDestinationPassingStyleInitOperand(
129 LinalgOp op, int operandNumber, Value spmdizedOperand,
130 ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
131 ImplicitLocOpBuilder &builder) {
132 Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
133 mesh: meshOp.getSymName(), meshAxes: reductionMeshAxes, builder);
134 Value zero = builder.create<arith::ConstantIndexOp>(args: 0);
135 Value isLeadProcess = builder.create<arith::CmpIOp>(
136 args: builder.getI1Type(), args: arith::CmpIPredicate::eq,
137 args&: processLinearIndexInReductionGroup, args&: zero);
138 scf::IfOp ifOp = builder.create<scf::IfOp>(args: spmdizedOperand.getType(),
139 args&: isLeadProcess, args: true, args: true);
140 // Then block.
141 {
142 OpBuilder::InsertionGuard insertionGuard(builder);
143 builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
144 builder.create<scf::YieldOp>(args&: spmdizedOperand);
145 }
146
147 // Else block.
148 {
149 OpBuilder::InsertionGuard insertionGuard(builder);
150 builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
151 SmallVector<OpFoldResult> shape =
152 tensor::getMixedSizes(builder, loc: builder.getLoc(), value: spmdizedOperand);
153
154 SmallVector<Operation *> combinerOps;
155 matchReduction(iterCarriedArgs: op.getRegionOutputArgs(), redPos: operandNumber, combinerOps);
156 assert(combinerOps.size() == 1);
157 std::optional<TypedAttr> neutralEl =
158 arith::getNeutralElement(op: combinerOps[0]);
159
160 Value init = builder.create<tensor::EmptyOp>(location: op.getLoc(), args&: shape,
161 args: neutralEl.value().getType());
162 Value constant =
163 builder.create<arith::ConstantOp>(location: op.getLoc(), args&: neutralEl.value());
164 Value fill = builder.create<linalg::FillOp>(location: op.getLoc(), args&: constant, args&: init)
165 .getResult(i: 0);
166
167 builder.create<scf::YieldOp>(args&: fill);
168 }
169 return ifOp.getResult(i: 0);
170}
171
172// Create the DPS init operands for the spmdized Linalg op.
173// Return all the new spmdized operands.
174static SmallVector<Value> createDestinationPassingStyleInitOperands(
175 LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
176 ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
177 ImplicitLocOpBuilder &builder) {
178 // TODO: add support for multiple destination passing style initial value
179 // operands.
180 assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
181 SmallVector<Value> newOperands = llvm::to_vector(Range&: spmdizedOperands);
182 auto operandIdx = op.getDpsInitOperand(i: 0)->getOperandNumber();
183 Value spmdizedInitOperand =
184 spmdizationMap.lookup(from: op->getOperands()[operandIdx]);
185 newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
186 op, operandNumber: 0, spmdizedOperand: spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
187 return newOperands;
188}
189
190static void createAllReduceForResultWithoutPartialSharding(
191 Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
192 MeshSharding resultSharding, ReductionKind reductionKind,
193 IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
194 SmallVector<MeshAxis> allReduceMeshAxes;
195 llvm::copy_if(Range&: opReductionMeshAxes, Out: std::back_inserter(x&: allReduceMeshAxes),
196 P: [&resultSharding](MeshAxis axis) {
197 return !llvm::is_contained(Range: resultSharding.getPartialAxes(),
198 Element: axis);
199 });
200 if (allReduceMeshAxes.empty()) {
201 return;
202 }
203
204 Value spmdizedLinalgOpResult = spmdizationMap.lookup(from: unshardedLinalgOpResult);
205 Value reducedValue = builder.create<mesh::AllReduceOp>(
206 args&: spmdizedLinalgOpResult, args: resultSharding.getMesh(), args&: allReduceMeshAxes,
207 args&: reductionKind);
208 spmdizationMap.map(from: unshardedLinalgOpResult, to: reducedValue);
209}
210
211static void createAllReduceForResultsWithoutPartialShardings(
212 LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
213 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
214 ImplicitLocOpBuilder &builder) {
215 ReductionKind reductionKind = getReductionKindOfLinalgOp(op: unshardedOp);
216 for (auto [unshardedLinalgOpResult, resultSharding] :
217 llvm::zip_equal(t: unshardedOp->getResults(), u&: resultShardings)) {
218 createAllReduceForResultWithoutPartialSharding(
219 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
220 reductionKind, spmdizationMap, builder);
221 }
222}
223
224static void spmdizeLinalgOpWithShardedReduction(
225 LinalgOp op, ArrayRef<Value> spmdizedOperands,
226 ArrayRef<MeshSharding> operandShardings,
227 ArrayRef<MeshSharding> resultShardings,
228 ArrayRef<utils::IteratorType> loopIteratorTypes,
229 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
230 IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
231 ImplicitLocOpBuilder &builder) {
232 MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
233 SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
234 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
235 SmallVector<Value> spmdizedLinalgOpOperands =
236 createDestinationPassingStyleInitOperands(op, meshOp: mesh, spmdizedOperands,
237 reductionMeshAxes,
238 spmdizationMap, builder);
239 // We must not change the operand mappings of the original spmdizationMap as
240 // they are the mappings for the whole spmdization blob and may be used by
241 // others.
242 IRMapping internalSpmdizationMap;
243 for (auto [unshardedOperand, spmdizedOperand] :
244 llvm::zip_equal(t: op->getOperands(), u&: spmdizedLinalgOpOperands)) {
245 internalSpmdizationMap.map(from: unshardedOperand, to: spmdizedOperand);
246 }
247 spmdizeTriviallyShardableOperation(
248 op&: *op, spmdizedOperands: spmdizedLinalgOpOperands, operandShardings, resultShardings,
249 spmdizationMap&: internalSpmdizationMap, symbolTable, builder);
250 for (Value result : op->getResults()) {
251 spmdizationMap.map(from: result, to: internalSpmdizationMap.lookup(from: result));
252 }
253
254 // Handle partial shardings.
255 createAllReduceForResultsWithoutPartialShardings(
256 unshardedOp: op, opReductionMeshAxes: reductionMeshAxes, resultShardings, spmdizationMap, builder);
257}
258
259namespace {
260
261// ShardingInterface for ops that implement LinalgStructuredInterface.
262// The supported ops are only those where the indexing maps are projected
263// permutations.
264template <typename Op>
265struct StructuredOpShardingInterface
266 : public mesh::ShardingInterface::ExternalModel<
267 StructuredOpShardingInterface<Op>, Op> {
268 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
269 return llvm::cast<LinalgOp>(Val: op).getIteratorTypesArray();
270 }
271
272 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
273 LinalgOp linalgOp = llvm::cast<LinalgOp>(Val: op);
274 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
275
276 // Results must have the same indexing as destination passing style initial
277 // operands.
278 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
279 res.push_back(Elt: res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
280 }
281
282 return res;
283 }
284
285 SmallVector<ReductionKind>
286 getReductionLoopIteratorKinds(Operation *op) const {
287 LinalgOp linalgOp = llvm::cast<LinalgOp>(Val: op);
288 SmallVector<utils::IteratorType> iteratorTypes =
289 linalgOp.getIteratorTypesArray();
290 unsigned reductionItersCount = std::accumulate(
291 iteratorTypes.begin(), iteratorTypes.end(), 0,
292 [](unsigned count, utils::IteratorType iter) {
293 return count + (iter == utils::IteratorType::reduction);
294 });
295 mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(op: linalgOp);
296 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
297 }
298
299 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
300 ArrayRef<MeshSharding> operandShardings,
301 ArrayRef<MeshSharding> resultShardings,
302 IRMapping &spmdizationMap,
303 SymbolTableCollection &symbolTable,
304 OpBuilder &builder) const {
305 LinalgOp linalgOp = llvm::cast<LinalgOp>(Val: op);
306
307 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
308 bool allIndexingMapsAreProjectedPermutation =
309 llvm::all_of(indexingMaps, [](AffineMap map) {
310 return map.isProjectedPermutation();
311 });
312 if (!allIndexingMapsAreProjectedPermutation) {
313 // TODO: handle non-projected permutations.
314 return op->emitOpError()
315 << "supports indexing maps that are only projected permutation.";
316 }
317
318 SmallVector<utils::IteratorType> loopIteratorTypes =
319 linalgOp.getIteratorTypesArray();
320 ShardingArray meshAxisAssignmentForLoopIterators =
321 getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
322 loopIteratorTypes, indexingMaps);
323 if (mesh::isAtLeastOneReductionIteratorSharded(
324 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
325 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
326 spmdizeLinalgOpWithShardedReduction(
327 op: linalgOp, spmdizedOperands, operandShardings, resultShardings,
328 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
329 symbolTable, builder&: implicitLocBuilder);
330 } else {
331 spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands,
332 operandShardings, resultShardings,
333 spmdizationMap, symbolTable, builder);
334 }
335
336 return success();
337 }
338};
339
340} // namespace
341
342template <typename OpType>
343static void registerOne(MLIRContext *ctx) {
344 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
345}
346
347/// Variadic helper function.
348template <typename... OpTypes>
349static void registerAll(MLIRContext *ctx) {
350 (registerOne<OpTypes>(ctx), ...);
351}
352
353void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
354 registry.addExtension(extensionFn: +[](MLIRContext *ctx, LinalgDialect *dialect) {
355 DialectRegistry registry;
356 registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
357 tensor::TensorDialect>();
358 ctx->appendDialectRegistry(registry);
359 for (StringRef name : registry.getDialectNames())
360 ctx->getOrLoadDialect(name);
361
362 registerOne<linalg::GenericOp>(ctx);
363 registerAll<
364#define GET_OP_LIST
365#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
366 >(ctx);
367 });
368}
369
370} // namespace mlir::linalg
371

source code of mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp