1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/IR/DialectRegistry.h"
14
15using namespace mlir;
16using namespace mlir::tensor;
17using namespace mlir::mesh;
18
19namespace {
20
21// Sharding of tensor.empty/tensor.splat
22template <typename OpTy>
23struct CreatorOpShardingInterface
24 : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
25 OpTy> {
26 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
27 auto ndims = mlir::cast<ShapedType>(Val: op->getResult(idx: 0).getType()).getRank();
28 return SmallVector<utils::IteratorType>(ndims,
29 utils::IteratorType::parallel);
30 }
31
32 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
33 MLIRContext *ctx = op->getContext();
34 Value val = op->getResult(idx: 0);
35 auto type = dyn_cast<RankedTensorType>(Val: val.getType());
36 if (!type)
37 return {};
38 return SmallVector<AffineMap>(
39 op->getNumOperands() + op->getNumResults(),
40 {AffineMap::getMultiDimIdentityMap(numDims: type.getRank(), context: ctx)});
41 }
42
43 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
44 ArrayRef<MeshSharding> operandShardings,
45 ArrayRef<MeshSharding> resultShardings,
46 IRMapping &spmdizationMap,
47 SymbolTableCollection &symbolTable,
48 OpBuilder &builder) const {
49 assert(resultShardings.size() == 1);
50 auto resType = cast<RankedTensorType>(Val: op->getResult(idx: 0).getType());
51 mlir::mesh::MeshOp mesh;
52 ShapedType shardType;
53 if (resType.getRank() > 0) {
54 mesh = mesh::getMesh(op, meshSymbol: resultShardings[0].getMeshAttr(), symbolTableCollection&: symbolTable);
55 shardType =
56 cast<ShapedType>(Val: mesh::shardType(type: resType, mesh, sharding: resultShardings[0]));
57 } else {
58 shardType = resType;
59 }
60 Operation *newOp = nullptr;
61 // if the sharding introduces a new dynamic dimension, we take it from
62 // the dynamic sharding info. For now bail out if it's not
63 // provided.
64 if (!shardType.hasStaticShape()) {
65 assert(op->getResult(0).hasOneUse());
66 SmallVector<Value> newOperands;
67 auto oldType = cast<ShapedType>(Val&: resType);
68 assert(oldType.getRank() == shardType.getRank());
69 int currOldOprndNum = -1;
70 mesh::ShardShapeOp shapeForDevice;
71 ValueRange device;
72 Operation *newSharding = nullptr;
73 for (auto i = 0; i < oldType.getRank(); ++i) {
74 if (!oldType.isDynamicDim(idx: i) && shardType.isDynamicDim(idx: i)) {
75 if (!newSharding) {
76 newSharding =
77 builder.create<ShardingOp>(location: op->getLoc(), args: resultShardings[0]);
78 device =
79 builder.create<mesh::ProcessMultiIndexOp>(location: op->getLoc(), args&: mesh)
80 .getResults();
81 shapeForDevice = builder.create<mesh::ShardShapeOp>(
82 location: op->getLoc(), args: oldType.getShape(), args&: spmdizedOperands,
83 args: newSharding->getResult(idx: 0), args&: device);
84 }
85 newOperands.emplace_back(Args: shapeForDevice.getResult()[i]);
86 } else if (oldType.isDynamicDim(idx: i)) {
87 assert(shardType.isDynamicDim(i));
88 newOperands.emplace_back(Args: spmdizedOperands[++currOldOprndNum]);
89 }
90 }
91 newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
92 spmdizationMap.map(from: op->getResult(idx: 0), to: newOp->getResult(idx: 0));
93 } else {
94 // `clone` will populate the mapping of old to new results.
95 newOp = builder.clone(op&: *op, mapper&: spmdizationMap);
96 }
97 newOp->getResult(idx: 0).setType(shardType);
98
99 return success();
100 }
101};
102} // namespace
103
104void mlir::tensor::registerShardingInterfaceExternalModels(
105 DialectRegistry &registry) {
106
107 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) {
108 EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
109 context&: *ctx);
110 SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
111 context&: *ctx);
112 });
113}
114

source code of mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp