1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
10#include "mlir/Dialect/Mesh/IR/MeshOps.h"
11#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
12#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
13#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/DialectRegistry.h"
16
17#define DEBUG_TYPE "tosa-sharding-impl"
18#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19
20using namespace mlir;
21using namespace mlir::tosa;
22using namespace mlir::mesh;
23
24namespace {
25
26// loop types: [parallel, parallel, parallel, reduction_sum]
27// indexing maps:
28// (d0, d1, d2, d3) -> (d0, d1, d3)
29// (d0, d1, d2, d3) -> (d0, d3, d2)
30// (d0, d1, d2, d3) -> (d0, d1, d2)
31struct MatMulOpSharding
32 : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
33 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
34 auto tensorType = dyn_cast<RankedTensorType>(Val: op->getResult(idx: 0).getType());
35 if (!tensorType)
36 return {};
37
38 SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
39 utils::IteratorType::parallel);
40 types[tensorType.getRank()] = utils::IteratorType::reduction;
41 return types;
42 }
43
44 SmallVector<ReductionKind>
45 getReductionLoopIteratorKinds(Operation *op) const {
46 return SmallVector<ReductionKind>(1, ReductionKind::Sum);
47 }
48
49 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
50 auto tensorType = dyn_cast<RankedTensorType>(Val: op->getResult(idx: 0).getType());
51 if (!tensorType)
52 return {};
53 MLIRContext *ctx = op->getContext();
54 SmallVector<AffineMap> maps;
55 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 3}, context: ctx));
56 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 3, 2}, context: ctx));
57 maps.push_back(Elt: AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx));
58 maps.push_back(Elt: AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx));
59 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 2}, context: ctx));
60 return maps;
61 }
62};
63
64struct NegateOpSharding
65 : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
66 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
67 Value val = op->getOperand(idx: 0);
68 auto type = dyn_cast<RankedTensorType>(Val: val.getType());
69 if (!type)
70 return {};
71 SmallVector<utils::IteratorType> types(type.getRank(),
72 utils::IteratorType::parallel);
73 return types;
74 }
75
76 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
77 MLIRContext *ctx = op->getContext();
78 Value val = op->getOperand(idx: 0);
79 auto type = dyn_cast<RankedTensorType>(Val: val.getType());
80 if (!type)
81 return {};
82 int64_t rank = type.getRank();
83 SmallVector<AffineMap> maps = {
84 AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx),
85 AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx), AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx),
86 AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx)};
87 return maps;
88 }
89
90 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
91 ArrayRef<MeshSharding> operandShardings,
92 ArrayRef<MeshSharding> resultShardings,
93 IRMapping &spmdizationMap,
94 SymbolTableCollection &symbolTable,
95 OpBuilder &builder) const {
96 spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings,
97 resultShardings, spmdizationMap,
98 symbolTable, builder);
99 return success();
100 }
101};
102
103template <typename OpType>
104static void registerElemwiseOne(MLIRContext *ctx) {
105 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
106}
107
108/// Variadic helper function.
109template <typename... OpTypes>
110static void registerElemwiseAll(MLIRContext *ctx) {
111 (registerElemwiseOne<OpTypes>(ctx), ...);
112}
113
114} // namespace
115
116void mlir::tosa::registerShardingInterfaceExternalModels(
117 DialectRegistry &registry) {
118
119 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TosaDialect *dialect) {
120 registerElemwiseAll<
121 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
122 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
123 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
124 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
125 LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
126 GreaterOp, GreaterEqualOp>(ctx);
127
128 MatMulOp::attachInterface<MatMulOpSharding>(context&: *ctx);
129 NegateOp::attachInterface<NegateOpSharding>(context&: *ctx);
130 });
131}
132

source code of mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp