1//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12// including transposing/reversing/reshaping etc..
13// of the weights and input/output tenors and reversing/reshaping etc .. of
14// the weights
15//
16//===----------------------------------------------------------------------===//
17
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Transforms/Passes.h"
20#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
21
22using namespace mlir;
23using namespace mlir::tosa;
24
25namespace {
26
27class TransposeConvNonStridedConverter
28 : public OpRewritePattern<tosa::TransposeConv2DOp> {
29public:
30 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
31 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
32 PatternRewriter &rewriter) const final {
33 Location loc = op->getLoc();
34 Value input = op->getOperand(idx: 0);
35 Value weight = op->getOperand(idx: 1);
36 Value bias = op->getOperand(idx: 2);
37
38 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
39 ShapedType weightTy = cast<ShapedType>(Val: weight.getType());
40 ShapedType biasTy = cast<ShapedType>(Val: bias.getType());
41 ShapedType resultTy = cast<ShapedType>(Val: op->getResult(idx: 0).getType());
42
43 llvm::ArrayRef<int64_t> stride = op.getStride();
44 llvm::ArrayRef<int64_t> pad = op.getOutPad();
45
46 // If striding is all 1 we can modify padding and reverse the kernel along
47 // the x/y direction to make it a regular convolution. This is much simpler
48 // then handling striding....
49 if (llvm::any_of(Range&: stride, P: [](int64_t v) { return v != 1; }))
50 return failure();
51
52 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
53 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
54 return failure();
55
56 int64_t kernelHeight = weightTy.getDimSize(idx: 1);
57 int64_t kernelWidth = weightTy.getDimSize(idx: 2);
58
59 llvm::SmallVector<int64_t> convPad(4, 0);
60 convPad[0] = kernelHeight - 1 + pad[0];
61 convPad[1] = kernelHeight - 1 + pad[1];
62 convPad[2] = kernelWidth - 1 + pad[2];
63 convPad[3] = kernelWidth - 1 + pad[3];
64
65 auto reverse1 = rewriter.create<tosa::ReverseOp>(
66 location: loc, args&: weightTy, args&: weight, /* axis = */ args: rewriter.getI32IntegerAttr(value: 1));
67 auto reverse2 = rewriter.create<tosa::ReverseOp>(
68 location: loc, args&: weightTy, args&: reverse1, /* axis = */ args: rewriter.getI32IntegerAttr(value: 2));
69
70 Value conv2d = rewriter.create<tosa::Conv2DOp>(
71 location: loc, args&: resultTy, args&: input, args&: reverse2, args&: bias, args: op.getInputZp(), args: op.getWeightZp(),
72 args: rewriter.getDenseI64ArrayAttr(values: convPad),
73 args: rewriter.getDenseI64ArrayAttr(values: stride),
74 args: rewriter.getDenseI64ArrayAttr(values: {1, 1}),
75 /* acc_type = */ args: op.getAccType());
76
77 rewriter.replaceOp(op, newValues: conv2d);
78 return success();
79 }
80};
81
82class TransposeConvStridedConverter
83 : public OpRewritePattern<tosa::TransposeConv2DOp> {
84public:
85 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
86 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
87 PatternRewriter &rewriter) const final {
88 Location loc = op->getLoc();
89 Value input = op->getOperand(idx: 0);
90 Value weight = op->getOperand(idx: 1);
91 Value bias = op->getOperand(idx: 2);
92
93 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
94 ShapedType weightTy = cast<ShapedType>(Val: weight.getType());
95 ShapedType biasTy = cast<ShapedType>(Val: bias.getType());
96 ShapedType resultTy = cast<ShapedType>(Val: op->getResult(idx: 0).getType());
97
98 Type inputETy = inputTy.getElementType();
99 Type weightETy = weightTy.getElementType();
100 Type biasETy = biasTy.getElementType();
101 Type resultETy = resultTy.getElementType();
102
103 llvm::ArrayRef<int64_t> pad = op.getOutPad();
104 llvm::ArrayRef<int64_t> stride = op.getStride();
105
106 // If striding is all 1 we can modify padding and reverse the kernel along
107 // the x/y direction to make it a regular convolution. This is much simpler
108 // then handling striding....
109
110 // If strides are all 1 we dont need to use this one.
111 if (llvm::all_of(Range&: stride, P: [](int64_t v) { return v == 1; }))
112 return rewriter.notifyMatchFailure(arg&: op, msg: "non-one stride found.");
113
114 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
115 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
116 return failure();
117
118 int64_t batch = inputTy.getDimSize(idx: 0);
119
120 int64_t outputChannels = weightTy.getDimSize(idx: 0);
121 int64_t weightHeight = weightTy.getDimSize(idx: 1);
122 int64_t weightWidth = weightTy.getDimSize(idx: 2);
123 int64_t inputChannels = weightTy.getDimSize(idx: 3);
124
125 // Pad the weight so that it is modulo of the striding.
126 llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
127 weightPadding[3] =
128 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
129 weightPadding[5] =
130 (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
131
132 Value weightPaddingVal =
133 getTosaConstShape(rewriter, loc: op->getLoc(), shape: weightPadding);
134
135 // Get and verify zero points.
136 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
137 if (failed(Result: maybeIZp))
138 return rewriter.notifyMatchFailure(
139 arg&: op, msg: "input zero point cannot be statically determined");
140
141 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
142 if (failed(Result: maybeWZp))
143 return rewriter.notifyMatchFailure(
144 arg&: op, msg: "weight zero point cannot be statically determined");
145
146 int64_t inputZpVal = *maybeIZp;
147 int64_t weightZpVal = *maybeWZp;
148
149 if (op.verifyInputZeroPoint(zp: inputZpVal).failed())
150 return rewriter.notifyMatchFailure(
151 arg&: op, msg: "input zero point must be zero for non-int8 integer types");
152
153 if (op.verifyWeightZeroPoint(zp: weightZpVal).failed())
154 return rewriter.notifyMatchFailure(
155 arg&: op, msg: "weight zero point must be zero for non-int8 integer types");
156
157 // construct pad_const values from zp values
158 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
159 const Value inputPadConst =
160 createPadConstTensor(builder, loc: op->getLoc(), src: input, val: inputZpVal);
161 const Value weightPadConst =
162 createPadConstTensor(builder, loc: op->getLoc(), src: input, val: weightZpVal);
163
164 weight = CreateOpAndInferShape<tosa::PadOp>(
165 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: weightETy), args&: weight,
166 args&: weightPaddingVal, args: weightPadConst);
167
168 weightTy = cast<ShapedType>(Val: weight.getType());
169 weightHeight = weightTy.getDimSize(idx: 1);
170 weightWidth = weightTy.getDimSize(idx: 2);
171
172 // Split out the width / height by the stride dimensions.
173 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
174 outputChannels, weightHeight / stride[0],
175 stride[0], weightWidth / stride[1],
176 stride[1], inputChannels};
177
178 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
179 builder, resultTy: UnrankedTensorType::get(elementType: weightETy), args&: weight,
180 args: getTosaConstShape(rewriter, loc, shape: weightReshapeDims0));
181
182 // Transpose the factored-out stride to the output channels.
183 weight = CreateOpAndInferShape<tosa::TransposeOp>(
184 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: weightETy), args&: weight,
185 args: rewriter.getDenseI32ArrayAttr(values: {2, 4, 0, 1, 3, 5}));
186
187 // Collapse the strides and output channels into a single dimension.
188 llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
189 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
190 weightWidth / stride[1], inputChannels};
191
192 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
193 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: weightETy), args&: weight,
194 args: getTosaConstShape(rewriter, loc, shape: weightReshapeDims1));
195 ShapedType restridedWeightTy = cast<ShapedType>(Val: weight.getType());
196
197 weight = CreateOpAndInferShape<tosa::ReverseOp>(
198 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: weightETy), args&: weight,
199 /* axis = */ args: rewriter.getI32IntegerAttr(value: 1));
200 weight = CreateOpAndInferShape<tosa::ReverseOp>(
201 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: weightETy), args&: weight,
202 /* axis = */ args: rewriter.getI32IntegerAttr(value: 2));
203
204 // We need to pad the input far enough that we can pull all values.
205 llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
206 inputPadding[2] += restridedWeightTy.getDimSize(idx: 1) - 1;
207 inputPadding[3] += restridedWeightTy.getDimSize(idx: 1) - 1;
208 inputPadding[4] += restridedWeightTy.getDimSize(idx: 2) - 1;
209 inputPadding[5] += restridedWeightTy.getDimSize(idx: 2) - 1;
210
211 Value inputPaddingVal =
212 getTosaConstShape(rewriter, loc: op->getLoc(), shape: inputPadding);
213
214 input = CreateOpAndInferShape<tosa::PadOp>(
215 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: inputETy), args&: input,
216 args&: inputPaddingVal, args: inputPadConst);
217
218 // We use a zero bias as we need to broadcast the bias.
219 auto zeroBias = rewriter.create<tosa::ConstOp>(
220 location: loc,
221 args: RankedTensorType::get(shape: {outputChannels * stride[0] * stride[1]},
222 elementType: biasETy),
223 args: DenseElementsAttr::get(
224 type: RankedTensorType::get(shape: {outputChannels * stride[0] * stride[1]},
225 elementType: biasETy),
226 values: rewriter.getZeroAttr(type: biasETy)));
227
228 auto inputZp =
229 createZeroPointTensor(builder&: rewriter, loc, srcElemType: input.getType(), zp: inputZpVal);
230 auto weightZp =
231 createZeroPointTensor(builder&: rewriter, loc, srcElemType: weight.getType(), zp: weightZpVal);
232
233 if (!inputZp.has_value() || !weightZp.has_value()) {
234 return rewriter.notifyMatchFailure(
235 arg&: op, msg: "fail to create a const zero point tensor");
236 }
237
238 // Perform the convolution using the zero bias.
239 Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
240 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: resultETy), args&: input,
241 args&: weight, args&: zeroBias, args&: inputZp.value(), args&: weightZp.value(),
242 /*pad=*/args: rewriter.getDenseI64ArrayAttr(values: {0, 0, 0, 0}),
243 /*stride=*/args: rewriter.getDenseI64ArrayAttr(values: {1, 1}),
244 /*dilation=*/args: rewriter.getDenseI64ArrayAttr(values: {1, 1}),
245 /* acc_type = */ args: op.getAccType())
246 .getResult();
247
248 // Factor the resulting width / height.
249 ShapedType convTy = cast<ShapedType>(Val: conv2d.getType());
250 Type convETy = convTy.getElementType();
251
252 int64_t convHeight = convTy.getDimSize(idx: 1);
253 int64_t convWidth = convTy.getDimSize(idx: 2);
254
255 // Factor striding out of the convolution result.
256 llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
257 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
258
259 auto convReshapeDims0Value =
260 getTosaConstShape(rewriter, loc, shape: convReshapeDims0);
261
262 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
263 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: resultETy), args&: conv2d,
264 args&: convReshapeDims0Value);
265
266 // Transpose the factored-out stride to the output channels.
267 conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
268 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: convETy), args&: conv2d,
269 args: rewriter.getDenseI32ArrayAttr(values: {0, 1, 3, 2, 4, 5}));
270
271 // Fuse striding behavior back into width / height.
272 llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
273 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
274
275 auto convReshapeDims1Value =
276 getTosaConstShape(rewriter, loc, shape: convReshapeDims1);
277
278 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
279 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: resultETy), args&: conv2d,
280 args&: convReshapeDims1Value);
281
282 // Determine the amount to slice / pad from the result start.
283 int64_t resultSliceTop = std::max<int64_t>(a: 0, b: -pad[0]);
284 int64_t resultSliceLeft = std::max<int64_t>(a: 0, b: -pad[2]);
285 int64_t resultPadTop = std::max<int64_t>(a: 0, b: pad[0]);
286 int64_t resultPadLeft = std::max<int64_t>(a: 0, b: pad[2]);
287
288 // Try to slice the targetted result size, cap to the convolutions width.
289 int64_t resultSliceHeight =
290 std::min<int64_t>(a: convReshapeDims1[1] - resultSliceTop,
291 b: resultTy.getDimSize(idx: 1) - resultPadTop);
292 int64_t resultSliceWidth =
293 std::min<int64_t>(a: convReshapeDims1[2] - resultSliceLeft,
294 b: resultTy.getDimSize(idx: 2) - resultPadLeft);
295
296 llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
297 resultSliceLeft, 0};
298 llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
299 convReshapeDims1.end());
300 sliceSize[1] = resultSliceHeight;
301 sliceSize[2] = resultSliceWidth;
302
303 auto slice = CreateOpAndInferShape<tosa::SliceOp>(
304 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: resultETy), args&: conv2d,
305 args: getTosaConstShape(rewriter, loc, shape: sliceBegin),
306 args: getTosaConstShape(rewriter, loc, shape: sliceSize))
307 .getResult();
308
309 llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
310 resultPadding[2] = resultPadTop;
311 resultPadding[3] = resultTy.getDimSize(idx: 1) - resultPadTop - sliceSize[1];
312 resultPadding[4] = resultPadLeft;
313 resultPadding[5] = resultTy.getDimSize(idx: 2) - resultPadLeft - sliceSize[2];
314
315 Value resultPaddingVal =
316 getTosaConstShape(rewriter, loc: op->getLoc(), shape: resultPadding);
317
318 Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
319 rewriter, loc, resultTy: UnrankedTensorType::get(elementType: resultETy), args&: slice,
320 args&: resultPaddingVal);
321
322 if (EqualizeRanks(rewriter, loc: op.getLoc(), input1&: resultPad, input2&: bias).failed()) {
323 return failure();
324 }
325
326 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, args: op.getType(), args&: resultPad, args&: bias);
327 return success();
328 }
329};
330
331} // namespace
332
333void mlir::tosa::populateTosaDecomposeTransposeConv(
334 MLIRContext *ctx, RewritePatternSet &patterns) {
335 patterns.add<TransposeConvNonStridedConverter>(arg&: ctx);
336 patterns.add<TransposeConvStridedConverter>(arg&: ctx);
337}
338

source code of mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp