1//===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15#include "mlir/Dialect/Tosa/Transforms/Passes.h"
16#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/Pass/Pass.h"
19
20using namespace mlir;
21using namespace mlir::tosa;
22
23namespace {
24
25struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
26 explicit DepthwiseConv2DIsMul(MLIRContext *context)
27 : OpRewritePattern(context) {}
28
29 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
30 PatternRewriter &rewriter) const override {
31 Value input = op.getInput();
32 Value weight = op.getWeight();
33 ShapedType inputType = cast<ShapedType>(input.getType());
34 ShapedType weightType = cast<ShapedType>(weight.getType());
35 ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
36
37 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
38 resultType.hasStaticShape())) {
39 return failure();
40 }
41
42 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
43 return failure();
44
45 // Only works for a 1x1 kernel.
46 ArrayRef<int64_t> weightShape = weightType.getShape();
47 if (weightShape[0] != 1 || weightShape[1] != 1) {
48 return failure();
49 }
50
51 Type inputETy = inputType.getElementType();
52 Type weightETy = weightType.getElementType();
53 if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
54 return rewriter.notifyMatchFailure(op, "unsupported type");
55
56 // Get and verify zero points.
57 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
58 if (failed(Result: maybeIZp))
59 return rewriter.notifyMatchFailure(
60 op, "input zero point cannot be statically determined");
61
62 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
63 if (failed(Result: maybeWZp))
64 return rewriter.notifyMatchFailure(
65 op, "weight zero point cannot be statically determined");
66
67 int64_t iZp = *maybeIZp;
68 int64_t wZp = *maybeWZp;
69 if (op.verifyInputZeroPoint(iZp).failed())
70 return rewriter.notifyMatchFailure(
71 op, "input zero point must be zero for non-int8 integer types");
72 if (op.verifyWeightZeroPoint(wZp).failed())
73 return rewriter.notifyMatchFailure(
74 op, "weight zero point must be zero for non-int8 integer types");
75
76 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
77 ArrayRef<int64_t> inputShape = inputType.getShape();
78 llvm::SmallVector<int64_t, 2> revisedInputShape{
79 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
80 inputType = RankedTensorType::get(
81 revisedInputShape,
82 dyn_cast<RankedTensorType>(input.getType()).getElementType());
83 auto revisedInputShapeValue =
84 getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
85 input = rewriter
86 .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
87 revisedInputShapeValue)
88 .getResult();
89
90 Type resultETy = resultType.getElementType();
91
92 if (inputETy != resultETy) {
93 inputType = inputType.clone(resultETy);
94 input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
95 }
96
97 if (weightETy != resultETy) {
98 weightType = weightType.clone(resultETy);
99 weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
100 }
101
102 if (iZp != 0 || wZp != 0) {
103
104 auto applyZp = [&](Value val, int64_t zp) -> Value {
105 if (zp == 0)
106 return val;
107 auto ety = cast<ShapedType>(val.getType()).getElementType();
108 std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
109 1);
110 auto zpTy = RankedTensorType::get(shape, ety);
111 auto zpAttr =
112 DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
113 auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
114 return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
115 zpVal);
116 };
117
118 input = applyZp(input, iZp);
119 weight = applyZp(weight, wZp);
120 }
121
122 ArrayRef<int64_t> padAttr = op.getPad();
123 llvm::SmallVector<int64_t> pad(10, 0);
124 for (const auto &it : llvm::enumerate(padAttr))
125 pad[it.index() + 2] = it.value();
126
127 if (llvm::any_of(Range&: pad, P: [](int64_t p) { return p != 0; })) {
128 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
129
130 llvm::SmallVector<int64_t> newShape(inputType.getShape());
131 for (int i = 0, s = pad.size(); i < s; ++i) {
132 if (newShape[i / 2] != ShapedType::kDynamic) {
133 newShape[i / 2] += pad[i];
134 }
135 }
136
137 Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
138
139 auto padTy = RankedTensorType::get({1}, inputETy);
140 auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
141 Value padVal =
142 rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
143 inputType = RankedTensorType::get(newShape, inputETy);
144 input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
145 padSizeVal, padVal);
146 }
147
148 // Perform an elementwise mul over the reshaped input and weight.
149 llvm::SmallVector<int64_t, 2> mulShape{
150 inputType.getDimSize(0), inputType.getDimSize(1),
151 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
152 auto mulShapeType = RankedTensorType::get(
153 mulShape,
154 dyn_cast<RankedTensorType>(weight.getType()).getElementType());
155
156 if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
157 return failure();
158 }
159
160 auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
161 auto shiftType = RankedTensorType::get({1}, shiftElementType);
162 auto shiftZeroAttr = DenseElementsAttr::get(
163 shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
164 Value constZero =
165 rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr);
166 Value mulValue = rewriter
167 .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
168 weight, constZero)
169 .getResult();
170
171 // Reshape output to [N, H, W, C * M].
172 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
173 auto outputShapeType = RankedTensorType::get(
174 outputShape,
175 dyn_cast<RankedTensorType>(input.getType()).getElementType());
176 auto outputShapeValue =
177 getTosaConstShape(rewriter, op->getLoc(), outputShape);
178 Value outputValue = rewriter.create<tosa::ReshapeOp>(
179 op.getLoc(), outputShapeType, mulValue, outputShapeValue);
180
181 Value bias = op.getBias();
182 if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
183 return failure();
184 }
185
186 // Add in the bias.
187 rewriter
188 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
189 .getResult();
190 return success();
191 }
192};
193
194} // namespace
195
196void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
197 RewritePatternSet &patterns) {
198 patterns.add<DepthwiseConv2DIsMul>(arg&: ctx);
199}
200

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp