1//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
10#include "mlir/Dialect/Mesh/IR/MeshOps.h"
11#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
12#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
13#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/DialectRegistry.h"
16#include "llvm/Support/Debug.h"
17
18#define DEBUG_TYPE "tosa-sharding-impl"
19#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20
21using namespace mlir;
22using namespace mlir::tosa;
23using namespace mlir::mesh;
24
25namespace {
26
27// loop types: [parallel, parallel, parallel, reduction_sum]
28// indexing maps:
29// (d0, d1, d2, d3) -> (d0, d1, d3)
30// (d0, d1, d2, d3) -> (d0, d3, d2)
31// (d0, d1, d2, d3) -> (d0, d1, d2)
32struct MatMulOpSharding
33 : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
34 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
35 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType());
36 if (!tensorType)
37 return {};
38
39 SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
40 utils::IteratorType::parallel);
41 types[tensorType.getRank()] = utils::IteratorType::reduction;
42 return types;
43 }
44
45 SmallVector<ReductionKind>
46 getReductionLoopIteratorKinds(Operation *op) const {
47 return SmallVector<ReductionKind>(1, ReductionKind::Sum);
48 }
49
50 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(idx: 0).getType());
52 if (!tensorType)
53 return {};
54 MLIRContext *ctx = op->getContext();
55 SmallVector<AffineMap> maps;
56 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 3}, context: ctx));
57 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 3, 2}, context: ctx));
58 maps.push_back(Elt: AffineMap::getMultiDimMapWithTargets(numDims: 4, targets: {0, 1, 2}, context: ctx));
59 return maps;
60 }
61};
62
63template <typename OpType>
64static void registerElemwiseOne(MLIRContext *ctx) {
65 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
66}
67
68/// Variadic helper function.
69template <typename... OpTypes>
70static void registerElemwiseAll(MLIRContext *ctx) {
71 (registerElemwiseOne<OpTypes>(ctx), ...);
72}
73
74} // namespace
75
76void mlir::tosa::registerShardingInterfaceExternalModels(
77 DialectRegistry &registry) {
78
79 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TosaDialect *dialect) {
80 registerElemwiseAll<
81 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
82 BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp,
83 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85 LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86 GreaterOp, GreaterEqualOp>(ctx);
87
88 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
89 });
90}
91

source code of mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp