1//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10#define MLIR_DIALECT_MESH_IR_MESHOPS_H
11
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
14#include "mlir/IR/BuiltinTypeInterfaces.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/IR/SymbolTable.h"
18#include "mlir/Interfaces/InferTypeOpInterface.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Support/MathExtras.h"
21
22namespace mlir {
23namespace mesh {
24
25using MeshAxis = int16_t;
26using MeshAxesAttr = DenseI16ArrayAttr;
27
28} // namespace mesh
29} // namespace mlir
30
31#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
32
33#define GET_ATTRDEF_CLASSES
34#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
35
36#define GET_OP_CLASSES
37#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
38
39namespace mlir {
40namespace mesh {
41
42inline bool isReductionLoop(utils::IteratorType iType) {
43 return iType == utils::IteratorType::reduction;
44}
45
46template <typename T>
47void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
48 while (!array.empty() && array.back().empty())
49 array.pop_back();
50}
51
52// Is the same tensor replicated on all processes.
53inline bool isFullReplication(MeshShardingAttr attr) {
54 return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
55}
56
57inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
58 SymbolTableCollection &symbolTableCollection) {
59 return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
60 op, meshSymbol);
61}
62
63// Get the corresponding mesh op using the standard attribute nomenclature.
64template <typename Op>
65mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
66 return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
67}
68
69template <>
70inline mesh::MeshOp
71getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
72 return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
73 symbolTableCollection);
74}
75
76// Get the number of processes that participate in each group
77// induced by `meshAxes`.
78template <typename MeshAxesRange, typename MeshShapeRange>
79int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
80 MeshShapeRange &&meshShape) {
81 int64_t res = 1;
82
83 for (MeshAxis axis : meshAxes) {
84 auto axisSize = *(std::begin(meshShape) + axis);
85 if (ShapedType::isDynamic(axisSize)) {
86 return ShapedType::kDynamic;
87 }
88 res *= axisSize;
89 }
90
91 return res;
92}
93
94template <typename MeshAxesRange>
95int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
96 return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
97 mesh.getShape());
98}
99
100// Get the size of a sharded dimension.
101inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
102 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
103 return ShapedType::kDynamic;
104
105 assert(dimSize % shardCount == 0);
106 return ceilDiv(lhs: dimSize, rhs: shardCount);
107}
108
109// Get the size of an unsharded dimension.
110inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
111 if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
112 return ShapedType::kDynamic;
113
114 return dimSize * shardCount;
115}
116
117// Return the sharded shape `shape` according ot sharding `sharding`.
118// The shape for the tensor on each device in the mesh.
119// Example:
120// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
121// result in a shape for each shard of ?x2x?.
122ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
123 MeshShardingAttr sharding);
124
125// If ranked tensor type return its sharded counterpart.
126//
127// If not ranked tensor type return `type`.
128// `sharding` in that case must be null.
129Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
130
131} // namespace mesh
132} // namespace mlir
133
134#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
135

source code of mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h