1//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15#include "mlir/Dialect/Tosa/Transforms/Passes.h"
16#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17
18using namespace mlir;
19using namespace mlir::tosa;
20
21namespace {
22
23SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
24 return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) {
25 return ShapedType::isDynamic(dim) ? -1 : dim;
26 }));
27}
28
29struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
30 explicit Conv2DIsFullyConnected(MLIRContext *context)
31 : OpRewritePattern(context) {}
32
33 LogicalResult matchAndRewrite(tosa::Conv2DOp op,
34 PatternRewriter &rewriter) const override {
35 Value input = op.getInput();
36 Value weight = op.getWeight();
37 ShapedType inputType = cast<ShapedType>(input.getType());
38 ShapedType weightType = cast<ShapedType>(weight.getType());
39 ShapedType resultType = cast<ShapedType>(op.getType());
40
41 auto numDynamic =
42 llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
43 if (numDynamic > 1)
44 return rewriter.notifyMatchFailure(
45 op, "at most one dim in input may be dynamic");
46 if (!weightType.hasRank())
47 return rewriter.notifyMatchFailure(op, "unranked weight input");
48
49 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
50 return failure();
51
52 // Only works for a 1x1 kernel.
53 ArrayRef<int64_t> weightShape = weightType.getShape();
54 if (weightShape[1] != 1 || weightShape[2] != 1)
55 return failure();
56
57 llvm::ArrayRef<int64_t> padAttr = op.getPad();
58 llvm::SmallVector<int64_t> pad(8, 0);
59 for (const auto &it : llvm::enumerate(padAttr))
60 pad[it.index() + 2] = it.value();
61
62 if (llvm::any_of(Range&: pad, P: [](int64_t p) { return p != 0; })) {
63 Type inputETy = inputType.getElementType();
64 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
65 if (op.getQuantizationInfo()) {
66 auto quantizationInfo = op.getQuantizationInfo();
67 int64_t iZp = quantizationInfo->getInputZp();
68
69 if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
70 return rewriter.notifyMatchFailure(
71 op, "tosa.conv op quantization has zp outside of input range");
72
73 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
74 }
75
76 llvm::SmallVector<int64_t> newShape(inputType.getShape());
77
78 for (int i = 0, s = newShape.size(); i < s; ++i) {
79 if (newShape[i] != ShapedType::kDynamic) {
80 newShape[i] += pad[i * 2] + pad[i * 2 + 1];
81 }
82 }
83
84 auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
85 auto padSize =
86 DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
87 Value padSizeVal =
88 rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
89
90 auto padTy = RankedTensorType::get({}, inputETy);
91 auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
92 Value padVal =
93 rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
94 inputType = RankedTensorType::get(newShape, inputETy);
95 input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
96 padSizeVal, padVal);
97 }
98
99 // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
100 ArrayRef<int64_t> inputShape = inputType.getShape();
101 int64_t combined = ShapedType::kDynamic;
102 if (numDynamic == 0)
103 combined = inputShape[0] * inputShape[1] * inputShape[2];
104 llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
105 auto revisedInputShapeType =
106 RankedTensorType::get(revisedInputShape, inputType.getElementType());
107 auto reshapedInput = rewriter
108 .create<tosa::ReshapeOp>(
109 op.getLoc(), revisedInputShapeType, input,
110 rewriter.getDenseI64ArrayAttr(
111 convertFromMlirShape(revisedInputShape)))
112 .getResult();
113
114 // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
115 llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
116 weightShape[3]};
117 auto revisedWeightShapeType = RankedTensorType::get(
118 revisedWeightShape,
119 dyn_cast<RankedTensorType>(weight.getType()).getElementType());
120 auto reshapedWeight = rewriter
121 .create<tosa::ReshapeOp>(
122 op.getLoc(), revisedWeightShapeType, weight,
123 rewriter.getDenseI64ArrayAttr(
124 convertFromMlirShape(revisedWeightShape)))
125 .getResult();
126
127 // Perform a fully connected network over the reshaped input and weight.
128 llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
129 auto fullyConnectedShapeType =
130 RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
131
132 Value fullyConnectedValue;
133 if (op.getQuantizationInfo()) {
134 fullyConnectedValue =
135 rewriter
136 .create<tosa::FullyConnectedOp>(
137 op.getLoc(), fullyConnectedShapeType, reshapedInput,
138 reshapedWeight, op.getBias(), *op.getQuantizationInfo())
139 .getResult();
140 } else {
141 fullyConnectedValue = rewriter
142 .create<tosa::FullyConnectedOp>(
143 op.getLoc(), fullyConnectedShapeType,
144 reshapedInput, reshapedWeight, op.getBias())
145 .getResult();
146 }
147
148 // Reshape output to [N, IH, IW, OC].
149 llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
150 inputShape[2], weightShape[0]};
151 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
152 op, resultType, fullyConnectedValue,
153 rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
154 return success();
155 }
156};
157
158} // namespace
159
160void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
161 RewritePatternSet &patterns) {
162 patterns.add<Conv2DIsFullyConnected>(arg&: ctx);
163}
164

source code of mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp