1 | //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H |
10 | #define MLIR_DIALECT_MESH_IR_MESHOPS_H |
11 | |
12 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
13 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
14 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
15 | #include "mlir/IR/OpDefinition.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/IR/SymbolTable.h" |
18 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
19 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
20 | #include "mlir/Support/MathExtras.h" |
21 | |
22 | namespace mlir { |
23 | namespace mesh { |
24 | |
25 | using MeshAxis = int16_t; |
26 | using MeshAxesAttr = DenseI16ArrayAttr; |
27 | |
28 | } // namespace mesh |
29 | } // namespace mlir |
30 | |
31 | #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc" |
32 | |
33 | #define GET_ATTRDEF_CLASSES |
34 | #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc" |
35 | |
36 | #define GET_OP_CLASSES |
37 | #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc" |
38 | |
39 | namespace mlir { |
40 | namespace mesh { |
41 | |
42 | inline bool isReductionLoop(utils::IteratorType iType) { |
43 | return iType == utils::IteratorType::reduction; |
44 | } |
45 | |
46 | template <typename T> |
47 | void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) { |
48 | while (!array.empty() && array.back().empty()) |
49 | array.pop_back(); |
50 | } |
51 | |
52 | // Is the same tensor replicated on all processes. |
53 | inline bool isFullReplication(MeshShardingAttr attr) { |
54 | return attr.getPartialAxes().empty() && attr.getSplitAxes().empty(); |
55 | } |
56 | |
57 | inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, |
58 | SymbolTableCollection &symbolTableCollection) { |
59 | return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>( |
60 | op, meshSymbol); |
61 | } |
62 | |
63 | // Get the corresponding mesh op using the standard attribute nomenclature. |
64 | template <typename Op> |
65 | mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) { |
66 | return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection); |
67 | } |
68 | |
69 | template <> |
70 | inline mesh::MeshOp |
71 | getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) { |
72 | return getMesh(op.getOperation(), op.getShardAttr().getMesh(), |
73 | symbolTableCollection); |
74 | } |
75 | |
76 | // Get the number of processes that participate in each group |
77 | // induced by `meshAxes`. |
78 | template <typename MeshAxesRange, typename MeshShapeRange> |
79 | int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, |
80 | MeshShapeRange &&meshShape) { |
81 | int64_t res = 1; |
82 | |
83 | for (MeshAxis axis : meshAxes) { |
84 | auto axisSize = *(std::begin(meshShape) + axis); |
85 | if (ShapedType::isDynamic(axisSize)) { |
86 | return ShapedType::kDynamic; |
87 | } |
88 | res *= axisSize; |
89 | } |
90 | |
91 | return res; |
92 | } |
93 | |
94 | template <typename MeshAxesRange> |
95 | int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) { |
96 | return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes), |
97 | mesh.getShape()); |
98 | } |
99 | |
100 | // Get the size of a sharded dimension. |
101 | inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) { |
102 | if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount)) |
103 | return ShapedType::kDynamic; |
104 | |
105 | assert(dimSize % shardCount == 0); |
106 | return ceilDiv(lhs: dimSize, rhs: shardCount); |
107 | } |
108 | |
109 | // Get the size of an unsharded dimension. |
110 | inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) { |
111 | if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount)) |
112 | return ShapedType::kDynamic; |
113 | |
114 | return dimSize * shardCount; |
115 | } |
116 | |
117 | // Return the sharded shape `shape` according ot sharding `sharding`. |
118 | // The shape for the tensor on each device in the mesh. |
119 | // Example: |
120 | // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would |
121 | // result in a shape for each shard of ?x2x?. |
122 | ShapedType shardShapedType(ShapedType shape, MeshOp mesh, |
123 | MeshShardingAttr sharding); |
124 | |
125 | // If ranked tensor type return its sharded counterpart. |
126 | // |
127 | // If not ranked tensor type return `type`. |
128 | // `sharding` in that case must be null. |
129 | Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding); |
130 | |
131 | } // namespace mesh |
132 | } // namespace mlir |
133 | |
134 | #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H |
135 | |