1 | //===- ShardingInterface.h --------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ |
10 | #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ |
11 | |
12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
13 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
14 | #include "mlir/IR/Value.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | |
17 | namespace mlir { |
18 | |
19 | class Operation; |
20 | class IRMapping; |
21 | class SymbolTableCollection; |
22 | |
23 | namespace mesh { |
24 | |
25 | using ShardingArray = SmallVector<SmallVector<MeshAxis>>; |
26 | using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>; |
27 | |
28 | struct ShardingOption { |
29 | // An array of int array. The sub-array at the i-th position signifies the |
30 | // mesh axes the i-th loop will be sharded on. |
31 | ShardingArray shardingArray = {}; |
32 | FlatSymbolRefAttr mesh = nullptr; |
33 | // `empty` being true indicates that no sharding information can be inferred |
34 | // at present. Note that it is different from the case where an operation is |
35 | // not sharded. |
36 | bool empty = false; |
37 | ShardingOption() = default; |
38 | ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh) |
39 | : shardingArray(std::move(shardingArray)), mesh(mesh) {} |
40 | }; |
41 | |
42 | // This method retrieves the 'MeshShardingAttr' attribute from a given operation |
43 | // result and includes the 'annotate_for_users' information. |
44 | FailureOr<std::pair<bool, MeshShardingAttr>> |
45 | getMeshShardingAttr(OpResult result); |
46 | |
47 | // This method retrieves the 'MeshShardingAttr' attribute from a given operation |
48 | // operand and includes the 'annotate_for_users' information. |
49 | FailureOr<std::pair<bool, MeshShardingAttr>> |
50 | getMeshShardingAttr(OpOperand &opOperand); |
51 | |
52 | namespace detail { |
53 | |
54 | FailureOr<ShardingOption> |
55 | defaultGetShardingOption(Operation *op, |
56 | ArrayRef<MeshShardingAttr> operandShardings, |
57 | ArrayRef<MeshShardingAttr> resultShardings); |
58 | |
59 | LogicalResult |
60 | defaultAddShardingAnnotations(Operation *op, OpBuilder &b, |
61 | const ShardingOption &shardingOption); |
62 | |
63 | } // namespace detail |
64 | |
65 | // Assumes full replication on all ranked tensor arguments and results. |
66 | void spmdizeFullyReplicatedOperation( |
67 | Operation &op, ArrayRef<Value> spmdizedOperands, |
68 | ArrayRef<MeshShardingAttr> operandShardings, |
69 | ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, |
70 | SymbolTableCollection &symbolTable, OpBuilder &builder); |
71 | |
72 | } // namespace mesh |
73 | } // namespace mlir |
74 | |
75 | /// Include the ODS generated interface header files. |
76 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc" |
77 | |
78 | #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_ |
79 | |