1//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12// including transposing/reversing/reshaping etc..
13// of the weights and input/output tenors and reversing/reshaping etc .. of
14// the weights
15//
16//===----------------------------------------------------------------------===//
17
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Transforms/Passes.h"
20#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
21#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
22#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::tosa;
26
27namespace {
28
29class TransposeConvNonStridedConverter
30 : public OpRewritePattern<tosa::TransposeConv2DOp> {
31public:
32 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
33 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
34 PatternRewriter &rewriter) const final {
35 Location loc = op->getLoc();
36 Value input = op->getOperand(0);
37 Value weight = op->getOperand(1);
38 Value bias = op->getOperand(2);
39
40 ShapedType inputTy = cast<ShapedType>(input.getType());
41 ShapedType weightTy = cast<ShapedType>(weight.getType());
42 ShapedType biasTy = cast<ShapedType>(bias.getType());
43 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
44
45 llvm::ArrayRef<int64_t> stride = op.getStride();
46 llvm::ArrayRef<int64_t> pad = op.getOutPad();
47
48 // If striding is all 1 we can modify padding and reverse the kernel along
49 // the x/y direction to make it a regular convolution. This is much simpler
50 // then handling striding....
51 if (llvm::any_of(Range&: stride, P: [](int64_t v) { return v != 1; }))
52 return failure();
53
54 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
55 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
56 return failure();
57
58 int64_t kernelHeight = weightTy.getDimSize(1);
59 int64_t kernelWidth = weightTy.getDimSize(2);
60
61 llvm::SmallVector<int64_t> convPad(4, 0);
62 convPad[0] = kernelHeight - 1 + pad[0];
63 convPad[1] = kernelHeight - 1 + pad[1];
64 convPad[2] = kernelWidth - 1 + pad[2];
65 convPad[3] = kernelWidth - 1 + pad[3];
66
67 auto reverse1 = rewriter.create<tosa::ReverseOp>(
68 loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
69 auto reverse2 = rewriter.create<tosa::ReverseOp>(
70 loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
71
72 Value conv2d = rewriter.create<tosa::Conv2DOp>(
73 loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(),
74 rewriter.getDenseI64ArrayAttr(convPad),
75 rewriter.getDenseI64ArrayAttr(stride),
76 rewriter.getDenseI64ArrayAttr({1, 1}),
77 /* acc_type = */ op.getAccType());
78
79 rewriter.replaceOp(op, conv2d);
80 return success();
81 }
82};
83
84class TransposeConvStridedConverter
85 : public OpRewritePattern<tosa::TransposeConv2DOp> {
86public:
87 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
88 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
89 PatternRewriter &rewriter) const final {
90 Location loc = op->getLoc();
91 Value input = op->getOperand(0);
92 Value weight = op->getOperand(1);
93 Value bias = op->getOperand(2);
94
95 ShapedType inputTy = cast<ShapedType>(input.getType());
96 ShapedType weightTy = cast<ShapedType>(weight.getType());
97 ShapedType biasTy = cast<ShapedType>(bias.getType());
98 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
99
100 Type inputETy = inputTy.getElementType();
101 Type weightETy = weightTy.getElementType();
102 Type biasETy = biasTy.getElementType();
103 Type resultETy = resultTy.getElementType();
104
105 llvm::ArrayRef<int64_t> pad = op.getOutPad();
106 llvm::ArrayRef<int64_t> stride = op.getStride();
107
108 // If striding is all 1 we can modify padding and reverse the kernel along
109 // the x/y direction to make it a regular convolution. This is much simpler
110 // then handling striding....
111
112 // If strides are all 1 we dont need to use this one.
113 if (llvm::all_of(Range&: stride, P: [](int64_t v) { return v == 1; }))
114 return rewriter.notifyMatchFailure(op, "non-one stride found.");
115
116 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
117 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
118 return failure();
119
120 int64_t batch = inputTy.getDimSize(0);
121
122 int64_t outputChannels = weightTy.getDimSize(0);
123 int64_t weightHeight = weightTy.getDimSize(1);
124 int64_t weightWidth = weightTy.getDimSize(2);
125 int64_t inputChannels = weightTy.getDimSize(3);
126
127 // Pad the weight so that it is modulo of the striding.
128 llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
129 weightPadding[3] =
130 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
131 weightPadding[5] =
132 (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
133
134 Value weightPaddingVal =
135 getTosaConstShape(rewriter, op->getLoc(), weightPadding);
136
137 // Get and verify zero points.
138 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
139 if (failed(Result: maybeIZp))
140 return rewriter.notifyMatchFailure(
141 op, "input zero point cannot be statically determined");
142
143 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
144 if (failed(Result: maybeWZp))
145 return rewriter.notifyMatchFailure(
146 op, "weight zero point cannot be statically determined");
147
148 int64_t inputZpVal = *maybeIZp;
149 int64_t weightZpVal = *maybeWZp;
150
151 if (op.verifyInputZeroPoint(inputZpVal).failed())
152 return rewriter.notifyMatchFailure(
153 op, "input zero point must be zero for non-int8 integer types");
154
155 if (op.verifyWeightZeroPoint(weightZpVal).failed())
156 return rewriter.notifyMatchFailure(
157 op, "weight zero point must be zero for non-int8 integer types");
158
159 // construct pad_const values from zp values
160 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
161 const Value inputPadConst =
162 createPadConstTensor(builder, op->getLoc(), input, inputZpVal);
163 const Value weightPadConst =
164 createPadConstTensor(builder, op->getLoc(), input, weightZpVal);
165
166 weight = CreateOpAndInferShape<tosa::PadOp>(
167 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
168 weightPaddingVal, weightPadConst);
169
170 weightTy = cast<ShapedType>(weight.getType());
171 weightHeight = weightTy.getDimSize(1);
172 weightWidth = weightTy.getDimSize(2);
173
174 // Split out the width / height by the stride dimensions.
175 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
176 outputChannels, weightHeight / stride[0],
177 stride[0], weightWidth / stride[1],
178 stride[1], inputChannels};
179
180 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
181 builder, UnrankedTensorType::get(weightETy), weight,
182 getTosaConstShape(rewriter, loc, weightReshapeDims0));
183
184 // Transpose the factored-out stride to the output channels.
185 weight = CreateOpAndInferShape<tosa::TransposeOp>(
186 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
187 rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
188
189 // Collapse the strides and output channels into a single dimension.
190 llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
191 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
192 weightWidth / stride[1], inputChannels};
193
194 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
195 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
196 getTosaConstShape(rewriter, loc, weightReshapeDims1));
197 ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
198
199 weight = CreateOpAndInferShape<tosa::ReverseOp>(
200 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
201 /* axis = */ rewriter.getI32IntegerAttr(1));
202 weight = CreateOpAndInferShape<tosa::ReverseOp>(
203 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
204 /* axis = */ rewriter.getI32IntegerAttr(2));
205
206 // We need to pad the input far enough that we can pull all values.
207 llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
208 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
209 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
210 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
211 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
212
213 Value inputPaddingVal =
214 getTosaConstShape(rewriter, op->getLoc(), inputPadding);
215
216 input = CreateOpAndInferShape<tosa::PadOp>(
217 rewriter, loc, UnrankedTensorType::get(inputETy), input,
218 inputPaddingVal, inputPadConst);
219
220 // We use a zero bias as we need to broadcast the bias.
221 auto zeroBias = rewriter.create<tosa::ConstOp>(
222 loc,
223 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
224 biasETy),
225 DenseElementsAttr::get(
226 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
227 biasETy),
228 rewriter.getZeroAttr(biasETy)));
229
230 auto inputZp =
231 createZeroPointTensor(builder&: rewriter, loc, srcElemType: input.getType(), zp: inputZpVal);
232 auto weightZp =
233 createZeroPointTensor(builder&: rewriter, loc, srcElemType: weight.getType(), zp: weightZpVal);
234
235 if (!inputZp.has_value() || !weightZp.has_value()) {
236 return rewriter.notifyMatchFailure(
237 op, "fail to create a const zero point tensor");
238 }
239
240 // Perform the convolution using the zero bias.
241 Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
242 rewriter, loc, UnrankedTensorType::get(resultETy), input,
243 weight, zeroBias, inputZp.value(), weightZp.value(),
244 /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
245 /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
246 /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
247 /* acc_type = */ op.getAccType())
248 .getResult();
249
250 // Factor the resulting width / height.
251 ShapedType convTy = cast<ShapedType>(conv2d.getType());
252 Type convETy = convTy.getElementType();
253
254 int64_t convHeight = convTy.getDimSize(1);
255 int64_t convWidth = convTy.getDimSize(2);
256
257 // Factor striding out of the convolution result.
258 llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
259 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
260
261 auto convReshapeDims0Value =
262 getTosaConstShape(rewriter, loc, shape: convReshapeDims0);
263
264 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
265 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
266 convReshapeDims0Value);
267
268 // Transpose the factored-out stride to the output channels.
269 conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
270 rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
271 rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
272
273 // Fuse striding behavior back into width / height.
274 llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
275 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
276
277 auto convReshapeDims1Value =
278 getTosaConstShape(rewriter, loc, shape: convReshapeDims1);
279
280 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
281 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
282 convReshapeDims1Value);
283
284 // Determine the amount to slice / pad from the result start.
285 int64_t resultSliceTop = std::max<int64_t>(a: 0, b: -pad[0]);
286 int64_t resultSliceLeft = std::max<int64_t>(a: 0, b: -pad[2]);
287 int64_t resultPadTop = std::max<int64_t>(a: 0, b: pad[0]);
288 int64_t resultPadLeft = std::max<int64_t>(a: 0, b: pad[2]);
289
290 // Try to slice the targetted result size, cap to the convolutions width.
291 int64_t resultSliceHeight =
292 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
293 resultTy.getDimSize(1) - resultPadTop);
294 int64_t resultSliceWidth =
295 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
296 resultTy.getDimSize(2) - resultPadLeft);
297
298 llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
299 resultSliceLeft, 0};
300 llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
301 convReshapeDims1.end());
302 sliceSize[1] = resultSliceHeight;
303 sliceSize[2] = resultSliceWidth;
304
305 auto slice = CreateOpAndInferShape<tosa::SliceOp>(
306 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
307 getTosaConstShape(rewriter, loc, sliceBegin),
308 getTosaConstShape(rewriter, loc, sliceSize))
309 .getResult();
310
311 llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
312 resultPadding[2] = resultPadTop;
313 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
314 resultPadding[4] = resultPadLeft;
315 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
316
317 Value resultPaddingVal =
318 getTosaConstShape(rewriter, op->getLoc(), resultPadding);
319
320 Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
321 rewriter, loc, UnrankedTensorType::get(resultETy), slice,
322 resultPaddingVal);
323
324 if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
325 return failure();
326 }
327
328 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
329 return success();
330 }
331};
332
333} // namespace
334
335void mlir::tosa::populateTosaDecomposeTransposeConv(
336 MLIRContext *ctx, RewritePatternSet &patterns) {
337 patterns.add<TransposeConvNonStridedConverter>(arg&: ctx);
338 patterns.add<TransposeConvStridedConverter>(arg&: ctx);
339}
340

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp