1 | //===- TosaDecomposeConv2D.cpp --------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically |
10 | // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
15 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
16 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::tosa; |
20 | |
21 | namespace { |
22 | |
23 | SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) { |
24 | return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) { |
25 | return ShapedType::isDynamic(dim) ? -1 : dim; |
26 | })); |
27 | } |
28 | |
29 | struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { |
30 | explicit Conv2DIsFullyConnected(MLIRContext *context) |
31 | : OpRewritePattern(context) {} |
32 | |
33 | LogicalResult matchAndRewrite(tosa::Conv2DOp op, |
34 | PatternRewriter &rewriter) const override { |
35 | Value input = op.getInput(); |
36 | Value weight = op.getWeight(); |
37 | ShapedType inputType = cast<ShapedType>(input.getType()); |
38 | ShapedType weightType = cast<ShapedType>(weight.getType()); |
39 | ShapedType resultType = cast<ShapedType>(op.getType()); |
40 | |
41 | auto numDynamic = |
42 | llvm::count_if(inputType.getShape(), ShapedType::isDynamic); |
43 | if (numDynamic > 1) |
44 | return rewriter.notifyMatchFailure( |
45 | op, "at most one dim in input may be dynamic" ); |
46 | if (!weightType.hasRank()) |
47 | return rewriter.notifyMatchFailure(op, "unranked weight input" ); |
48 | |
49 | if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) |
50 | return failure(); |
51 | |
52 | // Only works for a 1x1 kernel. |
53 | ArrayRef<int64_t> weightShape = weightType.getShape(); |
54 | if (weightShape[1] != 1 || weightShape[2] != 1) |
55 | return failure(); |
56 | |
57 | llvm::ArrayRef<int64_t> padAttr = op.getPad(); |
58 | llvm::SmallVector<int64_t> pad(8, 0); |
59 | for (const auto &it : llvm::enumerate(padAttr)) |
60 | pad[it.index() + 2] = it.value(); |
61 | |
62 | if (llvm::any_of(Range&: pad, P: [](int64_t p) { return p != 0; })) { |
63 | Type inputETy = inputType.getElementType(); |
64 | Attribute zeroAttr = rewriter.getZeroAttr(inputETy); |
65 | if (op.getQuantizationInfo()) { |
66 | auto quantizationInfo = op.getQuantizationInfo(); |
67 | int64_t iZp = quantizationInfo->getInputZp(); |
68 | |
69 | if (!validIntegerRange(cast<IntegerType>(inputETy), iZp)) |
70 | return rewriter.notifyMatchFailure( |
71 | op, "tosa.conv op quantization has zp outside of input range" ); |
72 | |
73 | zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); |
74 | } |
75 | |
76 | llvm::SmallVector<int64_t> newShape(inputType.getShape()); |
77 | |
78 | for (int i = 0, s = newShape.size(); i < s; ++i) { |
79 | if (newShape[i] != ShapedType::kDynamic) { |
80 | newShape[i] += pad[i * 2] + pad[i * 2 + 1]; |
81 | } |
82 | } |
83 | |
84 | auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type()); |
85 | auto padSize = |
86 | DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad)); |
87 | Value padSizeVal = |
88 | rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize); |
89 | |
90 | auto padTy = RankedTensorType::get({}, inputETy); |
91 | auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); |
92 | Value padVal = |
93 | rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr); |
94 | inputType = RankedTensorType::get(newShape, inputETy); |
95 | input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input, |
96 | padSizeVal, padVal); |
97 | } |
98 | |
99 | // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. |
100 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
101 | int64_t combined = ShapedType::kDynamic; |
102 | if (numDynamic == 0) |
103 | combined = inputShape[0] * inputShape[1] * inputShape[2]; |
104 | llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]}; |
105 | auto revisedInputShapeType = |
106 | RankedTensorType::get(revisedInputShape, inputType.getElementType()); |
107 | auto reshapedInput = rewriter |
108 | .create<tosa::ReshapeOp>( |
109 | op.getLoc(), revisedInputShapeType, input, |
110 | rewriter.getDenseI64ArrayAttr( |
111 | convertFromMlirShape(revisedInputShape))) |
112 | .getResult(); |
113 | |
114 | // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. |
115 | llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0], |
116 | weightShape[3]}; |
117 | auto revisedWeightShapeType = RankedTensorType::get( |
118 | revisedWeightShape, |
119 | dyn_cast<RankedTensorType>(weight.getType()).getElementType()); |
120 | auto reshapedWeight = rewriter |
121 | .create<tosa::ReshapeOp>( |
122 | op.getLoc(), revisedWeightShapeType, weight, |
123 | rewriter.getDenseI64ArrayAttr( |
124 | convertFromMlirShape(revisedWeightShape))) |
125 | .getResult(); |
126 | |
127 | // Perform a fully connected network over the reshaped input and weight. |
128 | llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]}; |
129 | auto fullyConnectedShapeType = |
130 | RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); |
131 | |
132 | Value fullyConnectedValue; |
133 | if (op.getQuantizationInfo()) { |
134 | fullyConnectedValue = |
135 | rewriter |
136 | .create<tosa::FullyConnectedOp>( |
137 | op.getLoc(), fullyConnectedShapeType, reshapedInput, |
138 | reshapedWeight, op.getBias(), *op.getQuantizationInfo()) |
139 | .getResult(); |
140 | } else { |
141 | fullyConnectedValue = rewriter |
142 | .create<tosa::FullyConnectedOp>( |
143 | op.getLoc(), fullyConnectedShapeType, |
144 | reshapedInput, reshapedWeight, op.getBias()) |
145 | .getResult(); |
146 | } |
147 | |
148 | // Reshape output to [N, IH, IW, OC]. |
149 | llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1], |
150 | inputShape[2], weightShape[0]}; |
151 | rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( |
152 | op, resultType, fullyConnectedValue, |
153 | rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape))); |
154 | return success(); |
155 | } |
156 | }; |
157 | |
158 | } // namespace |
159 | |
160 | void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, |
161 | RewritePatternSet &patterns) { |
162 | patterns.add<Conv2DIsFullyConnected>(arg&: ctx); |
163 | } |
164 | |