1//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
11
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Value.h"
16
17namespace mlir {
18
19class Operation;
20class IRMapping;
21class SymbolTableCollection;
22
23namespace mesh {
24
25// Retrieve the mesh axes corresponding to each operation loop iterator based
26// on the provided shardings for the op's operands and results.
27// Assumes that the indexingMaps are projected permutations.
28ShardingArray getMeshAxisAssignmentForLoopIterators(
29 ArrayRef<MeshShardingAttr> operandShardings,
30 ArrayRef<MeshShardingAttr> resultShardings,
31 ArrayRef<utils::IteratorType> loopIteratorTypes,
32 ArrayRef<AffineMap> indexingMaps);
33
34bool isAtLeastOneReductionIteratorSharded(
35 ArrayRef<utils::IteratorType> loopIteratorTypes,
36 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37
38// Get the set of mesh axes that correspond to reduction loop iterators.
39SmallVector<MeshAxis> getReductionMeshAxes(
40 ArrayRef<utils::IteratorType> loopIteratorTypes,
41 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42
43// Inserts a clone of the operation that has all ranked tensor
44// arguments/results sharded.
45void spmdizeTriviallyShardableOperation(
46 Operation &op, ArrayRef<Value> spmdizedOperands,
47 ArrayRef<MeshShardingAttr> operandShardings,
48 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
49 SymbolTableCollection &symbolTable, OpBuilder &builder);
50
51// All ranked tensor argument and result dimensions have
52// independent parallel loop iterators.
53template <typename Op>
54struct IndependentParallelIteratorDomainShardingInterface
55 : public ShardingInterface::ExternalModel<
56 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
57 SmallVector<utils::IteratorType>
58 getLoopIteratorTypes(Operation *operation) const {
59 SmallVector<utils::IteratorType> iterTypes;
60 for (Type t : operation->getOperandTypes()) {
61 populateIteratorTypes(t, iterTypes);
62 }
63 for (Type t : operation->getResultTypes()) {
64 populateIteratorTypes(t, iterTypes);
65 }
66 return iterTypes;
67 }
68
69 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
70 // TODO: implement.
71 return SmallVector<AffineMap>();
72 }
73
74 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
75 ArrayRef<MeshShardingAttr> operandShardings,
76 ArrayRef<MeshShardingAttr> resultShardings,
77 IRMapping &spmdizationMap,
78 SymbolTableCollection &symbolTable,
79 OpBuilder &builder) const {
80 spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
81 resultShardings, spmdizationMap,
82 symbolTable, builder);
83 return success();
84 }
85
86private:
87 void
88 populateIteratorTypes(Type t,
89 SmallVector<utils::IteratorType> &iterTypes) const {
90 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
91 if (!rankedTensorType) {
92 return;
93 }
94
95 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
96 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
97 iterTypes.push_back(utils::IteratorType::parallel);
98 }
99 }
100};
101
102// Sharding of elementwise operations like tensor addition and multiplication.
103template <typename ElemwiseOp>
104struct ElementwiseShardingInterface
105 : public ShardingInterface::ExternalModel<
106 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
107 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
108 Value val = op->getOperand(idx: 0);
109 auto type = dyn_cast<RankedTensorType>(val.getType());
110 if (!type)
111 return {};
112 SmallVector<utils::IteratorType> types(type.getRank(),
113 utils::IteratorType::parallel);
114 return types;
115 }
116
117 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
118 MLIRContext *ctx = op->getContext();
119 Value val = op->getOperand(idx: 0);
120 auto type = dyn_cast<RankedTensorType>(val.getType());
121 if (!type)
122 return {};
123 int64_t rank = type.getRank();
124 int64_t num = op->getNumOperands() + op->getNumResults();
125 SmallVector<AffineMap> maps(num,
126 AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx));
127 return maps;
128 }
129
130 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
131 ArrayRef<MeshShardingAttr> operandShardings,
132 ArrayRef<MeshShardingAttr> resultShardings,
133 IRMapping &spmdizationMap,
134 SymbolTableCollection &symbolTable,
135 OpBuilder &builder) const {
136 spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
137 resultShardings, spmdizationMap,
138 symbolTable, builder);
139 return success();
140 }
141};
142
143} // namespace mesh
144} // namespace mlir
145
146#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
147

source code of mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h