1//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
11
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Value.h"
16
17namespace mlir {
18
19class Operation;
20class IRMapping;
21class SymbolTableCollection;
22
23namespace mesh {
24
25// Retrieve the mesh axes corresponding to each operation loop iterator based
26// on the provided shardings for the op's operands and results.
27// Assumes that the indexingMaps are projected permutations.
28ShardingArray getMeshAxisAssignmentForLoopIterators(
29 ArrayRef<MeshSharding> operandShardings,
30 ArrayRef<MeshSharding> resultShardings,
31 ArrayRef<utils::IteratorType> loopIteratorTypes,
32 ArrayRef<AffineMap> indexingMaps);
33
34bool isAtLeastOneReductionIteratorSharded(
35 ArrayRef<utils::IteratorType> loopIteratorTypes,
36 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37
38// Get the set of mesh axes that correspond to reduction loop iterators.
39SmallVector<MeshAxis> getReductionMeshAxes(
40 ArrayRef<utils::IteratorType> loopIteratorTypes,
41 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42
43// Inserts a clone of the operation that has all ranked tensor
44// arguments/results sharded.
45void spmdizeTriviallyShardableOperation(Operation &op,
46 ArrayRef<Value> spmdizedOperands,
47 ArrayRef<MeshSharding> operandShardings,
48 ArrayRef<MeshSharding> resultShardings,
49 IRMapping &spmdizationMap,
50 SymbolTableCollection &symbolTable,
51 OpBuilder &builder);
52
53// All ranked tensor argument and result dimensions have
54// independent parallel loop iterators.
55template <typename Op>
56struct IndependentParallelIteratorDomainShardingInterface
57 : public ShardingInterface::ExternalModel<
58 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
59 SmallVector<utils::IteratorType>
60 getLoopIteratorTypes(Operation *operation) const {
61 SmallVector<utils::IteratorType> iterTypes;
62 for (Type t : operation->getOperandTypes()) {
63 populateIteratorTypes(t, iterTypes);
64 }
65 for (Type t : operation->getResultTypes()) {
66 populateIteratorTypes(t, iterTypes);
67 }
68 return iterTypes;
69 }
70
71 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
72 // TODO: implement.
73 return SmallVector<AffineMap>();
74 }
75
76 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
77 ArrayRef<MeshSharding> operandShardings,
78 ArrayRef<MeshSharding> resultShardings,
79 IRMapping &spmdizationMap,
80 SymbolTableCollection &symbolTable,
81 OpBuilder &builder) const {
82 spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings,
83 resultShardings, spmdizationMap,
84 symbolTable, builder);
85 return success();
86 }
87
88private:
89 void
90 populateIteratorTypes(Type t,
91 SmallVector<utils::IteratorType> &iterTypes) const {
92 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
93 if (!rankedTensorType) {
94 return;
95 }
96
97 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
98 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
99 iterTypes.push_back(utils::IteratorType::parallel);
100 }
101 }
102};
103
104// Sharding of elementwise operations like tensor addition and multiplication.
105template <typename ElemwiseOp>
106struct ElementwiseShardingInterface
107 : public ShardingInterface::ExternalModel<
108 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
109 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
110 Value val = op->getOperand(idx: 0);
111 auto type = dyn_cast<RankedTensorType>(val.getType());
112 if (!type)
113 return {};
114 SmallVector<utils::IteratorType> types(type.getRank(),
115 utils::IteratorType::parallel);
116 return types;
117 }
118
119 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
120 MLIRContext *ctx = op->getContext();
121 Value val = op->getOperand(idx: 0);
122 auto type = dyn_cast<RankedTensorType>(val.getType());
123 if (!type)
124 return {};
125 int64_t rank = type.getRank();
126 int64_t num = op->getNumOperands() + op->getNumResults();
127 SmallVector<AffineMap> maps(num,
128 AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx));
129 return maps;
130 }
131
132 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
133 ArrayRef<MeshSharding> operandShardings,
134 ArrayRef<MeshSharding> resultShardings,
135 IRMapping &spmdizationMap,
136 SymbolTableCollection &symbolTable,
137 OpBuilder &builder) const {
138 spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings,
139 resultShardings, spmdizationMap,
140 symbolTable, builder);
141 return success();
142 }
143};
144
145} // namespace mesh
146} // namespace mlir
147
148#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
149

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h