1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
12#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
13#include "mlir/IR/DialectRegistry.h"
14#include "llvm/Support/Debug.h"
15
16using namespace mlir;
17using namespace mlir::arith;
18using namespace mlir::mesh;
19
20namespace {
21
22// Sharding of arith.constant
23// RankedTensor constants can be sharded like any other tensor.
24// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
25// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
26// Scalar constants are always replicated and need no sharding annotation.
27
28struct ConstantShardingInterface
29 : public ShardingInterface::ExternalModel<ConstantShardingInterface,
30 ConstantOp> {
31 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
32 auto ndims = 0;
33 if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
34 ndims = type.getRank();
35 }
36 return SmallVector<utils::IteratorType>(ndims,
37 utils::IteratorType::parallel);
38 }
39
40 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
41 if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
42 return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
43 numDims: type.getRank(), context: op->getContext())});
44 }
45 return {};
46 }
47
48 // Indicate failure if no result sharding exists.
49 // Otherwise mirror result sharding if it is a tensor constant.
50 // Otherwise return replication option.
51 FailureOr<ShardingOption>
52 getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
53 ArrayRef<MeshSharding> resultShardings) const {
54 assert(resultShardings.size() == 1 &&
55 "Expecting exactly one result sharding for arith.constant");
56 auto resultSharding = resultShardings[0];
57 if (!resultSharding) {
58 return failure();
59 }
60 if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
61 ShardingArray axesArray(resultSharding.getSplitAxes().size());
62 for (auto [i, axes] : llvm::enumerate(First: resultSharding.getSplitAxes())) {
63 axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
64 }
65 return ShardingOption(axesArray, resultSharding.getMeshAttr());
66 }
67 return ShardingOption({}, resultSharding.getMeshAttr());
68 }
69
70 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
71 ArrayRef<MeshSharding> operandShardings,
72 ArrayRef<MeshSharding> resultShardings,
73 IRMapping &spmdizationMap,
74 SymbolTableCollection &symbolTable,
75 OpBuilder &builder) const {
76 auto cOp = cast<ConstantOp>(op);
77 if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
78 if (!value.isSplat() || !resultShardings[0]) {
79 // Currently non-splat constants are not supported.
80 return failure();
81 }
82 auto sharding = resultShardings[0];
83 auto newType = cast<RankedTensorType>(shardType(
84 cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
85 sharding));
86 auto newValue = value.resizeSplat(newType);
87 auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
88 spmdizationMap.map(op->getResult(idx: 0), newOp.getResult());
89 spmdizationMap.map(op, newOp.getOperation());
90 } else {
91 // `clone` will populate the mapping of old to new results.
92 (void)builder.clone(op&: *op, mapper&: spmdizationMap);
93 }
94 return success();
95 }
96};
97} // namespace
98
99void mlir::arith::registerShardingInterfaceExternalModels(
100 DialectRegistry &registry) {
101
102 registry.addExtension(extensionFn: +[](MLIRContext *ctx, ArithDialect *dialect) {
103 ConstantOp::template attachInterface<ConstantShardingInterface>(*ctx);
104 });
105}
106

source code of mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp