1 | //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ |
10 | #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ |
11 | |
12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
13 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
14 | #include "mlir/IR/AffineMap.h" |
15 | #include "mlir/IR/Value.h" |
16 | |
17 | namespace mlir { |
18 | |
19 | class Operation; |
20 | class IRMapping; |
21 | class SymbolTableCollection; |
22 | |
23 | namespace mesh { |
24 | |
25 | // Retrieve the mesh axes corresponding to each operation loop iterator based |
26 | // on the provided shardings for the op's operands and results. |
27 | // Assumes that the indexingMaps are projected permutations. |
28 | ShardingArray getMeshAxisAssignmentForLoopIterators( |
29 | ArrayRef<MeshSharding> operandShardings, |
30 | ArrayRef<MeshSharding> resultShardings, |
31 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
32 | ArrayRef<AffineMap> indexingMaps); |
33 | |
34 | bool isAtLeastOneReductionIteratorSharded( |
35 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
36 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); |
37 | |
38 | // Get the set of mesh axes that correspond to reduction loop iterators. |
39 | SmallVector<MeshAxis> getReductionMeshAxes( |
40 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
41 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); |
42 | |
43 | // Inserts a clone of the operation that has all ranked tensor |
44 | // arguments/results sharded. |
45 | void spmdizeTriviallyShardableOperation(Operation &op, |
46 | ArrayRef<Value> spmdizedOperands, |
47 | ArrayRef<MeshSharding> operandShardings, |
48 | ArrayRef<MeshSharding> resultShardings, |
49 | IRMapping &spmdizationMap, |
50 | SymbolTableCollection &symbolTable, |
51 | OpBuilder &builder); |
52 | |
53 | // All ranked tensor argument and result dimensions have |
54 | // independent parallel loop iterators. |
55 | template <typename Op> |
56 | struct IndependentParallelIteratorDomainShardingInterface |
57 | : public ShardingInterface::ExternalModel< |
58 | IndependentParallelIteratorDomainShardingInterface<Op>, Op> { |
59 | SmallVector<utils::IteratorType> |
60 | getLoopIteratorTypes(Operation *operation) const { |
61 | SmallVector<utils::IteratorType> iterTypes; |
62 | for (Type t : operation->getOperandTypes()) { |
63 | populateIteratorTypes(t, iterTypes); |
64 | } |
65 | for (Type t : operation->getResultTypes()) { |
66 | populateIteratorTypes(t, iterTypes); |
67 | } |
68 | return iterTypes; |
69 | } |
70 | |
71 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
72 | // TODO: implement. |
73 | return SmallVector<AffineMap>(); |
74 | } |
75 | |
76 | LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, |
77 | ArrayRef<MeshSharding> operandShardings, |
78 | ArrayRef<MeshSharding> resultShardings, |
79 | IRMapping &spmdizationMap, |
80 | SymbolTableCollection &symbolTable, |
81 | OpBuilder &builder) const { |
82 | spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings, |
83 | resultShardings, spmdizationMap, |
84 | symbolTable, builder); |
85 | return success(); |
86 | } |
87 | |
88 | private: |
89 | void |
90 | populateIteratorTypes(Type t, |
91 | SmallVector<utils::IteratorType> &iterTypes) const { |
92 | RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t); |
93 | if (!rankedTensorType) { |
94 | return; |
95 | } |
96 | |
97 | iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank()); |
98 | for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) { |
99 | iterTypes.push_back(utils::IteratorType::parallel); |
100 | } |
101 | } |
102 | }; |
103 | |
104 | // Sharding of elementwise operations like tensor addition and multiplication. |
105 | template <typename ElemwiseOp> |
106 | struct ElementwiseShardingInterface |
107 | : public ShardingInterface::ExternalModel< |
108 | ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> { |
109 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
110 | Value val = op->getOperand(idx: 0); |
111 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
112 | if (!type) |
113 | return {}; |
114 | SmallVector<utils::IteratorType> types(type.getRank(), |
115 | utils::IteratorType::parallel); |
116 | return types; |
117 | } |
118 | |
119 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
120 | MLIRContext *ctx = op->getContext(); |
121 | Value val = op->getOperand(idx: 0); |
122 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
123 | if (!type) |
124 | return {}; |
125 | int64_t rank = type.getRank(); |
126 | int64_t num = op->getNumOperands() + op->getNumResults(); |
127 | SmallVector<AffineMap> maps(num, |
128 | AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx)); |
129 | return maps; |
130 | } |
131 | |
132 | LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, |
133 | ArrayRef<MeshSharding> operandShardings, |
134 | ArrayRef<MeshSharding> resultShardings, |
135 | IRMapping &spmdizationMap, |
136 | SymbolTableCollection &symbolTable, |
137 | OpBuilder &builder) const { |
138 | spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings, |
139 | resultShardings, spmdizationMap, |
140 | symbolTable, builder); |
141 | return success(); |
142 | } |
143 | }; |
144 | |
145 | } // namespace mesh |
146 | } // namespace mlir |
147 | |
148 | #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ |
149 | |