1 | //===- TosaDecomposeTransposeConv.cpp -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically |
10 | // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping |
11 | // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D |
12 | // including transposing/reversing/reshaping etc.. |
13 | // of the weights and input/output tenors and reversing/reshaping etc .. of |
14 | // the weights |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
19 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
20 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
21 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::tosa; |
26 | |
27 | namespace { |
28 | |
29 | class TransposeConvNonStridedConverter |
30 | : public OpRewritePattern<tosa::TransposeConv2DOp> { |
31 | public: |
32 | using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; |
33 | LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, |
34 | PatternRewriter &rewriter) const final { |
35 | Location loc = op->getLoc(); |
36 | Value input = op->getOperand(0); |
37 | Value weight = op->getOperand(1); |
38 | Value bias = op->getOperand(2); |
39 | |
40 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
41 | ShapedType weightTy = cast<ShapedType>(weight.getType()); |
42 | ShapedType biasTy = cast<ShapedType>(bias.getType()); |
43 | ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); |
44 | |
45 | llvm::ArrayRef<int64_t> stride = op.getStride(); |
46 | llvm::ArrayRef<int64_t> pad = op.getOutPad(); |
47 | |
48 | // If striding is all 1 we can modify padding and reverse the kernel along |
49 | // the x/y direction to make it a regular convolution. This is much simpler |
50 | // then handling striding.... |
51 | if (llvm::any_of(Range&: stride, P: [](int64_t v) { return v != 1; })) |
52 | return failure(); |
53 | |
54 | if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || |
55 | !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) |
56 | return failure(); |
57 | |
58 | int64_t kernelHeight = weightTy.getDimSize(1); |
59 | int64_t kernelWidth = weightTy.getDimSize(2); |
60 | |
61 | llvm::SmallVector<int64_t> convPad(4, 0); |
62 | convPad[0] = kernelHeight - 1 + pad[0]; |
63 | convPad[1] = kernelHeight - 1 + pad[1]; |
64 | convPad[2] = kernelWidth - 1 + pad[2]; |
65 | convPad[3] = kernelWidth - 1 + pad[3]; |
66 | |
67 | auto reverse1 = rewriter.create<tosa::ReverseOp>( |
68 | loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1)); |
69 | auto reverse2 = rewriter.create<tosa::ReverseOp>( |
70 | loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); |
71 | |
72 | Value conv2d = rewriter.create<tosa::Conv2DOp>( |
73 | loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(), |
74 | rewriter.getDenseI64ArrayAttr(convPad), |
75 | rewriter.getDenseI64ArrayAttr(stride), |
76 | rewriter.getDenseI64ArrayAttr({1, 1}), |
77 | /* acc_type = */ op.getAccType()); |
78 | |
79 | rewriter.replaceOp(op, conv2d); |
80 | return success(); |
81 | } |
82 | }; |
83 | |
84 | class TransposeConvStridedConverter |
85 | : public OpRewritePattern<tosa::TransposeConv2DOp> { |
86 | public: |
87 | using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; |
88 | LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, |
89 | PatternRewriter &rewriter) const final { |
90 | Location loc = op->getLoc(); |
91 | Value input = op->getOperand(0); |
92 | Value weight = op->getOperand(1); |
93 | Value bias = op->getOperand(2); |
94 | |
95 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
96 | ShapedType weightTy = cast<ShapedType>(weight.getType()); |
97 | ShapedType biasTy = cast<ShapedType>(bias.getType()); |
98 | ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); |
99 | |
100 | Type inputETy = inputTy.getElementType(); |
101 | Type weightETy = weightTy.getElementType(); |
102 | Type biasETy = biasTy.getElementType(); |
103 | Type resultETy = resultTy.getElementType(); |
104 | |
105 | llvm::ArrayRef<int64_t> pad = op.getOutPad(); |
106 | llvm::ArrayRef<int64_t> stride = op.getStride(); |
107 | |
108 | // If striding is all 1 we can modify padding and reverse the kernel along |
109 | // the x/y direction to make it a regular convolution. This is much simpler |
110 | // then handling striding.... |
111 | |
112 | // If strides are all 1 we dont need to use this one. |
113 | if (llvm::all_of(Range&: stride, P: [](int64_t v) { return v == 1; })) |
114 | return rewriter.notifyMatchFailure(op, "non-one stride found." ); |
115 | |
116 | if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || |
117 | !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) |
118 | return failure(); |
119 | |
120 | int64_t batch = inputTy.getDimSize(0); |
121 | |
122 | int64_t outputChannels = weightTy.getDimSize(0); |
123 | int64_t weightHeight = weightTy.getDimSize(1); |
124 | int64_t weightWidth = weightTy.getDimSize(2); |
125 | int64_t inputChannels = weightTy.getDimSize(3); |
126 | |
127 | // Pad the weight so that it is modulo of the striding. |
128 | llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; |
129 | weightPadding[3] = |
130 | (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0; |
131 | weightPadding[5] = |
132 | (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0; |
133 | |
134 | Value weightPaddingVal = |
135 | getTosaConstShape(rewriter, op->getLoc(), weightPadding); |
136 | |
137 | // Get and verify zero points. |
138 | FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); |
139 | if (failed(Result: maybeIZp)) |
140 | return rewriter.notifyMatchFailure( |
141 | op, "input zero point cannot be statically determined" ); |
142 | |
143 | FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint(); |
144 | if (failed(Result: maybeWZp)) |
145 | return rewriter.notifyMatchFailure( |
146 | op, "weight zero point cannot be statically determined" ); |
147 | |
148 | int64_t inputZpVal = *maybeIZp; |
149 | int64_t weightZpVal = *maybeWZp; |
150 | |
151 | if (op.verifyInputZeroPoint(inputZpVal).failed()) |
152 | return rewriter.notifyMatchFailure( |
153 | op, "input zero point must be zero for non-int8 integer types" ); |
154 | |
155 | if (op.verifyWeightZeroPoint(weightZpVal).failed()) |
156 | return rewriter.notifyMatchFailure( |
157 | op, "weight zero point must be zero for non-int8 integer types" ); |
158 | |
159 | // construct pad_const values from zp values |
160 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
161 | const Value inputPadConst = |
162 | createPadConstTensor(builder, op->getLoc(), input, inputZpVal); |
163 | const Value weightPadConst = |
164 | createPadConstTensor(builder, op->getLoc(), input, weightZpVal); |
165 | |
166 | weight = CreateOpAndInferShape<tosa::PadOp>( |
167 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
168 | weightPaddingVal, weightPadConst); |
169 | |
170 | weightTy = cast<ShapedType>(weight.getType()); |
171 | weightHeight = weightTy.getDimSize(1); |
172 | weightWidth = weightTy.getDimSize(2); |
173 | |
174 | // Split out the width / height by the stride dimensions. |
175 | llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { |
176 | outputChannels, weightHeight / stride[0], |
177 | stride[0], weightWidth / stride[1], |
178 | stride[1], inputChannels}; |
179 | |
180 | weight = CreateOpAndInferShape<tosa::ReshapeOp>( |
181 | builder, UnrankedTensorType::get(weightETy), weight, |
182 | getTosaConstShape(rewriter, loc, weightReshapeDims0)); |
183 | |
184 | // Transpose the factored-out stride to the output channels. |
185 | weight = CreateOpAndInferShape<tosa::TransposeOp>( |
186 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
187 | rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5})); |
188 | |
189 | // Collapse the strides and output channels into a single dimension. |
190 | llvm::SmallVector<int64_t, 4> weightReshapeDims1 = { |
191 | outputChannels * stride[0] * stride[1], weightHeight / stride[0], |
192 | weightWidth / stride[1], inputChannels}; |
193 | |
194 | weight = CreateOpAndInferShape<tosa::ReshapeOp>( |
195 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
196 | getTosaConstShape(rewriter, loc, weightReshapeDims1)); |
197 | ShapedType restridedWeightTy = cast<ShapedType>(weight.getType()); |
198 | |
199 | weight = CreateOpAndInferShape<tosa::ReverseOp>( |
200 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
201 | /* axis = */ rewriter.getI32IntegerAttr(1)); |
202 | weight = CreateOpAndInferShape<tosa::ReverseOp>( |
203 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
204 | /* axis = */ rewriter.getI32IntegerAttr(2)); |
205 | |
206 | // We need to pad the input far enough that we can pull all values. |
207 | llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; |
208 | inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; |
209 | inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; |
210 | inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; |
211 | inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; |
212 | |
213 | Value inputPaddingVal = |
214 | getTosaConstShape(rewriter, op->getLoc(), inputPadding); |
215 | |
216 | input = CreateOpAndInferShape<tosa::PadOp>( |
217 | rewriter, loc, UnrankedTensorType::get(inputETy), input, |
218 | inputPaddingVal, inputPadConst); |
219 | |
220 | // We use a zero bias as we need to broadcast the bias. |
221 | auto zeroBias = rewriter.create<tosa::ConstOp>( |
222 | loc, |
223 | RankedTensorType::get({outputChannels * stride[0] * stride[1]}, |
224 | biasETy), |
225 | DenseElementsAttr::get( |
226 | RankedTensorType::get({outputChannels * stride[0] * stride[1]}, |
227 | biasETy), |
228 | rewriter.getZeroAttr(biasETy))); |
229 | |
230 | auto inputZp = |
231 | createZeroPointTensor(builder&: rewriter, loc, srcElemType: input.getType(), zp: inputZpVal); |
232 | auto weightZp = |
233 | createZeroPointTensor(builder&: rewriter, loc, srcElemType: weight.getType(), zp: weightZpVal); |
234 | |
235 | if (!inputZp.has_value() || !weightZp.has_value()) { |
236 | return rewriter.notifyMatchFailure( |
237 | op, "fail to create a const zero point tensor" ); |
238 | } |
239 | |
240 | // Perform the convolution using the zero bias. |
241 | Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( |
242 | rewriter, loc, UnrankedTensorType::get(resultETy), input, |
243 | weight, zeroBias, inputZp.value(), weightZp.value(), |
244 | /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), |
245 | /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), |
246 | /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), |
247 | /* acc_type = */ op.getAccType()) |
248 | .getResult(); |
249 | |
250 | // Factor the resulting width / height. |
251 | ShapedType convTy = cast<ShapedType>(conv2d.getType()); |
252 | Type convETy = convTy.getElementType(); |
253 | |
254 | int64_t convHeight = convTy.getDimSize(1); |
255 | int64_t convWidth = convTy.getDimSize(2); |
256 | |
257 | // Factor striding out of the convolution result. |
258 | llvm::SmallVector<int64_t, 6> convReshapeDims0 = { |
259 | batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; |
260 | |
261 | auto convReshapeDims0Value = |
262 | getTosaConstShape(rewriter, loc, shape: convReshapeDims0); |
263 | |
264 | conv2d = CreateOpAndInferShape<tosa::ReshapeOp>( |
265 | rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, |
266 | convReshapeDims0Value); |
267 | |
268 | // Transpose the factored-out stride to the output channels. |
269 | conv2d = CreateOpAndInferShape<tosa::TransposeOp>( |
270 | rewriter, loc, UnrankedTensorType::get(convETy), conv2d, |
271 | rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5})); |
272 | |
273 | // Fuse striding behavior back into width / height. |
274 | llvm::SmallVector<int64_t, 6> convReshapeDims1 = { |
275 | batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; |
276 | |
277 | auto convReshapeDims1Value = |
278 | getTosaConstShape(rewriter, loc, shape: convReshapeDims1); |
279 | |
280 | conv2d = CreateOpAndInferShape<tosa::ReshapeOp>( |
281 | rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, |
282 | convReshapeDims1Value); |
283 | |
284 | // Determine the amount to slice / pad from the result start. |
285 | int64_t resultSliceTop = std::max<int64_t>(a: 0, b: -pad[0]); |
286 | int64_t resultSliceLeft = std::max<int64_t>(a: 0, b: -pad[2]); |
287 | int64_t resultPadTop = std::max<int64_t>(a: 0, b: pad[0]); |
288 | int64_t resultPadLeft = std::max<int64_t>(a: 0, b: pad[2]); |
289 | |
290 | // Try to slice the targetted result size, cap to the convolutions width. |
291 | int64_t resultSliceHeight = |
292 | std::min<int64_t>(convReshapeDims1[1] - resultSliceTop, |
293 | resultTy.getDimSize(1) - resultPadTop); |
294 | int64_t resultSliceWidth = |
295 | std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft, |
296 | resultTy.getDimSize(2) - resultPadLeft); |
297 | |
298 | llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop, |
299 | resultSliceLeft, 0}; |
300 | llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(), |
301 | convReshapeDims1.end()); |
302 | sliceSize[1] = resultSliceHeight; |
303 | sliceSize[2] = resultSliceWidth; |
304 | |
305 | auto slice = CreateOpAndInferShape<tosa::SliceOp>( |
306 | rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, |
307 | getTosaConstShape(rewriter, loc, sliceBegin), |
308 | getTosaConstShape(rewriter, loc, sliceSize)) |
309 | .getResult(); |
310 | |
311 | llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; |
312 | resultPadding[2] = resultPadTop; |
313 | resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1]; |
314 | resultPadding[4] = resultPadLeft; |
315 | resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; |
316 | |
317 | Value resultPaddingVal = |
318 | getTosaConstShape(rewriter, op->getLoc(), resultPadding); |
319 | |
320 | Value resultPad = CreateOpAndInferShape<tosa::PadOp>( |
321 | rewriter, loc, UnrankedTensorType::get(resultETy), slice, |
322 | resultPaddingVal); |
323 | |
324 | if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) { |
325 | return failure(); |
326 | } |
327 | |
328 | rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias); |
329 | return success(); |
330 | } |
331 | }; |
332 | |
333 | } // namespace |
334 | |
335 | void mlir::tosa::populateTosaDecomposeTransposeConv( |
336 | MLIRContext *ctx, RewritePatternSet &patterns) { |
337 | patterns.add<TransposeConvNonStridedConverter>(arg&: ctx); |
338 | patterns.add<TransposeConvStridedConverter>(arg&: ctx); |
339 | } |
340 | |