1//===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15#include "mlir/Dialect/Tosa/Transforms/Passes.h"
16#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17#include "mlir/Pass/Pass.h"
18
19using namespace mlir;
20using namespace mlir::tosa;
21
22namespace {
23
24struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
25 explicit DepthwiseConv2DIsMul(MLIRContext *context)
26 : OpRewritePattern(context) {}
27
28 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
29 PatternRewriter &rewriter) const override {
30 Value input = op.getInput();
31 Value weight = op.getWeight();
32 ShapedType inputType = cast<ShapedType>(input.getType());
33 ShapedType weightType = cast<ShapedType>(weight.getType());
34 ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
35
36 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37 resultType.hasStaticShape())) {
38 return failure();
39 }
40
41 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
42 return failure();
43
44 // Only works for a 1x1 kernel.
45 ArrayRef<int64_t> weightShape = weightType.getShape();
46 if (weightShape[0] != 1 || weightShape[1] != 1) {
47 return failure();
48 }
49
50 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
51 ArrayRef<int64_t> inputShape = inputType.getShape();
52 llvm::SmallVector<int64_t, 2> revisedInputShape{
53 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
54 inputType = RankedTensorType::get(
55 revisedInputShape,
56 dyn_cast<RankedTensorType>(input.getType()).getElementType());
57 input = rewriter
58 .create<tosa::ReshapeOp>(
59 op.getLoc(), inputType, input,
60 rewriter.getDenseI64ArrayAttr(revisedInputShape))
61 .getResult();
62
63 if (inputType.getElementType() != resultType.getElementType()) {
64 inputType = inputType.clone(resultType.getElementType());
65 input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
66 }
67
68 if (weightType.getElementType() != resultType.getElementType()) {
69 weightType = weightType.clone(resultType.getElementType());
70 weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
71 }
72
73 if (auto quantizationInfo = op.getQuantizationInfo()) {
74 auto iZp = quantizationInfo->getInputZp();
75 auto wZp = quantizationInfo->getWeightZp();
76
77 auto applyZp = [&](Value val, int64_t zp) -> Value {
78 if (zp == 0)
79 return val;
80 auto ety = cast<ShapedType>(val.getType()).getElementType();
81 std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
82 1);
83 auto zpTy = RankedTensorType::get(shape, ety);
84 auto zpAttr =
85 DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
86 auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
87 return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
88 zpVal);
89 };
90
91 input = applyZp(input, iZp);
92 weight = applyZp(weight, wZp);
93 }
94
95 ArrayRef<int64_t> padAttr = op.getPad();
96 llvm::SmallVector<int64_t> pad(10, 0);
97 for (const auto &it : llvm::enumerate(padAttr))
98 pad[it.index() + 2] = it.value();
99
100 if (llvm::any_of(Range&: pad, P: [](int64_t p) { return p != 0; })) {
101 Type inputETy = inputType.getElementType();
102 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
103
104 llvm::SmallVector<int64_t> newShape(inputType.getShape());
105 for (int i = 0, s = pad.size(); i < s; ++i) {
106 if (newShape[i / 2] != ShapedType::kDynamic) {
107 newShape[i / 2] += pad[i];
108 }
109 }
110
111 auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
112 auto padSize =
113 DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
114 Value padSizeVal =
115 rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
116
117 auto padTy = RankedTensorType::get({}, inputETy);
118 auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
119 Value padVal =
120 rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
121 inputType = RankedTensorType::get(newShape, inputETy);
122 input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
123 padSizeVal, padVal);
124 }
125
126 // Perform an elementwise mul over the reshaped input and weight.
127 llvm::SmallVector<int64_t, 2> mulShape{
128 inputType.getDimSize(0), inputType.getDimSize(1),
129 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
130 auto mulShapeType = RankedTensorType::get(
131 mulShape,
132 dyn_cast<RankedTensorType>(weight.getType()).getElementType());
133
134 if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
135 return failure();
136 }
137
138 Value mulValue = rewriter
139 .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
140 weight, /*shift=*/0)
141 .getResult();
142
143 // Reshape output to [N, H, W, C * M].
144 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
145 auto outputShapeType = RankedTensorType::get(
146 outputShape,
147 dyn_cast<RankedTensorType>(input.getType()).getElementType());
148 Value outputValue = rewriter.create<tosa::ReshapeOp>(
149 op.getLoc(), outputShapeType, mulValue,
150 rewriter.getDenseI64ArrayAttr(outputShape));
151
152 Value bias = op.getBias();
153 if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
154 return failure();
155 }
156
157 // Add in the bias.
158 rewriter
159 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
160 .getResult();
161 return success();
162 }
163};
164
165} // namespace
166
167void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
168 RewritePatternSet &patterns) {
169 patterns.add<DepthwiseConv2DIsMul>(arg&: ctx);
170}
171

source code of mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp