1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
12#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
13#include "mlir/IR/DialectRegistry.h"
14
15using namespace mlir;
16using namespace mlir::arith;
17using namespace mlir::mesh;
18
19namespace {
20
21// Sharding of arith.constant
22// RankedTensor constants can be sharded like any other tensor.
23// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
24// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
25// Scalar constants are always replicated and need no sharding annotation.
26
27struct ConstantShardingInterface
28 : public ShardingInterface::ExternalModel<ConstantShardingInterface,
29 ConstantOp> {
30 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31 auto ndims = 0;
32 if (auto type = dyn_cast<RankedTensorType>(Val: op->getResult(idx: 0).getType())) {
33 ndims = type.getRank();
34 }
35 return SmallVector<utils::IteratorType>(ndims,
36 utils::IteratorType::parallel);
37 }
38
39 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
40 if (auto type = dyn_cast<RankedTensorType>(Val: op->getResult(idx: 0).getType())) {
41 return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
42 numDims: type.getRank(), context: op->getContext())});
43 }
44 return {};
45 }
46
47 // Indicate failure if no result sharding exists.
48 // Otherwise mirror result sharding if it is a tensor constant.
49 // Otherwise return replication option.
50 FailureOr<ShardingOption>
51 getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
52 ArrayRef<MeshSharding> resultShardings) const {
53 assert(resultShardings.size() == 1 &&
54 "Expecting exactly one result sharding for arith.constant");
55 auto resultSharding = resultShardings[0];
56 if (!resultSharding) {
57 return failure();
58 }
59 if (auto type = dyn_cast<RankedTensorType>(Val: op->getResult(idx: 0).getType())) {
60 ShardingArray axesArray(resultSharding.getSplitAxes().size());
61 for (auto [i, axes] : llvm::enumerate(First: resultSharding.getSplitAxes())) {
62 axesArray[i].append(in_start: axes.asArrayRef().begin(), in_end: axes.asArrayRef().end());
63 }
64 return ShardingOption(axesArray, resultSharding.getMeshAttr());
65 }
66 return ShardingOption({}, resultSharding.getMeshAttr());
67 }
68
69 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
70 ArrayRef<MeshSharding> operandShardings,
71 ArrayRef<MeshSharding> resultShardings,
72 IRMapping &spmdizationMap,
73 SymbolTableCollection &symbolTable,
74 OpBuilder &builder) const {
75 auto cOp = cast<ConstantOp>(Val: op);
76 if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(Val: cOp.getValue())) {
77 if (!value.isSplat() || !resultShardings[0]) {
78 // Currently non-splat constants are not supported.
79 return failure();
80 }
81 auto sharding = resultShardings[0];
82 auto newType = cast<RankedTensorType>(Val: shardType(
83 type: cOp.getType(), mesh: getMesh(op, meshSymbol: sharding.getMeshAttr(), symbolTableCollection&: symbolTable),
84 sharding));
85 auto newValue = value.resizeSplat(newType);
86 auto newOp = builder.create<ConstantOp>(location: op->getLoc(), args&: newType, args&: newValue);
87 spmdizationMap.map(from: op->getResult(idx: 0), to: newOp.getResult());
88 spmdizationMap.map(from: op, to: newOp.getOperation());
89 } else {
90 // `clone` will populate the mapping of old to new results.
91 (void)builder.clone(op&: *op, mapper&: spmdizationMap);
92 }
93 return success();
94 }
95};
96} // namespace
97
98void mlir::arith::registerShardingInterfaceExternalModels(
99 DialectRegistry &registry) {
100
101 registry.addExtension(extensionFn: +[](MLIRContext *ctx, ArithDialect *dialect) {
102 ConstantOp::template attachInterface<ConstantShardingInterface>(context&: *ctx);
103 });
104}
105

source code of mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp