1//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16#include "mlir/Dialect/Mesh/IR/MeshOps.h"
17#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
18#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
19#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/IR/AffineExpr.h"
24#include "mlir/IR/DialectRegistry.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/OpDefinition.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/SymbolTable.h"
31#include "mlir/IR/Value.h"
32#include "mlir/Interfaces/TilingInterface.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include <iterator>
38#include <numeric>
39#include <optional>
40#include <utility>
41
42namespace mlir::linalg {
43
44using MeshAxis = mesh::MeshAxis;
45using ReductionKind = mesh::ReductionKind;
46using MeshSharding = mesh::MeshSharding;
47using ShardingArray = mesh::ShardingArray;
48using MeshOp = mesh::MeshOp;
49
50// Returns the corresponding mesh reduction kind for the given arith op.
51static ReductionKind getReductionKind(Operation *op) {
52 return llvm::TypeSwitch<Operation *, ReductionKind>(op)
53 // Floating-point operations.
54 .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
55 .Case([](arith::MulFOp op) { return ReductionKind::Product; })
56 // TODO: handle maxnumf and minnumf.
57 .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
58 .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
59 // Integer operations.
60 .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
61 .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
62 .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
63 .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
64 // TODO: handle signless, signed and unsigned types properly.
65 // It is assumed that the element type of the collective operands and
66 // result drive the meaning of the reduction kind, whether it is signed
67 // or unsigned.
68 // The reduction op inside the linalg op may have different result type
69 // from the element type of the linalg op's result.
70 // Also signed and unsigned Arith dialect ops may accept signed, unsigned
71 // or signless operands.
72 // Maybe expand the reduction kinds.
73 .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
74 .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
75 .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
76 .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
77 .Case([](arith::MulIOp op) { return ReductionKind::Product; })
78 .Default([](Operation *op) { return ReductionKind::Generic; });
79}
80
81static std::optional<Operation *> getCombinerOp(LinalgOp op) {
82 SmallVector<Operation *> combinerOps;
83 Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
84 if (!reducedValue || combinerOps.size() != 1) {
85 return std::nullopt;
86 }
87
88 return combinerOps[0];
89}
90
91static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
92 std::optional<Operation *> reductionOp = getCombinerOp(op);
93 if (!reductionOp) {
94 return ReductionKind::Generic;
95 }
96 [[maybe_unused]] Type resultElementType =
97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
98 // TODO: handle case when result type of the reduction op does not match the
99 // element type of the result tensor.
100 // Would it makes sense at all?
101 assert(resultElementType == reductionOp.value()->getResult(0).getType());
102 return getReductionKind(reductionOp.value());
103}
104
105static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106 ArrayRef<MeshSharding> resultShardings,
107 SymbolTableCollection &symbolTable) {
108 for (const MeshSharding &sharding : operandShardings) {
109 if (sharding) {
110 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111 }
112 }
113
114 for (const MeshSharding &sharding : resultShardings) {
115 if (sharding) {
116 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117 }
118 }
119
120 assert(false);
121 return nullptr;
122}
123
124// Choose the operand based on the current process index along the reduction
125// mesh axes.
126// We need to use the initial value only once to avoid including it in the
127// reduction multiple times.
128// In each process group only the leading process with linear index 0 would use
129// the original operand.
130// The other processes would use the reduction operation neutral tensor.
131static Value createDestinationPassingStyleInitOperand(
132 LinalgOp op, int operandNumber, Value spmdizedOperand,
133 ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134 ImplicitLocOpBuilder &builder) {
135 Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
136 mesh: meshOp.getSymName(), meshAxes: reductionMeshAxes, builder);
137 Value zero = builder.create<arith::ConstantIndexOp>(args: 0);
138 Value isLeadProcess = builder.create<arith::CmpIOp>(
139 builder.getI1Type(), arith::CmpIPredicate::eq,
140 processLinearIndexInReductionGroup, zero);
141 scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
142 isLeadProcess, true, true);
143 // Then block.
144 {
145 OpBuilder::InsertionGuard insertionGuard(builder);
146 builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
147 builder.create<scf::YieldOp>(spmdizedOperand);
148 }
149
150 // Else block.
151 {
152 OpBuilder::InsertionGuard insertionGuard(builder);
153 builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
154 SmallVector<OpFoldResult> shape =
155 tensor::getMixedSizes(builder, loc: builder.getLoc(), value: spmdizedOperand);
156
157 SmallVector<Operation *> combinerOps;
158 matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159 assert(combinerOps.size() == 1);
160 std::optional<TypedAttr> neutralEl =
161 arith::getNeutralElement(combinerOps[0]);
162
163 Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164 neutralEl.value().getType());
165 Value constant =
166 builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167 Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168 .getResult(0);
169
170 builder.create<scf::YieldOp>(fill);
171 }
172 return ifOp.getResult(0);
173}
174
175// Create the DPS init operands for the spmdized Linalg op.
176// Return all the new spmdized operands.
177static SmallVector<Value> createDestinationPassingStyleInitOperands(
178 LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
179 ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
180 ImplicitLocOpBuilder &builder) {
181 // TODO: add support for multiple destination passing style initial value
182 // operands.
183 assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
184 SmallVector<Value> newOperands = llvm::to_vector(Range&: spmdizedOperands);
185 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
186 Value spmdizedInitOperand =
187 spmdizationMap.lookup(op->getOperands()[operandIdx]);
188 newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
189 op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
190 return newOperands;
191}
192
193static void createAllReduceForResultWithoutPartialSharding(
194 Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
195 MeshSharding resultSharding, ReductionKind reductionKind,
196 IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
197 SmallVector<MeshAxis> allReduceMeshAxes;
198 llvm::copy_if(Range&: opReductionMeshAxes, Out: std::back_inserter(x&: allReduceMeshAxes),
199 P: [&resultSharding](MeshAxis axis) {
200 return !llvm::is_contained(Range: resultSharding.getPartialAxes(),
201 Element: axis);
202 });
203 if (allReduceMeshAxes.empty()) {
204 return;
205 }
206
207 Value spmdizedLinalgOpResult = spmdizationMap.lookup(from: unshardedLinalgOpResult);
208 Value reducedValue = builder.create<mesh::AllReduceOp>(
209 spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
210 reductionKind);
211 spmdizationMap.map(from: unshardedLinalgOpResult, to: reducedValue);
212}
213
214static void createAllReduceForResultsWithoutPartialShardings(
215 LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
216 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
217 ImplicitLocOpBuilder &builder) {
218 ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
219 for (auto [unshardedLinalgOpResult, resultSharding] :
220 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
221 createAllReduceForResultWithoutPartialSharding(
222 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
223 reductionKind, spmdizationMap, builder);
224 }
225}
226
227static void spmdizeLinalgOpWithShardedReduction(
228 LinalgOp op, ArrayRef<Value> spmdizedOperands,
229 ArrayRef<MeshSharding> operandShardings,
230 ArrayRef<MeshSharding> resultShardings,
231 ArrayRef<utils::IteratorType> loopIteratorTypes,
232 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
233 IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
234 ImplicitLocOpBuilder &builder) {
235 MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
236 SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
237 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
238 SmallVector<Value> spmdizedLinalgOpOperands =
239 createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
240 reductionMeshAxes,
241 spmdizationMap, builder);
242 // We must not change the operand mappings of the original spmdizationMap as
243 // they are the mappings for the whole spmdization blob and may be used by
244 // others.
245 IRMapping internalSpmdizationMap;
246 for (auto [unshardedOperand, spmdizedOperand] :
247 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
248 internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
249 }
250 spmdizeTriviallyShardableOperation(
251 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
252 internalSpmdizationMap, symbolTable, builder);
253 for (Value result : op->getResults()) {
254 spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
255 }
256
257 // Handle partial shardings.
258 createAllReduceForResultsWithoutPartialShardings(
259 op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
260}
261
262namespace {
263
264// ShardingInterface for ops that implement LinalgStructuredInterface.
265// The supported ops are only those where the indexing maps are projected
266// permutations.
267template <typename Op>
268struct StructuredOpShardingInterface
269 : public mesh::ShardingInterface::ExternalModel<
270 StructuredOpShardingInterface<Op>, Op> {
271 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
272 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
273 }
274
275 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
276 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
277 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
278
279 // Results must have the same indexing as destination passing style initial
280 // operands.
281 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
282 res.push_back(Elt: res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
283 }
284
285 return res;
286 }
287
288 SmallVector<ReductionKind>
289 getReductionLoopIteratorKinds(Operation *op) const {
290 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
291 SmallVector<utils::IteratorType> iteratorTypes =
292 linalgOp.getIteratorTypesArray();
293 unsigned reductionItersCount = std::accumulate(
294 iteratorTypes.begin(), iteratorTypes.end(), 0,
295 [](unsigned count, utils::IteratorType iter) {
296 return count + (iter == utils::IteratorType::reduction);
297 });
298 mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
299 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
300 }
301
302 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
303 ArrayRef<MeshSharding> operandShardings,
304 ArrayRef<MeshSharding> resultShardings,
305 IRMapping &spmdizationMap,
306 SymbolTableCollection &symbolTable,
307 OpBuilder &builder) const {
308 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
309
310 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
311 bool allIndexingMapsAreProjectedPermutation =
312 llvm::all_of(indexingMaps, [](AffineMap map) {
313 return map.isProjectedPermutation();
314 });
315 if (!allIndexingMapsAreProjectedPermutation) {
316 // TODO: handle non-projected permutations.
317 return op->emitOpError()
318 << "supports indexing maps that are only projected permutation.";
319 }
320
321 SmallVector<utils::IteratorType> loopIteratorTypes =
322 linalgOp.getIteratorTypesArray();
323 ShardingArray meshAxisAssignmentForLoopIterators =
324 getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
325 loopIteratorTypes, indexingMaps);
326 if (mesh::isAtLeastOneReductionIteratorSharded(
327 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
328 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
329 spmdizeLinalgOpWithShardedReduction(
330 linalgOp, spmdizedOperands, operandShardings, resultShardings,
331 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
332 symbolTable, implicitLocBuilder);
333 } else {
334 spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands,
335 operandShardings, resultShardings,
336 spmdizationMap, symbolTable, builder);
337 }
338
339 return success();
340 }
341};
342
343} // namespace
344
345template <typename OpType>
346static void registerOne(MLIRContext *ctx) {
347 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
348}
349
350/// Variadic helper function.
351template <typename... OpTypes>
352static void registerAll(MLIRContext *ctx) {
353 (registerOne<OpTypes>(ctx), ...);
354}
355
356void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
357 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
358 DialectRegistry registry;
359 registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
360 tensor::TensorDialect>();
361 ctx->appendDialectRegistry(registry);
362 for (StringRef name : registry.getDialectNames())
363 ctx->getOrLoadDialect(name);
364
365 registerOne<linalg::GenericOp>(ctx);
366 registerAll<
367#define GET_OP_LIST
368#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
369 >(ctx);
370 });
371}
372
373} // namespace mlir::linalg
374

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp