1//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
10#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
11
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/IR/DialectRegistry.h"
14#include "mlir/Support/LogicalResult.h"
15
16namespace mlir {
17namespace mesh {
18
19// Insert resharding spmdization of the value `sourceShardValue`
20// from sharding `source` to sharding `target`.
21// `sourceShardValue` is the already sharded value according to `source`.
22//
23// Example
24//
25// ```mlir
26// mesh.mesh @mesh_1d(shape = 2)
27// ...
28// %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
29// %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
30// ```
31//
32// Will result in
33//
34// ```mlir
35// %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
36// tensor<1xi8> -> tensor<2xi8>
37// ```
38TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
39 ShardOp target,
40 TypedValue<ShapedType> sourceShardValue);
41TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
42 ShardOp target,
43 TypedValue<ShapedType> sourceShardValue,
44 SymbolTableCollection &symbolTableCollection);
45
46void reshardingRegisterDependentDialects(DialectRegistry &registry);
47
48} // namespace mesh
49} // namespace mlir
50
51#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
52

source code of mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h