1 | //===- ShardingInterfaceImpl.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
10 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" |
11 | #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h" |
12 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
13 | #include "mlir/IR/DialectRegistry.h" |
14 | #include "llvm/Support/Debug.h" |
15 | |
16 | #define DEBUG_TYPE "tensor-sharding-impl" |
17 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::tensor; |
21 | using namespace mlir::mesh; |
22 | |
23 | namespace { |
24 | |
25 | // Sharding of tensor.empty/tensor.splat |
26 | template <typename OpTy> |
27 | struct CreatorOpShardingInterface |
28 | : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, |
29 | OpTy> { |
30 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
31 | auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank(); |
32 | return SmallVector<utils::IteratorType>(ndims, |
33 | utils::IteratorType::parallel); |
34 | } |
35 | |
36 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
37 | MLIRContext *ctx = op->getContext(); |
38 | Value val = op->getResult(idx: 0); |
39 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
40 | if (!type) |
41 | return {}; |
42 | return SmallVector<AffineMap>( |
43 | op->getNumOperands() + op->getNumResults(), |
44 | {AffineMap::getMultiDimIdentityMap(numDims: type.getRank(), context: ctx)}); |
45 | } |
46 | |
47 | LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, |
48 | ArrayRef<MeshSharding> operandShardings, |
49 | ArrayRef<MeshSharding> resultShardings, |
50 | IRMapping &spmdizationMap, |
51 | SymbolTableCollection &symbolTable, |
52 | OpBuilder &builder) const { |
53 | assert(resultShardings.size() == 1); |
54 | auto resType = cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
55 | mlir::mesh::MeshOp mesh; |
56 | ShapedType shardType; |
57 | if (resType.getRank() > 0) { |
58 | mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable); |
59 | shardType = |
60 | cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0])); |
61 | } else { |
62 | shardType = resType; |
63 | } |
64 | Operation *newOp = nullptr; |
65 | // if the sharding introduces a new dynamic dimension, we take it from |
66 | // the dynamic sharding info. For now bail out if it's not |
67 | // provided. |
68 | if (!shardType.hasStaticShape()) { |
69 | assert(op->getResult(0).hasOneUse()); |
70 | SmallVector<Value> newOperands; |
71 | auto oldType = cast<ShapedType>(resType); |
72 | assert(oldType.getRank() == shardType.getRank()); |
73 | int currOldOprndNum = -1; |
74 | mesh::ShardShapeOp shapeForDevice; |
75 | ValueRange device; |
76 | Operation *newSharding = nullptr; |
77 | for (auto i = 0; i < oldType.getRank(); ++i) { |
78 | if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) { |
79 | if (!newSharding) { |
80 | newSharding = |
81 | builder.create<ShardingOp>(op->getLoc(), resultShardings[0]); |
82 | device = |
83 | builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh) |
84 | .getResults(); |
85 | shapeForDevice = builder.create<mesh::ShardShapeOp>( |
86 | op->getLoc(), oldType.getShape(), spmdizedOperands, |
87 | newSharding->getResult(0), device); |
88 | } |
89 | newOperands.emplace_back(shapeForDevice.getResult()[i]); |
90 | } else if (oldType.isDynamicDim(i)) { |
91 | assert(shardType.isDynamicDim(i)); |
92 | newOperands.emplace_back(Args: spmdizedOperands[++currOldOprndNum]); |
93 | } |
94 | } |
95 | newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands); |
96 | spmdizationMap.map(from: op->getResult(idx: 0), to: newOp->getResult(idx: 0)); |
97 | } else { |
98 | // `clone` will populate the mapping of old to new results. |
99 | newOp = builder.clone(op&: *op, mapper&: spmdizationMap); |
100 | } |
101 | newOp->getResult(idx: 0).setType(shardType); |
102 | |
103 | return success(); |
104 | } |
105 | }; |
106 | } // namespace |
107 | |
108 | void mlir::tensor::registerShardingInterfaceExternalModels( |
109 | DialectRegistry ®istry) { |
110 | |
111 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) { |
112 | EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>( |
113 | *ctx); |
114 | SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>( |
115 | *ctx); |
116 | }); |
117 | } |
118 | |