1//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
10#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
11
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16
17namespace mlir {
18
19class Operation;
20class IRMapping;
21class SymbolTableCollection;
22
23namespace mesh {
24
25using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
26using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
27
28struct ShardingOption {
29 // An array of int array. The sub-array at the i-th position signifies the
30 // mesh axes the i-th loop will be sharded on.
31 ShardingArray shardingArray = {};
32 FlatSymbolRefAttr mesh = nullptr;
33 // `empty` being true indicates that no sharding information can be inferred
34 // at present. Note that it is different from the case where an operation is
35 // not sharded.
36 bool empty = false;
37 ShardingOption() = default;
38 ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
39 : shardingArray(std::move(shardingArray)), mesh(mesh) {}
40};
41
42// This method retrieves the 'MeshShardingAttr' attribute from a given operation
43// result and includes the 'annotate_for_users' information.
44FailureOr<std::pair<bool, MeshShardingAttr>>
45getMeshShardingAttr(OpResult result);
46
47// This method retrieves the 'MeshShardingAttr' attribute from a given operation
48// operand and includes the 'annotate_for_users' information.
49FailureOr<std::pair<bool, MeshShardingAttr>>
50getMeshShardingAttr(OpOperand &opOperand);
51
52namespace detail {
53
54FailureOr<ShardingOption>
55defaultGetShardingOption(Operation *op,
56 ArrayRef<MeshShardingAttr> operandShardings,
57 ArrayRef<MeshShardingAttr> resultShardings);
58
59LogicalResult
60defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
61 const ShardingOption &shardingOption);
62
63} // namespace detail
64
65// Assumes full replication on all ranked tensor arguments and results.
66void spmdizeFullyReplicatedOperation(
67 Operation &op, ArrayRef<Value> spmdizedOperands,
68 ArrayRef<MeshShardingAttr> operandShardings,
69 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
70 SymbolTableCollection &symbolTable, OpBuilder &builder);
71
72} // namespace mesh
73} // namespace mlir
74
75/// Include the ODS generated interface header files.
76#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
77
78#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
79

source code of mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h