1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
10#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/IR/DialectRegistry.h"
14#include "llvm/Support/Debug.h"
15
16#define DEBUG_TYPE "tensor-sharding-impl"
17#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
18
19using namespace mlir;
20using namespace mlir::tensor;
21using namespace mlir::mesh;
22
23namespace {
24
25// Sharding of tensor.empty/tensor.splat
26template <typename OpTy>
27struct CreatorOpShardingInterface
28 : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
29 OpTy> {
30 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31 auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
32 return SmallVector<utils::IteratorType>(ndims,
33 utils::IteratorType::parallel);
34 }
35
36 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
37 MLIRContext *ctx = op->getContext();
38 Value val = op->getResult(idx: 0);
39 auto type = dyn_cast<RankedTensorType>(val.getType());
40 if (!type)
41 return {};
42 return SmallVector<AffineMap>(
43 op->getNumOperands() + op->getNumResults(),
44 {AffineMap::getMultiDimIdentityMap(numDims: type.getRank(), context: ctx)});
45 }
46
47 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
48 ArrayRef<MeshSharding> operandShardings,
49 ArrayRef<MeshSharding> resultShardings,
50 IRMapping &spmdizationMap,
51 SymbolTableCollection &symbolTable,
52 OpBuilder &builder) const {
53 assert(resultShardings.size() == 1);
54 auto resType = cast<RankedTensorType>(op->getResult(idx: 0).getType());
55 mlir::mesh::MeshOp mesh;
56 ShapedType shardType;
57 if (resType.getRank() > 0) {
58 mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
59 shardType =
60 cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
61 } else {
62 shardType = resType;
63 }
64 Operation *newOp = nullptr;
65 // if the sharding introduces a new dynamic dimension, we take it from
66 // the dynamic sharding info. For now bail out if it's not
67 // provided.
68 if (!shardType.hasStaticShape()) {
69 assert(op->getResult(0).hasOneUse());
70 SmallVector<Value> newOperands;
71 auto oldType = cast<ShapedType>(resType);
72 assert(oldType.getRank() == shardType.getRank());
73 int currOldOprndNum = -1;
74 mesh::ShardShapeOp shapeForDevice;
75 ValueRange device;
76 Operation *newSharding = nullptr;
77 for (auto i = 0; i < oldType.getRank(); ++i) {
78 if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
79 if (!newSharding) {
80 newSharding =
81 builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
82 device =
83 builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
84 .getResults();
85 shapeForDevice = builder.create<mesh::ShardShapeOp>(
86 op->getLoc(), oldType.getShape(), spmdizedOperands,
87 newSharding->getResult(0), device);
88 }
89 newOperands.emplace_back(shapeForDevice.getResult()[i]);
90 } else if (oldType.isDynamicDim(i)) {
91 assert(shardType.isDynamicDim(i));
92 newOperands.emplace_back(Args: spmdizedOperands[++currOldOprndNum]);
93 }
94 }
95 newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
96 spmdizationMap.map(from: op->getResult(idx: 0), to: newOp->getResult(idx: 0));
97 } else {
98 // `clone` will populate the mapping of old to new results.
99 newOp = builder.clone(op&: *op, mapper&: spmdizationMap);
100 }
101 newOp->getResult(idx: 0).setType(shardType);
102
103 return success();
104 }
105};
106} // namespace
107
108void mlir::tensor::registerShardingInterfaceExternalModels(
109 DialectRegistry &registry) {
110
111 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) {
112 EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
113 *ctx);
114 SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
115 *ctx);
116 });
117}
118

source code of mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp