1 | //===- ShardingInterfaceImpl.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" |
10 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
11 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
12 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" |
13 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
14 | #include "mlir/IR/AffineMap.h" |
15 | #include "mlir/IR/DialectRegistry.h" |
16 | #include "llvm/Support/Debug.h" |
17 | |
18 | #define DEBUG_TYPE "tosa-sharding-impl" |
19 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
20 | |
21 | using namespace mlir; |
22 | using namespace mlir::tosa; |
23 | using namespace mlir::mesh; |
24 | |
25 | namespace { |
26 | |
27 | // loop types: [parallel, parallel, parallel, reduction_sum] |
28 | // indexing maps: |
29 | // (d0, d1, d2, d3) -> (d0, d1, d3) |
30 | // (d0, d1, d2, d3) -> (d0, d3, d2) |
31 | // (d0, d1, d2, d3) -> (d0, d1, d2) |
32 | struct MatMulOpSharding |
33 | : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> { |
34 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
35 | auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
36 | if (!tensorType) |
37 | return {}; |
38 | |
39 | SmallVector<utils::IteratorType> types(tensorType.getRank() + 1, |
40 | utils::IteratorType::parallel); |
41 | types[tensorType.getRank()] = utils::IteratorType::reduction; |
42 | return types; |
43 | } |
44 | |
45 | SmallVector<ReductionKind> |
46 | getReductionLoopIteratorKinds(Operation *op) const { |
47 | return SmallVector<ReductionKind>(1, ReductionKind::Sum); |
48 | } |
49 | |
50 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
51 | auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
52 | if (!tensorType) |
53 | return {}; |
54 | MLIRContext *ctx = op->getContext(); |
55 | SmallVector<AffineMap> maps; |
56 | maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 3}, context: ctx)); |
57 | maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 3, 2}, context: ctx)); |
58 | maps.push_back(Elt: AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx)); |
59 | maps.push_back(Elt: AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx)); |
60 | maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 2}, context: ctx)); |
61 | return maps; |
62 | } |
63 | }; |
64 | |
65 | struct NegateOpSharding |
66 | : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> { |
67 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
68 | Value val = op->getOperand(idx: 0); |
69 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
70 | if (!type) |
71 | return {}; |
72 | SmallVector<utils::IteratorType> types(type.getRank(), |
73 | utils::IteratorType::parallel); |
74 | return types; |
75 | } |
76 | |
77 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
78 | MLIRContext *ctx = op->getContext(); |
79 | Value val = op->getOperand(idx: 0); |
80 | auto type = dyn_cast<RankedTensorType>(val.getType()); |
81 | if (!type) |
82 | return {}; |
83 | int64_t rank = type.getRank(); |
84 | SmallVector<AffineMap> maps = { |
85 | AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx), |
86 | AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx), AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx), |
87 | AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx)}; |
88 | return maps; |
89 | } |
90 | |
91 | LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, |
92 | ArrayRef<MeshSharding> operandShardings, |
93 | ArrayRef<MeshSharding> resultShardings, |
94 | IRMapping &spmdizationMap, |
95 | SymbolTableCollection &symbolTable, |
96 | OpBuilder &builder) const { |
97 | spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings, |
98 | resultShardings, spmdizationMap, |
99 | symbolTable, builder); |
100 | return success(); |
101 | } |
102 | }; |
103 | |
104 | template <typename OpType> |
105 | static void registerElemwiseOne(MLIRContext *ctx) { |
106 | OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx); |
107 | } |
108 | |
109 | /// Variadic helper function. |
110 | template <typename... OpTypes> |
111 | static void registerElemwiseAll(MLIRContext *ctx) { |
112 | (registerElemwiseOne<OpTypes>(ctx), ...); |
113 | } |
114 | |
115 | } // namespace |
116 | |
117 | void mlir::tosa::registerShardingInterfaceExternalModels( |
118 | DialectRegistry ®istry) { |
119 | |
120 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TosaDialect *dialect) { |
121 | registerElemwiseAll< |
122 | ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, |
123 | BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp, |
124 | LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, |
125 | MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, |
126 | LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, |
127 | GreaterOp, GreaterEqualOp>(ctx); |
128 | |
129 | MatMulOp::attachInterface<MatMulOpSharding>(*ctx); |
130 | NegateOp::attachInterface<NegateOpSharding>(*ctx); |
131 | }); |
132 | } |
133 | |