1 | //===- ShardingInterfaceImpl.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" |
10 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
11 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
12 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" |
13 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
14 | #include "mlir/IR/AffineMap.h" |
15 | #include "mlir/IR/DialectRegistry.h" |
16 | #include "llvm/Support/Debug.h" |
17 | |
18 | #define DEBUG_TYPE "tosa-sharding-impl" |
19 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
20 | |
21 | using namespace mlir; |
22 | using namespace mlir::tosa; |
23 | using namespace mlir::mesh; |
24 | |
25 | namespace { |
26 | |
27 | // loop types: [parallel, parallel, parallel, reduction_sum] |
28 | // indexing maps: |
29 | // (d0, d1, d2, d3) -> (d0, d1, d3) |
30 | // (d0, d1, d2, d3) -> (d0, d3, d2) |
31 | // (d0, d1, d2, d3) -> (d0, d1, d2) |
32 | struct MatMulOpSharding |
33 | : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> { |
34 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
35 | auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
36 | if (!tensorType) |
37 | return {}; |
38 | |
39 | SmallVector<utils::IteratorType> types(tensorType.getRank() + 1, |
40 | utils::IteratorType::parallel); |
41 | types[tensorType.getRank()] = utils::IteratorType::reduction; |
42 | return types; |
43 | } |
44 | |
45 | SmallVector<ReductionKind> |
46 | getReductionLoopIteratorKinds(Operation *op) const { |
47 | return SmallVector<ReductionKind>(1, ReductionKind::Sum); |
48 | } |
49 | |
50 | SmallVector<AffineMap> getIndexingMaps(Operation *op) const { |
51 | auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
52 | if (!tensorType) |
53 | return {}; |
54 | MLIRContext *ctx = op->getContext(); |
55 | SmallVector<AffineMap> maps; |
56 | maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 3}, context: ctx)); |
57 | maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 3, 2}, context: ctx)); |
58 | maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 2}, context: ctx)); |
59 | return maps; |
60 | } |
61 | }; |
62 | |
63 | template <typename OpType> |
64 | static void registerElemwiseOne(MLIRContext *ctx) { |
65 | OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx); |
66 | } |
67 | |
68 | /// Variadic helper function. |
69 | template <typename... OpTypes> |
70 | static void registerElemwiseAll(MLIRContext *ctx) { |
71 | (registerElemwiseOne<OpTypes>(ctx), ...); |
72 | } |
73 | |
74 | } // namespace |
75 | |
76 | void mlir::tosa::registerShardingInterfaceExternalModels( |
77 | DialectRegistry ®istry) { |
78 | |
79 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TosaDialect *dialect) { |
80 | registerElemwiseAll< |
81 | ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, |
82 | BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp, |
83 | LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, |
84 | MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, |
85 | LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, |
86 | GreaterOp, GreaterEqualOp>(ctx); |
87 | |
88 | MatMulOp::attachInterface<MatMulOpSharding>(*ctx); |
89 | }); |
90 | } |
91 | |