1 | //===- TosaDecomposeTransposeConv.cpp -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically |
10 | // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping |
11 | // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D |
12 | // including transposing/reversing/reshaping etc.. |
13 | // of the weights and input/output tenors and reversing/reshaping etc .. of |
14 | // the weights |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
19 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
20 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
21 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::tosa; |
26 | |
27 | namespace { |
28 | |
29 | template <typename TosaOp, typename... Args> |
30 | TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy, |
31 | Args &&...args) { |
32 | auto op = rewriter.create<TosaOp>(loc, resultTy, args...); |
33 | |
34 | InferShapedTypeOpInterface shapeInterface = |
35 | dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); |
36 | if (!shapeInterface) |
37 | return op; |
38 | |
39 | SmallVector<ShapedTypeComponents> returnedShapes; |
40 | if (shapeInterface |
41 | .inferReturnTypeComponents( |
42 | op.getContext(), op.getLoc(), op->getOperands(), |
43 | op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), |
44 | op->getRegions(), returnedShapes) |
45 | .failed()) |
46 | return op; |
47 | |
48 | // We need to use the element type of the existing result type to generate |
49 | // the new result shaped type. This is because rescale can include a cast to |
50 | // different bit-width types and does not have a TypeAttr to define the |
51 | // target type. |
52 | auto result = op->getResult(0); |
53 | auto predictedShape = returnedShapes[0]; |
54 | auto currentKnowledge = |
55 | mlir::tosa::ValueKnowledge::getKnowledgeFromType(type: resultTy); |
56 | |
57 | // Compute the knowledge based on the inferred type. |
58 | auto inferredKnowledge = |
59 | mlir::tosa::ValueKnowledge::getPessimisticValueState(); |
60 | inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType(); |
61 | inferredKnowledge.hasRank = predictedShape.hasRank(); |
62 | if (predictedShape.hasRank()) { |
63 | for (auto dim : predictedShape.getDims()) { |
64 | inferredKnowledge.sizes.push_back(Elt: dim); |
65 | } |
66 | } |
67 | |
68 | // Compute the new type based on the joined version. |
69 | auto newKnowledge = |
70 | mlir::tosa::ValueKnowledge::join(lhs: currentKnowledge, rhs: inferredKnowledge); |
71 | auto newTy = newKnowledge.getType(); |
72 | result.setType(newTy); |
73 | return op; |
74 | } |
75 | |
76 | class TransposeConvNonStridedConverter |
77 | : public OpRewritePattern<tosa::TransposeConv2DOp> { |
78 | public: |
79 | using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; |
80 | LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, |
81 | PatternRewriter &rewriter) const final { |
82 | Location loc = op->getLoc(); |
83 | Value input = op->getOperand(0); |
84 | Value weight = op->getOperand(1); |
85 | Value bias = op->getOperand(2); |
86 | |
87 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
88 | ShapedType weightTy = cast<ShapedType>(weight.getType()); |
89 | ShapedType biasTy = cast<ShapedType>(bias.getType()); |
90 | ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); |
91 | |
92 | llvm::ArrayRef<int64_t> stride = op.getStride(); |
93 | llvm::ArrayRef<int64_t> pad = op.getOutPad(); |
94 | |
95 | // If striding is all 1 we can modify padding and reverse the kernel along |
96 | // the x/y direction to make it a regular convolution. This is much simpler |
97 | // then handling striding.... |
98 | if (llvm::any_of(Range&: stride, P: [](int64_t v) { return v != 1; })) |
99 | return failure(); |
100 | |
101 | if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || |
102 | !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) |
103 | return failure(); |
104 | |
105 | int64_t kernelHeight = weightTy.getDimSize(1); |
106 | int64_t kernelWidth = weightTy.getDimSize(2); |
107 | |
108 | llvm::SmallVector<int64_t> convPad(4, 0); |
109 | convPad[0] = kernelHeight - 1 + pad[0]; |
110 | convPad[1] = kernelHeight - 1 + pad[1]; |
111 | convPad[2] = kernelWidth - 1 + pad[2]; |
112 | convPad[3] = kernelWidth - 1 + pad[3]; |
113 | |
114 | auto reverse1 = rewriter.create<tosa::ReverseOp>( |
115 | loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1)); |
116 | auto reverse2 = rewriter.create<tosa::ReverseOp>( |
117 | loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); |
118 | |
119 | Value conv2d; |
120 | if (op.getQuantizationInfo()) { |
121 | conv2d = rewriter.create<tosa::Conv2DOp>( |
122 | loc, resultTy, input, reverse2, bias, |
123 | rewriter.getDenseI64ArrayAttr(convPad), |
124 | rewriter.getDenseI64ArrayAttr(stride), |
125 | rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo()); |
126 | } else { |
127 | conv2d = rewriter.create<tosa::Conv2DOp>( |
128 | loc, resultTy, input, reverse2, bias, |
129 | rewriter.getDenseI64ArrayAttr(convPad), |
130 | rewriter.getDenseI64ArrayAttr(stride), |
131 | rewriter.getDenseI64ArrayAttr({1, 1})); |
132 | } |
133 | |
134 | rewriter.replaceOp(op, conv2d); |
135 | return success(); |
136 | } |
137 | }; |
138 | |
139 | class TransposeConvStridedConverter |
140 | : public OpRewritePattern<tosa::TransposeConv2DOp> { |
141 | public: |
142 | using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; |
143 | LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, |
144 | PatternRewriter &rewriter) const final { |
145 | Location loc = op->getLoc(); |
146 | Value input = op->getOperand(0); |
147 | Value weight = op->getOperand(1); |
148 | Value bias = op->getOperand(2); |
149 | |
150 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
151 | ShapedType weightTy = cast<ShapedType>(weight.getType()); |
152 | ShapedType biasTy = cast<ShapedType>(bias.getType()); |
153 | ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); |
154 | |
155 | Type inputETy = inputTy.getElementType(); |
156 | Type weightETy = weightTy.getElementType(); |
157 | Type biasETy = biasTy.getElementType(); |
158 | Type resultETy = resultTy.getElementType(); |
159 | |
160 | llvm::ArrayRef<int64_t> pad = op.getOutPad(); |
161 | llvm::ArrayRef<int64_t> stride = op.getStride(); |
162 | |
163 | // If striding is all 1 we can modify padding and reverse the kernel along |
164 | // the x/y direction to make it a regular convolution. This is much simpler |
165 | // then handling striding.... |
166 | |
167 | // If strides are all 1 we dont need to use this one. |
168 | if (llvm::all_of(Range&: stride, P: [](int64_t v) { return v == 1; })) |
169 | return rewriter.notifyMatchFailure(op, "non-one stride found." ); |
170 | |
171 | if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || |
172 | !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) |
173 | return failure(); |
174 | |
175 | int64_t batch = inputTy.getDimSize(0); |
176 | |
177 | int64_t outputChannels = weightTy.getDimSize(0); |
178 | int64_t weightHeight = weightTy.getDimSize(1); |
179 | int64_t weightWidth = weightTy.getDimSize(2); |
180 | int64_t inputChannels = weightTy.getDimSize(3); |
181 | |
182 | // Pad the weight so that it is modulo of the striding. |
183 | llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; |
184 | weightPadding[3] = |
185 | weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0; |
186 | weightPadding[5] = |
187 | weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; |
188 | DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( |
189 | RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); |
190 | Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>( |
191 | rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); |
192 | |
193 | if (op.getQuantizationInfo().has_value()) { |
194 | auto quantInfo = op.getQuantizationInfo().value(); |
195 | weight = createOpAndInfer<tosa::PadOp>( |
196 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
197 | weightPaddingVal, nullptr, |
198 | rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp())); |
199 | |
200 | } else { |
201 | weight = createOpAndInfer<tosa::PadOp>(rewriter, loc, |
202 | UnrankedTensorType::get(weightETy), |
203 | weight, weightPaddingVal); |
204 | } |
205 | |
206 | weightTy = cast<ShapedType>(weight.getType()); |
207 | weightHeight = weightTy.getDimSize(1); |
208 | weightWidth = weightTy.getDimSize(2); |
209 | |
210 | // Split out the width / height by the stride dimensions. |
211 | llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { |
212 | outputChannels, weightHeight / stride[0], |
213 | stride[0], weightWidth / stride[1], |
214 | stride[1], inputChannels}; |
215 | weight = createOpAndInfer<tosa::ReshapeOp>( |
216 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
217 | rewriter.getDenseI64ArrayAttr(weightReshapeDims0)); |
218 | |
219 | // Transpose the factored-out stride to the output channels. |
220 | Value transposeWeightVal = rewriter.create<tosa::ConstOp>( |
221 | loc, RankedTensorType::get({6}, rewriter.getI32Type()), |
222 | rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); |
223 | |
224 | weight = createOpAndInfer<tosa::TransposeOp>( |
225 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
226 | transposeWeightVal); |
227 | |
228 | // Collapse the strides and output channels into a single dimension. |
229 | llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { |
230 | outputChannels * stride[0] * stride[1], weightHeight / stride[0], |
231 | weightWidth / stride[1], inputChannels}; |
232 | weight = createOpAndInfer<tosa::ReshapeOp>( |
233 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
234 | rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); |
235 | ShapedType restridedWeightTy = cast<ShapedType>(weight.getType()); |
236 | |
237 | weight = createOpAndInfer<tosa::ReverseOp>( |
238 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
239 | /* axis = */ rewriter.getI32IntegerAttr(1)); |
240 | weight = createOpAndInfer<tosa::ReverseOp>( |
241 | rewriter, loc, UnrankedTensorType::get(weightETy), weight, |
242 | /* axis = */ rewriter.getI32IntegerAttr(2)); |
243 | |
244 | // We need to pad the input far enough that we can pull all values. |
245 | llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; |
246 | inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; |
247 | inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; |
248 | inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; |
249 | inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; |
250 | |
251 | DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( |
252 | RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); |
253 | |
254 | Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>( |
255 | rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); |
256 | |
257 | if (op.getQuantizationInfo().has_value()) { |
258 | auto quantInfo = op.getQuantizationInfo().value(); |
259 | input = createOpAndInfer<tosa::PadOp>( |
260 | rewriter, loc, UnrankedTensorType::get(inputETy), input, |
261 | inputPaddingVal, nullptr, |
262 | rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp())); |
263 | } else { |
264 | input = createOpAndInfer<tosa::PadOp>(rewriter, loc, |
265 | UnrankedTensorType::get(inputETy), |
266 | input, inputPaddingVal); |
267 | } |
268 | |
269 | // We use a zero bias as we need to broadcast the bias. |
270 | auto zeroBias = rewriter.create<tosa::ConstOp>( |
271 | loc, |
272 | RankedTensorType::get({outputChannels * stride[0] * stride[1]}, |
273 | biasETy), |
274 | DenseElementsAttr::get( |
275 | RankedTensorType::get({outputChannels * stride[0] * stride[1]}, |
276 | biasETy), |
277 | rewriter.getZeroAttr(biasETy))); |
278 | |
279 | // Perform the convolution using the zero bias. |
280 | Value conv2d; |
281 | if (op.getQuantizationInfo()) { |
282 | conv2d = createOpAndInfer<tosa::Conv2DOp>( |
283 | rewriter, loc, UnrankedTensorType::get(resultETy), input, |
284 | weight, zeroBias, |
285 | /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), |
286 | /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), |
287 | /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), |
288 | *op.getQuantizationInfo()) |
289 | .getResult(); |
290 | } else { |
291 | conv2d = createOpAndInfer<tosa::Conv2DOp>( |
292 | rewriter, loc, UnrankedTensorType::get(resultETy), input, |
293 | weight, zeroBias, |
294 | /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), |
295 | /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), |
296 | /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1})) |
297 | .getResult(); |
298 | } |
299 | |
300 | // Factor the resulting width / height. |
301 | ShapedType convTy = cast<ShapedType>(conv2d.getType()); |
302 | Type convETy = convTy.getElementType(); |
303 | |
304 | int64_t convHeight = convTy.getDimSize(1); |
305 | int64_t convWidth = convTy.getDimSize(2); |
306 | |
307 | // Factor striding out of the convolution result. |
308 | llvm::SmallVector<int64_t, 6> convReshapeDims0 = { |
309 | batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; |
310 | conv2d = createOpAndInfer<tosa::ReshapeOp>( |
311 | rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, |
312 | rewriter.getDenseI64ArrayAttr(convReshapeDims0)); |
313 | |
314 | // Transpose the factored-out stride to the output channels. |
315 | Value transposeConvVal = rewriter.create<tosa::ConstOp>( |
316 | loc, RankedTensorType::get({6}, rewriter.getI32Type()), |
317 | rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); |
318 | |
319 | conv2d = createOpAndInfer<tosa::TransposeOp>( |
320 | rewriter, loc, UnrankedTensorType::get(convETy), conv2d, |
321 | transposeConvVal); |
322 | |
323 | // Fuse striding behavior back into width / height. |
324 | llvm::SmallVector<int64_t, 6> convReshapeDims1 = { |
325 | batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; |
326 | conv2d = createOpAndInfer<tosa::ReshapeOp>( |
327 | rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, |
328 | rewriter.getDenseI64ArrayAttr(convReshapeDims1)); |
329 | |
330 | // Determine the amount to slice / pad from the result start. |
331 | int64_t resultSliceTop = std::max<int64_t>(a: 0, b: -pad[0]); |
332 | int64_t resultSliceLeft = std::max<int64_t>(a: 0, b: -pad[2]); |
333 | int64_t resultPadTop = std::max<int64_t>(a: 0, b: pad[0]); |
334 | int64_t resultPadLeft = std::max<int64_t>(a: 0, b: pad[2]); |
335 | |
336 | // Try to slice the targetted result size, cap to the convolutions width. |
337 | int64_t resultSliceHeight = |
338 | std::min<int64_t>(convReshapeDims1[1] - resultSliceTop, |
339 | resultTy.getDimSize(1) - resultPadTop); |
340 | int64_t resultSliceWidth = |
341 | std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft, |
342 | resultTy.getDimSize(2) - resultPadLeft); |
343 | |
344 | llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop, |
345 | resultSliceLeft, 0}; |
346 | llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(), |
347 | convReshapeDims1.end()); |
348 | sliceSize[1] = resultSliceHeight; |
349 | sliceSize[2] = resultSliceWidth; |
350 | |
351 | auto slice = createOpAndInfer<tosa::SliceOp>( |
352 | rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, |
353 | rewriter.getDenseI64ArrayAttr(sliceBegin), |
354 | rewriter.getDenseI64ArrayAttr(sliceSize)) |
355 | .getResult(); |
356 | |
357 | llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; |
358 | resultPadding[2] = resultPadTop; |
359 | resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1]; |
360 | resultPadding[4] = resultPadLeft; |
361 | resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; |
362 | |
363 | DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get( |
364 | RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding); |
365 | |
366 | Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>( |
367 | rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr); |
368 | |
369 | Value resultPad = createOpAndInfer<tosa::PadOp>( |
370 | rewriter, loc, UnrankedTensorType::get(resultETy), slice, |
371 | resultPaddingVal); |
372 | |
373 | if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) { |
374 | return failure(); |
375 | } |
376 | |
377 | rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias); |
378 | return success(); |
379 | } |
380 | }; |
381 | |
382 | } // namespace |
383 | |
384 | void mlir::tosa::populateTosaDecomposeTransposeConv( |
385 | MLIRContext *ctx, RewritePatternSet &patterns) { |
386 | patterns.add<TransposeConvNonStridedConverter>(arg&: ctx); |
387 | patterns.add<TransposeConvStridedConverter>(arg&: ctx); |
388 | } |
389 | |