1 | //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H |
10 | #define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H |
11 | |
12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
13 | #include "mlir/IR/DialectRegistry.h" |
14 | #include "mlir/Support/LogicalResult.h" |
15 | |
16 | namespace mlir { |
17 | namespace mesh { |
18 | |
19 | // Insert resharding spmdization of the value `sourceShardValue` |
20 | // from sharding `source` to sharding `target`. |
21 | // `sourceShardValue` is the already sharded value according to `source`. |
22 | // |
23 | // Example |
24 | // |
25 | // ```mlir |
26 | // mesh.mesh @mesh_1d(shape = 2) |
27 | // ... |
28 | // %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8> |
29 | // %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> |
30 | // ``` |
31 | // |
32 | // Will result in |
33 | // |
34 | // ```mlir |
35 | // %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : |
36 | // tensor<1xi8> -> tensor<2xi8> |
37 | // ``` |
38 | TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, |
39 | ShardOp target, |
40 | TypedValue<ShapedType> sourceShardValue); |
41 | TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, |
42 | ShardOp target, |
43 | TypedValue<ShapedType> sourceShardValue, |
44 | SymbolTableCollection &symbolTableCollection); |
45 | |
46 | void reshardingRegisterDependentDialects(DialectRegistry ®istry); |
47 | |
48 | } // namespace mesh |
49 | } // namespace mlir |
50 | |
51 | #endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H |
52 | |