1 | //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ |
10 | #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ |
11 | |
12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
13 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
14 | #include "mlir/IR/AffineMap.h" |
15 | #include "mlir/IR/Value.h" |
16 | |
17 | namespace mlir { |
18 | |
19 | class Operation; |
20 | class IRMapping; |
21 | class SymbolTableCollection; |
22 | |
23 | namespace mesh { |
24 | |
25 | // Retrieve the mesh axes corresponding to each operation loop iterator based |
26 | // on the provided shardings for the op's operands and results. |
27 | // Assumes that the indexingMaps are projected permutations. |
28 | ShardingArray getMeshAxisAssignmentForLoopIterators( |
29 | ArrayRef<MeshShardingAttr> operandShardings, |
30 | ArrayRef<MeshShardingAttr> resultShardings, |
31 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
32 | ArrayRef<AffineMap> indexingMaps); |
33 | |
34 | bool isAtLeastOneReductionIteratorSharded( |
35 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
36 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); |
37 | |
38 | // Get the set of mesh axes that correspond to reduction loop iterators. |
39 | SmallVector<MeshAxis> getReductionMeshAxes( |
40 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
41 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); |
42 | |
43 | // Inserts a clone of the operation that has all ranked tensor |
44 | // arguments/results sharded. |
45 | void spmdizeTriviallyShardableOperation( |
46 | Operation &op, ArrayRef<Value> spmdizedOperands, |
47 | ArrayRef<MeshShardingAttr> operandShardings, |
48 | ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, |
49 | SymbolTableCollection &symbolTable, OpBuilder &builder); |
50 | |
51 | // All ranked tensor argument and result dimensions have |
52 | // independent parallel loop iterators. |
53 | template <typename Op> |
54 | struct IndependentParallelIteratorDomainShardingInterface |
55 | : public ShardingInterface::ExternalModel< |
56 | IndependentParallelIteratorDomainShardingInterface<Op>, Op> { |
57 | SmallVector<utils::IteratorType> |
58 | getLoopIteratorTypes(Operation *operation) const { |
59 | SmallVector<utils::IteratorType> iterTypes; |
60 | for (Type t : operation->getOperandTypes()) { |
61 | populateIteratorTypes(t, iterTypes); |
62 | } |
63 | for (Type t : operation->getResultTypes()) { |
64 | populateIteratorTypes(t, iterTypes); |
65 | } |
66 | return iterTypes; |
67 | } |
68 | |
69 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
70 | // TODO: implement. |
71 | return SmallVector<AffineMap>(); |
72 | } |
73 | |
74 | LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, |
75 | ArrayRef<MeshShardingAttr> operandShardings, |
76 | ArrayRef<MeshShardingAttr> resultShardings, |
77 | IRMapping &spmdizationMap, |
78 | SymbolTableCollection &symbolTable, |
79 | OpBuilder &builder) const { |
80 | spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, |
81 | resultShardings, spmdizationMap, |
82 | symbolTable, builder); |
83 | return success(); |
84 | } |
85 | |
86 | private: |
87 | void |
88 | populateIteratorTypes(Type t, |
89 | SmallVector<utils::IteratorType> &iterTypes) const { |
90 | RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t); |
91 | if (!rankedTensorType) { |
92 | return; |
93 | } |
94 | |
95 | iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank()); |
96 | for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) { |
97 | iterTypes.push_back(utils::IteratorType::parallel); |
98 | } |
99 | } |
100 | }; |
101 | |
102 | // Sharding of elementwise operations like tensor addition and multiplication. |
103 | template <typename ElemwiseOp> |
104 | struct ElementwiseShardingInterface |
105 | : public ShardingInterface::ExternalModel< |
106 | ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> { |
107 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
108 | Value val = op->getOperand(idx: 0); |
109 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
110 | if (!type) |
111 | return {}; |
112 | SmallVector<utils::IteratorType> types(type.getRank(), |
113 | utils::IteratorType::parallel); |
114 | return types; |
115 | } |
116 | |
117 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
118 | MLIRContext *ctx = op->getContext(); |
119 | Value val = op->getOperand(idx: 0); |
120 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
121 | if (!type) |
122 | return {}; |
123 | int64_t rank = type.getRank(); |
124 | int64_t num = op->getNumOperands() + op->getNumResults(); |
125 | SmallVector<AffineMap> maps(num, |
126 | AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx)); |
127 | return maps; |
128 | } |
129 | |
130 | LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, |
131 | ArrayRef<MeshShardingAttr> operandShardings, |
132 | ArrayRef<MeshShardingAttr> resultShardings, |
133 | IRMapping &spmdizationMap, |
134 | SymbolTableCollection &symbolTable, |
135 | OpBuilder &builder) const { |
136 | spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, |
137 | resultShardings, spmdizationMap, |
138 | symbolTable, builder); |
139 | return success(); |
140 | } |
141 | }; |
142 | |
143 | } // namespace mesh |
144 | } // namespace mlir |
145 | |
146 | #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_ |
147 | |