1//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16#include "mlir/Dialect/Mesh/IR/MeshOps.h"
17#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
18#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
19#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/IR/AffineExpr.h"
24#include "mlir/IR/DialectRegistry.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/OpDefinition.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/SymbolTable.h"
31#include "mlir/IR/Value.h"
32#include "mlir/Interfaces/TilingInterface.h"
33#include "mlir/Support/LogicalResult.h"
34#include "llvm/ADT/ArrayRef.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include <iterator>
39#include <optional>
40#include <utility>
41
42namespace mlir::linalg {
43
44using MeshAxis = mesh::MeshAxis;
45using ReductionKind = mesh::ReductionKind;
46using MeshShardingAttr = mesh::MeshShardingAttr;
47using ShardingArray = mesh::ShardingArray;
48using MeshOp = mesh::MeshOp;
49
50// Returns the corresponding mesh reduction kind for the given arith op.
51static ReductionKind getReductionKind(Operation *op) {
52 return llvm::TypeSwitch<Operation *, ReductionKind>(op)
53 // Floating-point operations.
54 .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55 .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56 // TODO: handle maxnumf and minnumf.
57 .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58 .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59 // Integer operations.
60 .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61 .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62 .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63 .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64 // TODO: handle signless, signed and unsigned types properly.
65 // It is assumed that the element type of the collective operands and
66 // result drive the meaning of the reduction kind, whether it is signed
67 // or unsigned.
68 // The reduction op inside the linalg op may have different result type
69 // from the element type of the linalg op's result.
70 // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71 // or signless operands.
72 // Maybe expand the reduction kinds.
73 .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74 .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75 .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76 .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77 .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78 .Default([](Operation *op) { return ReductionKind::Generic; });
79}
80
81static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82 SmallVector<Operation *> combinerOps;
83 Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84 if (!reducedValue || combinerOps.size() != 1) {
85 return std::nullopt;
86 }
87
88 return combinerOps[0];
89}
90
91static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
92 std::optional<Operation *> reductionOp = getCombinerOp(op);
93 if (!reductionOp) {
94 return ReductionKind::Generic;
95 }
96 [[maybe_unused]] Type resultElementType =
97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98 // TODO: handle case when result type of the reduction op does not match the
99 // element type of the result tensor.
100 // Would it makes sense at all?
101 assert(resultElementType == reductionOp.value()->getResult(0).getType());
102 return getReductionKind(reductionOp.value());
103}
104
105static MeshOp getMesh(Operation *op,
106 ArrayRef<MeshShardingAttr> operandShardings,
107 ArrayRef<MeshShardingAttr> resultShardings,
108 SymbolTableCollection &symbolTable) {
109 for (MeshShardingAttr sharding : operandShardings) {
110 if (sharding) {
111 return mesh::getMesh(op, sharding.getMesh(), symbolTable);
112 }
113 }
114
115 for (MeshShardingAttr sharding : resultShardings) {
116 if (sharding) {
117 return mesh::getMesh(op, sharding.getMesh(), symbolTable);
118 }
119 }
120
121 assert(false);
122 return nullptr;
123}
124
125// Choose the operand based on the current process index along the reduction
126// mesh axes.
127// We need to use the initial value only once to avoid including it in the
128// reduction multiple times.
129// In each process group only the leading process with linear index 0 would use
130// the original operand.
131// The other processes would use the reduction operation neutral tensor.
132static Value createDestinationPassingStyleInitOperand(
133 LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
134 MeshOp meshOp, ImplicitLocOpBuilder &builder) {
135 Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136 mesh: meshOp.getSymName(), meshAxes: reductionMeshAxes, builder);
137 Value zero = builder.create<arith::ConstantIndexOp>(args: 0);
138 Value isLeadProcess = builder.create<arith::CmpIOp>(
139 builder.getI1Type(), arith::CmpIPredicate::eq,
140 processLinearIndexInReductionGroup, zero);
141 scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142 isLeadProcess, true, true);
143 // Then block.
144 {
145 OpBuilder::InsertionGuard insertionGuard(builder);
146 builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147 builder.create<scf::YieldOp>(spmdizedOperand);
148 }
149
150 // Else block.
151 {
152 OpBuilder::InsertionGuard insertionGuard(builder);
153 builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
154 SmallVector<OpFoldResult> shape =
155 tensor::getMixedSizes(builder, loc: builder.getLoc(), value: spmdizedOperand);
156 PartialReductionOpInterface partialReductionIface =
157 llvm::cast<PartialReductionOpInterface>(op.getOperation());
158 FailureOr<Operation *> reductionNeutralTensorOp =
159 partialReductionIface.generateInitialTensorForPartialReduction(
160 builder, builder.getLoc(), shape, {});
161 assert(succeeded(reductionNeutralTensorOp));
162 builder.create<scf::YieldOp>(
163 reductionNeutralTensorOp.value()->getResult(0));
164 }
165 return ifOp.getResult(0);
166}
167
168// Create the DPS init operands for the spmdized Linalg op.
169// Return all the new spmdized operands.
170static SmallVector<Value> createDestinationPassingStyleInitOperands(
171 LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
172 ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
173 ImplicitLocOpBuilder &builder) {
174 // TODO: add support for multiple destination passing style initial value
175 // operands.
176 // PartialReductionOpInterface::generateInitialTensorForPartialReduction
177 // needs to also support multiple DPS initial operands.
178 SmallVector<Value> newOperands = llvm::to_vector(Range&: spmdizedOperands);
179 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
180 Value spmdizedInitOperand =
181 spmdizationMap.lookup(op->getOperands()[operandIdx]);
182 newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
183 op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
184 return newOperands;
185}
186
187static void createAllReduceForResultWithoutPartialSharding(
188 Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
189 MeshShardingAttr resultSharding, ReductionKind reductionKind,
190 IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
191 SmallVector<MeshAxis> allReduceMeshAxes;
192 llvm::copy_if(Range&: opReductionMeshAxes, Out: std::back_inserter(x&: allReduceMeshAxes),
193 P: [&resultSharding](MeshAxis axis) {
194 return !llvm::is_contained(resultSharding.getPartialAxes(),
195 axis);
196 });
197 if (allReduceMeshAxes.empty()) {
198 return;
199 }
200
201 Value spmdizedLinalgOpResult = spmdizationMap.lookup(from: unshardedLinalgOpResult);
202 Value reducedValue = builder.create<mesh::AllReduceOp>(
203 spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
204 allReduceMeshAxes, reductionKind);
205 spmdizationMap.map(from: unshardedLinalgOpResult, to: reducedValue);
206}
207
208static void createAllReduceForResultsWithoutPartialShardings(
209 LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
210 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
211 ImplicitLocOpBuilder &builder) {
212 ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
213 for (auto [unshardedLinalgOpResult, resultSharding] :
214 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
215 createAllReduceForResultWithoutPartialSharding(
216 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
217 reductionKind, spmdizationMap, builder);
218 }
219}
220
221static void spmdizeLinalgOpWithShardedReduction(
222 LinalgOp op, ArrayRef<Value> spmdizedOperands,
223 ArrayRef<MeshShardingAttr> operandShardings,
224 ArrayRef<MeshShardingAttr> resultShardings,
225 ArrayRef<utils::IteratorType> loopIteratorTypes,
226 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
227 IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
228 ImplicitLocOpBuilder &builder) {
229 MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
230 SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
231 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
232 SmallVector<Value> spmdizedLinalgOpOperands =
233 createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
234 reductionMeshAxes,
235 spmdizationMap, builder);
236 // We must not change the operand mappings of the original spmdizationMap as
237 // they are the mappings for the whole spmdization blob and may be used by
238 // others.
239 IRMapping internalSpmdizationMap;
240 for (auto [unshardedOperand, spmdizedOperand] :
241 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
242 internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
243 }
244 spmdizeTriviallyShardableOperation(
245 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
246 internalSpmdizationMap, symbolTable, builder);
247 for (Value result : op->getResults()) {
248 spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
249 }
250
251 // Handle partial shardings.
252 createAllReduceForResultsWithoutPartialShardings(
253 op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
254}
255
256namespace {
257
258// ShardingInterface for ops that implement LinalgStructuredInterface.
259// The supported ops are only those where the indexing maps are projected
260// permutations.
261template <typename Op>
262struct StructuredOpShardingInterface
263 : public mesh::ShardingInterface::ExternalModel<
264 StructuredOpShardingInterface<Op>, Op> {
265 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
266 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
267 }
268
269 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
270 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
271 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
272
273 // Results must have the same indexing as destination passing style initial
274 // operands.
275 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
276 res.push_back(Elt: res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
277 }
278
279 return res;
280 }
281
282 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
283 ArrayRef<MeshShardingAttr> operandShardings,
284 ArrayRef<MeshShardingAttr> resultShardings,
285 IRMapping &spmdizationMap,
286 SymbolTableCollection &symbolTable,
287 OpBuilder &builder) const {
288 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
289
290 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
291 bool allIndexingMapsAreProjectedPermutation =
292 llvm::all_of(indexingMaps, [](AffineMap map) {
293 return map.isProjectedPermutation();
294 });
295 if (!allIndexingMapsAreProjectedPermutation) {
296 // TODO: handle non-projected permutations.
297 return op->emitOpError()
298 << "supports indexing maps that are only projected permutation.";
299 }
300
301 SmallVector<utils::IteratorType> loopIteratorTypes =
302 linalgOp.getIteratorTypesArray();
303 ShardingArray meshAxisAssignmentForLoopIterators =
304 getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
305 loopIteratorTypes, indexingMaps);
306 if (mesh::isAtLeastOneReductionIteratorSharded(
307 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
308 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
309 spmdizeLinalgOpWithShardedReduction(
310 linalgOp, spmdizedOperands, operandShardings, resultShardings,
311 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
312 symbolTable, implicitLocBuilder);
313 } else {
314 spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
315 operandShardings, resultShardings,
316 spmdizationMap, symbolTable, builder);
317 }
318
319 return success();
320 }
321};
322
323} // namespace
324
325template <typename OpType>
326static void registerOne(MLIRContext *ctx) {
327 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
328}
329
330/// Variadic helper function.
331template <typename... OpTypes>
332static void registerAll(MLIRContext *ctx) {
333 (registerOne<OpTypes>(ctx), ...);
334}
335
336void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
337 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
338 DialectRegistry registry;
339 registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
340 tensor::TensorDialect>();
341 ctx->appendDialectRegistry(registry);
342 for (StringRef name : registry.getDialectNames())
343 ctx->getOrLoadDialect(name);
344
345 registerOne<linalg::GenericOp>(ctx);
346 registerAll<
347#define GET_OP_LIST
348#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
349 >(ctx);
350 });
351}
352
353} // namespace mlir::linalg
354

source code of mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp