1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
10#include "mlir/Dialect/Mesh/IR/MeshOps.h"
11#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
12#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
13#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/DialectRegistry.h"
16#include "llvm/Support/Debug.h"
17
18#define DEBUG_TYPE "tosa-sharding-impl"
19#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20
21using namespace mlir;
22using namespace mlir::tosa;
23using namespace mlir::mesh;
24
25namespace {
26
27// loop types: [parallel, parallel, parallel, reduction_sum]
28// indexing maps:
29// (d0, d1, d2, d3) -> (d0, d1, d3)
30// (d0, d1, d2, d3) -> (d0, d3, d2)
31// (d0, d1, d2, d3) -> (d0, d1, d2)
32struct MatMulOpSharding
33 : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
34 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
35 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType());
36 if (!tensorType)
37 return {};
38
39 SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
40 utils::IteratorType::parallel);
41 types[tensorType.getRank()] = utils::IteratorType::reduction;
42 return types;
43 }
44
45 SmallVector<ReductionKind>
46 getReductionLoopIteratorKinds(Operation *op) const {
47 return SmallVector<ReductionKind>(1, ReductionKind::Sum);
48 }
49
50 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType());
52 if (!tensorType)
53 return {};
54 MLIRContext *ctx = op->getContext();
55 SmallVector<AffineMap> maps;
56 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 3}, context: ctx));
57 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 3, 2}, context: ctx));
58 maps.push_back(Elt: AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx));
59 maps.push_back(Elt: AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx));
60 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 2}, context: ctx));
61 return maps;
62 }
63};
64
65struct NegateOpSharding
66 : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
67 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
68 Value val = op->getOperand(idx: 0);
69 auto type = dyn_cast<RankedTensorType>(val.getType());
70 if (!type)
71 return {};
72 SmallVector<utils::IteratorType> types(type.getRank(),
73 utils::IteratorType::parallel);
74 return types;
75 }
76
77 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
78 MLIRContext *ctx = op->getContext();
79 Value val = op->getOperand(idx: 0);
80 auto type = dyn_cast<RankedTensorType>(val.getType());
81 if (!type)
82 return {};
83 int64_t rank = type.getRank();
84 SmallVector<AffineMap> maps = {
85 AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx),
86 AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx), AffineMap::get(dimCount: 0, symbolCount: 0, results: {}, context: ctx),
87 AffineMap::getMultiDimIdentityMap(numDims: rank, context: ctx)};
88 return maps;
89 }
90
91 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
92 ArrayRef<MeshSharding> operandShardings,
93 ArrayRef<MeshSharding> resultShardings,
94 IRMapping &spmdizationMap,
95 SymbolTableCollection &symbolTable,
96 OpBuilder &builder) const {
97 spmdizeTriviallyShardableOperation(op&: *op, spmdizedOperands, operandShardings,
98 resultShardings, spmdizationMap,
99 symbolTable, builder);
100 return success();
101 }
102};
103
104template <typename OpType>
105static void registerElemwiseOne(MLIRContext *ctx) {
106 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
107}
108
109/// Variadic helper function.
110template <typename... OpTypes>
111static void registerElemwiseAll(MLIRContext *ctx) {
112 (registerElemwiseOne<OpTypes>(ctx), ...);
113}
114
115} // namespace
116
117void mlir::tosa::registerShardingInterfaceExternalModels(
118 DialectRegistry &registry) {
119
120 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TosaDialect *dialect) {
121 registerElemwiseAll<
122 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
123 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
124 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
125 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
126 LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
127 GreaterOp, GreaterEqualOp>(ctx);
128
129 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
130 NegateOp::attachInterface<NegateOpSharding>(*ctx);
131 });
132}
133

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp