1//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12// including transposing/reversing/reshaping etc..
13// of the weights and input/output tenors and reversing/reshaping etc .. of
14// the weights
15//
16//===----------------------------------------------------------------------===//
17
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Transforms/Passes.h"
20#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
21#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
22#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::tosa;
26
27namespace {
28
29template <typename TosaOp, typename... Args>
30TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
31 Args &&...args) {
32 auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
33
34 InferShapedTypeOpInterface shapeInterface =
35 dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
36 if (!shapeInterface)
37 return op;
38
39 SmallVector<ShapedTypeComponents> returnedShapes;
40 if (shapeInterface
41 .inferReturnTypeComponents(
42 op.getContext(), op.getLoc(), op->getOperands(),
43 op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
44 op->getRegions(), returnedShapes)
45 .failed())
46 return op;
47
48 // We need to use the element type of the existing result type to generate
49 // the new result shaped type. This is because rescale can include a cast to
50 // different bit-width types and does not have a TypeAttr to define the
51 // target type.
52 auto result = op->getResult(0);
53 auto predictedShape = returnedShapes[0];
54 auto currentKnowledge =
55 mlir::tosa::ValueKnowledge::getKnowledgeFromType(type: resultTy);
56
57 // Compute the knowledge based on the inferred type.
58 auto inferredKnowledge =
59 mlir::tosa::ValueKnowledge::getPessimisticValueState();
60 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
61 inferredKnowledge.hasRank = predictedShape.hasRank();
62 if (predictedShape.hasRank()) {
63 for (auto dim : predictedShape.getDims()) {
64 inferredKnowledge.sizes.push_back(Elt: dim);
65 }
66 }
67
68 // Compute the new type based on the joined version.
69 auto newKnowledge =
70 mlir::tosa::ValueKnowledge::join(lhs: currentKnowledge, rhs: inferredKnowledge);
71 auto newTy = newKnowledge.getType();
72 result.setType(newTy);
73 return op;
74}
75
76class TransposeConvNonStridedConverter
77 : public OpRewritePattern<tosa::TransposeConv2DOp> {
78public:
79 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
80 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
81 PatternRewriter &rewriter) const final {
82 Location loc = op->getLoc();
83 Value input = op->getOperand(0);
84 Value weight = op->getOperand(1);
85 Value bias = op->getOperand(2);
86
87 ShapedType inputTy = cast<ShapedType>(input.getType());
88 ShapedType weightTy = cast<ShapedType>(weight.getType());
89 ShapedType biasTy = cast<ShapedType>(bias.getType());
90 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
91
92 llvm::ArrayRef<int64_t> stride = op.getStride();
93 llvm::ArrayRef<int64_t> pad = op.getOutPad();
94
95 // If striding is all 1 we can modify padding and reverse the kernel along
96 // the x/y direction to make it a regular convolution. This is much simpler
97 // then handling striding....
98 if (llvm::any_of(Range&: stride, P: [](int64_t v) { return v != 1; }))
99 return failure();
100
101 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
102 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
103 return failure();
104
105 int64_t kernelHeight = weightTy.getDimSize(1);
106 int64_t kernelWidth = weightTy.getDimSize(2);
107
108 llvm::SmallVector<int64_t> convPad(4, 0);
109 convPad[0] = kernelHeight - 1 + pad[0];
110 convPad[1] = kernelHeight - 1 + pad[1];
111 convPad[2] = kernelWidth - 1 + pad[2];
112 convPad[3] = kernelWidth - 1 + pad[3];
113
114 auto reverse1 = rewriter.create<tosa::ReverseOp>(
115 loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
116 auto reverse2 = rewriter.create<tosa::ReverseOp>(
117 loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
118
119 Value conv2d;
120 if (op.getQuantizationInfo()) {
121 conv2d = rewriter.create<tosa::Conv2DOp>(
122 loc, resultTy, input, reverse2, bias,
123 rewriter.getDenseI64ArrayAttr(convPad),
124 rewriter.getDenseI64ArrayAttr(stride),
125 rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
126 } else {
127 conv2d = rewriter.create<tosa::Conv2DOp>(
128 loc, resultTy, input, reverse2, bias,
129 rewriter.getDenseI64ArrayAttr(convPad),
130 rewriter.getDenseI64ArrayAttr(stride),
131 rewriter.getDenseI64ArrayAttr({1, 1}));
132 }
133
134 rewriter.replaceOp(op, conv2d);
135 return success();
136 }
137};
138
139class TransposeConvStridedConverter
140 : public OpRewritePattern<tosa::TransposeConv2DOp> {
141public:
142 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
143 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
144 PatternRewriter &rewriter) const final {
145 Location loc = op->getLoc();
146 Value input = op->getOperand(0);
147 Value weight = op->getOperand(1);
148 Value bias = op->getOperand(2);
149
150 ShapedType inputTy = cast<ShapedType>(input.getType());
151 ShapedType weightTy = cast<ShapedType>(weight.getType());
152 ShapedType biasTy = cast<ShapedType>(bias.getType());
153 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
154
155 Type inputETy = inputTy.getElementType();
156 Type weightETy = weightTy.getElementType();
157 Type biasETy = biasTy.getElementType();
158 Type resultETy = resultTy.getElementType();
159
160 llvm::ArrayRef<int64_t> pad = op.getOutPad();
161 llvm::ArrayRef<int64_t> stride = op.getStride();
162
163 // If striding is all 1 we can modify padding and reverse the kernel along
164 // the x/y direction to make it a regular convolution. This is much simpler
165 // then handling striding....
166
167 // If strides are all 1 we dont need to use this one.
168 if (llvm::all_of(Range&: stride, P: [](int64_t v) { return v == 1; }))
169 return rewriter.notifyMatchFailure(op, "non-one stride found.");
170
171 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
172 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
173 return failure();
174
175 int64_t batch = inputTy.getDimSize(0);
176
177 int64_t outputChannels = weightTy.getDimSize(0);
178 int64_t weightHeight = weightTy.getDimSize(1);
179 int64_t weightWidth = weightTy.getDimSize(2);
180 int64_t inputChannels = weightTy.getDimSize(3);
181
182 // Pad the weight so that it is modulo of the striding.
183 llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
184 weightPadding[3] =
185 weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
186 weightPadding[5] =
187 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
188 DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
189 RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
190 Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
191 rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
192
193 if (op.getQuantizationInfo().has_value()) {
194 auto quantInfo = op.getQuantizationInfo().value();
195 weight = createOpAndInfer<tosa::PadOp>(
196 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
197 weightPaddingVal, nullptr,
198 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
199
200 } else {
201 weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
202 UnrankedTensorType::get(weightETy),
203 weight, weightPaddingVal);
204 }
205
206 weightTy = cast<ShapedType>(weight.getType());
207 weightHeight = weightTy.getDimSize(1);
208 weightWidth = weightTy.getDimSize(2);
209
210 // Split out the width / height by the stride dimensions.
211 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
212 outputChannels, weightHeight / stride[0],
213 stride[0], weightWidth / stride[1],
214 stride[1], inputChannels};
215 weight = createOpAndInfer<tosa::ReshapeOp>(
216 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
217 rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
218
219 // Transpose the factored-out stride to the output channels.
220 Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
221 loc, RankedTensorType::get({6}, rewriter.getI32Type()),
222 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
223
224 weight = createOpAndInfer<tosa::TransposeOp>(
225 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
226 transposeWeightVal);
227
228 // Collapse the strides and output channels into a single dimension.
229 llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
230 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
231 weightWidth / stride[1], inputChannels};
232 weight = createOpAndInfer<tosa::ReshapeOp>(
233 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
234 rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
235 ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
236
237 weight = createOpAndInfer<tosa::ReverseOp>(
238 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
239 /* axis = */ rewriter.getI32IntegerAttr(1));
240 weight = createOpAndInfer<tosa::ReverseOp>(
241 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
242 /* axis = */ rewriter.getI32IntegerAttr(2));
243
244 // We need to pad the input far enough that we can pull all values.
245 llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
246 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
247 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
248 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
249 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
250
251 DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
252 RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
253
254 Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
255 rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
256
257 if (op.getQuantizationInfo().has_value()) {
258 auto quantInfo = op.getQuantizationInfo().value();
259 input = createOpAndInfer<tosa::PadOp>(
260 rewriter, loc, UnrankedTensorType::get(inputETy), input,
261 inputPaddingVal, nullptr,
262 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
263 } else {
264 input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
265 UnrankedTensorType::get(inputETy),
266 input, inputPaddingVal);
267 }
268
269 // We use a zero bias as we need to broadcast the bias.
270 auto zeroBias = rewriter.create<tosa::ConstOp>(
271 loc,
272 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
273 biasETy),
274 DenseElementsAttr::get(
275 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
276 biasETy),
277 rewriter.getZeroAttr(biasETy)));
278
279 // Perform the convolution using the zero bias.
280 Value conv2d;
281 if (op.getQuantizationInfo()) {
282 conv2d = createOpAndInfer<tosa::Conv2DOp>(
283 rewriter, loc, UnrankedTensorType::get(resultETy), input,
284 weight, zeroBias,
285 /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
286 /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
287 /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
288 *op.getQuantizationInfo())
289 .getResult();
290 } else {
291 conv2d = createOpAndInfer<tosa::Conv2DOp>(
292 rewriter, loc, UnrankedTensorType::get(resultETy), input,
293 weight, zeroBias,
294 /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
295 /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
296 /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
297 .getResult();
298 }
299
300 // Factor the resulting width / height.
301 ShapedType convTy = cast<ShapedType>(conv2d.getType());
302 Type convETy = convTy.getElementType();
303
304 int64_t convHeight = convTy.getDimSize(1);
305 int64_t convWidth = convTy.getDimSize(2);
306
307 // Factor striding out of the convolution result.
308 llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
309 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
310 conv2d = createOpAndInfer<tosa::ReshapeOp>(
311 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
312 rewriter.getDenseI64ArrayAttr(convReshapeDims0));
313
314 // Transpose the factored-out stride to the output channels.
315 Value transposeConvVal = rewriter.create<tosa::ConstOp>(
316 loc, RankedTensorType::get({6}, rewriter.getI32Type()),
317 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
318
319 conv2d = createOpAndInfer<tosa::TransposeOp>(
320 rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
321 transposeConvVal);
322
323 // Fuse striding behavior back into width / height.
324 llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
325 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
326 conv2d = createOpAndInfer<tosa::ReshapeOp>(
327 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
328 rewriter.getDenseI64ArrayAttr(convReshapeDims1));
329
330 // Determine the amount to slice / pad from the result start.
331 int64_t resultSliceTop = std::max<int64_t>(a: 0, b: -pad[0]);
332 int64_t resultSliceLeft = std::max<int64_t>(a: 0, b: -pad[2]);
333 int64_t resultPadTop = std::max<int64_t>(a: 0, b: pad[0]);
334 int64_t resultPadLeft = std::max<int64_t>(a: 0, b: pad[2]);
335
336 // Try to slice the targetted result size, cap to the convolutions width.
337 int64_t resultSliceHeight =
338 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
339 resultTy.getDimSize(1) - resultPadTop);
340 int64_t resultSliceWidth =
341 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
342 resultTy.getDimSize(2) - resultPadLeft);
343
344 llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
345 resultSliceLeft, 0};
346 llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
347 convReshapeDims1.end());
348 sliceSize[1] = resultSliceHeight;
349 sliceSize[2] = resultSliceWidth;
350
351 auto slice = createOpAndInfer<tosa::SliceOp>(
352 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
353 rewriter.getDenseI64ArrayAttr(sliceBegin),
354 rewriter.getDenseI64ArrayAttr(sliceSize))
355 .getResult();
356
357 llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
358 resultPadding[2] = resultPadTop;
359 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
360 resultPadding[4] = resultPadLeft;
361 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
362
363 DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
364 RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
365
366 Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
367 rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
368
369 Value resultPad = createOpAndInfer<tosa::PadOp>(
370 rewriter, loc, UnrankedTensorType::get(resultETy), slice,
371 resultPaddingVal);
372
373 if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
374 return failure();
375 }
376
377 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
378 return success();
379 }
380};
381
382} // namespace
383
384void mlir::tosa::populateTosaDecomposeTransposeConv(
385 MLIRContext *ctx, RewritePatternSet &patterns) {
386 patterns.add<TransposeConvNonStridedConverter>(arg&: ctx);
387 patterns.add<TransposeConvStridedConverter>(arg&: ctx);
388}
389

source code of mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp