1 | //===- TosaTestPasses.cpp -------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Test passes to exercise TOSA helper functions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
16 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
17 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/Matchers.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | |
23 | #define PASS_NAME "tosa-test-quant-utils" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::tosa; |
27 | |
28 | // This transformation converts quantized uint8 to quantized int8. The |
29 | // construction of the new type invokes buildQTypeFromMinMax. Extracted from |
30 | // TOSA legalization infrastructure. |
31 | struct ConvertTosaNegateOp : public RewritePattern { |
32 | explicit ConvertTosaNegateOp(MLIRContext *context) |
33 | : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} |
34 | LogicalResult matchAndRewrite(Operation *op, |
35 | PatternRewriter &rewriter) const override; |
36 | }; |
37 | |
38 | LogicalResult |
39 | ConvertTosaNegateOp::matchAndRewrite(Operation *op, |
40 | PatternRewriter &rewriter) const { |
41 | |
42 | auto tosaNegateOp = cast<tosa::NegateOp>(op); |
43 | |
44 | auto inputType = |
45 | dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType()); |
46 | // skip if input is not ranked tensor type |
47 | if (!inputType) |
48 | return failure(); |
49 | |
50 | // skip if it's not ranked tensor type. |
51 | auto outputType = |
52 | dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType()); |
53 | if (!outputType) |
54 | return failure(); |
55 | |
56 | // skip if output is not per-tensor quantized type. |
57 | auto outputElementType = |
58 | dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType()); |
59 | if (!outputElementType) |
60 | return failure(); |
61 | |
62 | // skip if output is not uint8. |
63 | if (outputElementType.isSigned() || |
64 | outputElementType.getStorageTypeIntegralWidth() != 8) |
65 | return failure(); |
66 | |
67 | double typeRangeMin = double(outputElementType.getStorageTypeMin() - |
68 | outputElementType.getZeroPoint()) * |
69 | outputElementType.getScale(); |
70 | double typeRangeMax = double(outputElementType.getStorageTypeMax() - |
71 | outputElementType.getZeroPoint()) * |
72 | outputElementType.getScale(); |
73 | bool narrowRange = outputElementType.getStorageTypeMin() == 1; |
74 | |
75 | auto dstQConstType = RankedTensorType::get( |
76 | outputType.getShape(), |
77 | buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), |
78 | rewriter.getF64FloatAttr(typeRangeMin), |
79 | rewriter.getF64FloatAttr(typeRangeMax), |
80 | rewriter.getI32IntegerAttr( |
81 | outputElementType.getStorageTypeIntegralWidth()), |
82 | 0, true /* signed */, |
83 | rewriter.getBoolAttr(narrowRange))); |
84 | |
85 | ElementsAttr inputElems; |
86 | if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems))) |
87 | return failure(); |
88 | |
89 | auto newConstOp = |
90 | rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems); |
91 | auto newNegateOp = rewriter.create<tosa::NegateOp>( |
92 | op->getLoc(), dstQConstType, newConstOp.getResult()); |
93 | |
94 | rewriter.replaceOp(op, {newNegateOp.getResult()}); |
95 | return success(); |
96 | } |
97 | |
98 | // This transformation modifies the quantized output of a test conv2d input and |
99 | // appends a TOSA rescale after it. The rescale op requires the invocation of |
100 | // computeMultiplierAndShift. From TOSA legalization infrastructure. |
101 | struct ConvertTosaConv2DOp : public RewritePattern { |
102 | explicit ConvertTosaConv2DOp(MLIRContext *context) |
103 | : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} |
104 | LogicalResult matchAndRewrite(Operation *op, |
105 | PatternRewriter &rewriter) const override; |
106 | }; |
107 | |
108 | LogicalResult |
109 | ConvertTosaConv2DOp::matchAndRewrite(Operation *op, |
110 | PatternRewriter &rewriter) const { |
111 | |
112 | auto tosaConv2DOp = cast<tosa::Conv2DOp>(op); |
113 | |
114 | auto inputType = |
115 | dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType()); |
116 | |
117 | // skip if input is not ranked tensor type |
118 | if (!inputType) |
119 | return failure(); |
120 | |
121 | auto weightType = |
122 | dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType()); |
123 | |
124 | // skip if wt is not ranked tensor type |
125 | if (!weightType) |
126 | return failure(); |
127 | |
128 | // skip if it's not ranked tensor type. |
129 | auto outputType = |
130 | dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType()); |
131 | if (!outputType) |
132 | return failure(); |
133 | |
134 | auto inputQType = |
135 | dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType()); |
136 | auto weightQType = |
137 | dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType()); |
138 | auto outputQType = |
139 | dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType()); |
140 | |
141 | // Works on quantized type only. |
142 | if (!(inputQType && weightQType && outputQType)) |
143 | return failure(); |
144 | |
145 | auto newTosaConv2DOpType = |
146 | RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); |
147 | |
148 | auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>( |
149 | op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), |
150 | tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), |
151 | tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), |
152 | tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr()); |
153 | |
154 | // Create rescale to quantized type |
155 | double inputScale = inputQType.getScale(); |
156 | double weightScale = weightQType.getScale(); |
157 | double outputScale = outputQType.getScale(); |
158 | int64_t outputZpVal = outputQType.getZeroPoint(); |
159 | |
160 | auto inputZp = |
161 | createZeroPointTensor(rewriter, op->getLoc(), newTosaConv2DOpType, 0); |
162 | auto outputZp = createZeroPointTensor( |
163 | rewriter, op->getLoc(), tosaConv2DOp.getOutput().getType(), outputZpVal); |
164 | |
165 | if (!inputZp || !outputZp) |
166 | return failure(); |
167 | |
168 | double opTensorScale = (inputScale * weightScale) / outputScale; |
169 | |
170 | int32_t multiplier; |
171 | int32_t shift; |
172 | |
173 | // Obtain the quantized scale = multiplier and shift. |
174 | if (!computeMultiplierAndShift(scale: opTensorScale, multiplier, shift, scaleWidth: 32)) |
175 | return failure(); |
176 | |
177 | bool inputUnsigned = |
178 | newTosaConv2DOp.getResult().getType().isUnsignedInteger(); |
179 | bool outputUnsigned = outputType.isUnsignedInteger(); |
180 | |
181 | auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>( |
182 | op->getLoc(), outputType, newTosaConv2DOp.getResult(), |
183 | getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}), |
184 | getConstTensorInt<int8_t>(rewriter, op->getLoc(), |
185 | {static_cast<int8_t>(shift)}), |
186 | inputZp.value(), outputZp.value(), |
187 | /* scale32 = */ rewriter.getBoolAttr(true), |
188 | /* double_round = */ rewriter.getStringAttr("DOUBLE_ROUND" ), |
189 | /* per_channel = */ rewriter.getBoolAttr(false), |
190 | rewriter.getBoolAttr(inputUnsigned), |
191 | rewriter.getBoolAttr(outputUnsigned)); |
192 | |
193 | rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); |
194 | return success(); |
195 | } |
196 | |
197 | namespace { |
198 | |
199 | struct TosaTestQuantUtilAPI |
200 | : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> { |
201 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI) |
202 | |
203 | StringRef getArgument() const final { return PASS_NAME; } |
204 | StringRef getDescription() const final { |
205 | return "TOSA Test: Exercise the APIs in QuantUtils.cpp." ; |
206 | } |
207 | void runOnOperation() override; |
208 | }; |
209 | |
210 | void TosaTestQuantUtilAPI::runOnOperation() { |
211 | auto *ctx = &getContext(); |
212 | RewritePatternSet patterns(ctx); |
213 | auto func = getOperation(); |
214 | |
215 | patterns.add<ConvertTosaNegateOp>(ctx); |
216 | patterns.add<ConvertTosaConv2DOp>(ctx); |
217 | (void)applyPatternsGreedily(func, std::move(patterns)); |
218 | } |
219 | |
220 | } // namespace |
221 | |
222 | namespace mlir { |
223 | void registerTosaTestQuantUtilAPIPass() { |
224 | PassRegistration<TosaTestQuantUtilAPI>(); |
225 | } |
226 | } // namespace mlir |
227 | |