1//===- TosaTestPasses.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Test passes to exercise TOSA helper functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/Dialect/Tosa/Transforms/Passes.h"
17#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Matchers.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23#define PASS_NAME "tosa-test-quant-utils"
24
25using namespace mlir;
26using namespace mlir::tosa;
27
28// This transformation converts quantized uint8 to quantized int8. The
29// construction of the new type invokes buildQTypeFromMinMax. Extracted from
30// TOSA legalization infrastructure.
31struct ConvertTosaNegateOp : public RewritePattern {
32 explicit ConvertTosaNegateOp(MLIRContext *context)
33 : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
34 LogicalResult matchAndRewrite(Operation *op,
35 PatternRewriter &rewriter) const override;
36};
37
38LogicalResult
39ConvertTosaNegateOp::matchAndRewrite(Operation *op,
40 PatternRewriter &rewriter) const {
41
42 auto tosaNegateOp = cast<tosa::NegateOp>(Val: op);
43
44 auto inputType =
45 dyn_cast<mlir::RankedTensorType>(Val: tosaNegateOp.getInput1().getType());
46 // skip if input is not ranked tensor type
47 if (!inputType)
48 return failure();
49
50 // skip if it's not ranked tensor type.
51 auto outputType =
52 dyn_cast<mlir::RankedTensorType>(Val: tosaNegateOp.getResult().getType());
53 if (!outputType)
54 return failure();
55
56 // skip if output is not per-tensor quantized type.
57 auto outputElementType =
58 dyn_cast<mlir::quant::UniformQuantizedType>(Val: outputType.getElementType());
59 if (!outputElementType)
60 return failure();
61
62 // skip if output is not uint8.
63 if (outputElementType.isSigned() ||
64 outputElementType.getStorageTypeIntegralWidth() != 8)
65 return failure();
66
67 double typeRangeMin = double(outputElementType.getStorageTypeMin() -
68 outputElementType.getZeroPoint()) *
69 outputElementType.getScale();
70 double typeRangeMax = double(outputElementType.getStorageTypeMax() -
71 outputElementType.getZeroPoint()) *
72 outputElementType.getScale();
73 bool narrowRange = outputElementType.getStorageTypeMin() == 1;
74
75 auto dstQConstType = RankedTensorType::get(
76 shape: outputType.getShape(),
77 elementType: buildQTypeFromMinMax(builder: rewriter, inputDType: outputElementType.getExpressedType(),
78 minAttr: rewriter.getF64FloatAttr(value: typeRangeMin),
79 maxAttr: rewriter.getF64FloatAttr(value: typeRangeMax),
80 quantBits: rewriter.getI32IntegerAttr(
81 value: outputElementType.getStorageTypeIntegralWidth()),
82 filterQuantDim: 0, isSigned: true /* signed */,
83 narrowRange: rewriter.getBoolAttr(value: narrowRange)));
84
85 ElementsAttr inputElems;
86 if (!matchPattern(value: tosaNegateOp.getInput1(), pattern: m_Constant(bind_value: &inputElems)))
87 return failure();
88
89 auto newConstOp =
90 rewriter.create<tosa::ConstOp>(location: op->getLoc(), args&: dstQConstType, args&: inputElems);
91 auto newNegateOp = rewriter.create<tosa::NegateOp>(
92 location: op->getLoc(), args&: dstQConstType, args: newConstOp.getResult());
93
94 rewriter.replaceOp(op, newValues: {newNegateOp.getResult()});
95 return success();
96}
97
98// This transformation modifies the quantized output of a test conv2d input and
99// appends a TOSA rescale after it. The rescale op requires the invocation of
100// computeMultiplierAndShift. From TOSA legalization infrastructure.
101struct ConvertTosaConv2DOp : public RewritePattern {
102 explicit ConvertTosaConv2DOp(MLIRContext *context)
103 : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
104 LogicalResult matchAndRewrite(Operation *op,
105 PatternRewriter &rewriter) const override;
106};
107
108LogicalResult
109ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
110 PatternRewriter &rewriter) const {
111
112 auto tosaConv2DOp = cast<tosa::Conv2DOp>(Val: op);
113
114 auto inputType =
115 dyn_cast<mlir::RankedTensorType>(Val: tosaConv2DOp.getInput().getType());
116
117 // skip if input is not ranked tensor type
118 if (!inputType)
119 return failure();
120
121 auto weightType =
122 dyn_cast<mlir::RankedTensorType>(Val: tosaConv2DOp.getWeight().getType());
123
124 // skip if wt is not ranked tensor type
125 if (!weightType)
126 return failure();
127
128 // skip if it's not ranked tensor type.
129 auto outputType =
130 dyn_cast<mlir::RankedTensorType>(Val: tosaConv2DOp.getResult().getType());
131 if (!outputType)
132 return failure();
133
134 auto inputQType =
135 dyn_cast<mlir::quant::UniformQuantizedType>(Val: inputType.getElementType());
136 auto weightQType =
137 dyn_cast<mlir::quant::UniformQuantizedType>(Val: weightType.getElementType());
138 auto outputQType =
139 dyn_cast<mlir::quant::UniformQuantizedType>(Val: outputType.getElementType());
140
141 // Works on quantized type only.
142 if (!(inputQType && weightQType && outputQType))
143 return failure();
144
145 auto newTosaConv2DOpType =
146 RankedTensorType::get(shape: outputType.getShape(), elementType: rewriter.getIntegerType(width: 32));
147
148 auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
149 location: op->getLoc(), args&: newTosaConv2DOpType, args: tosaConv2DOp.getInput(),
150 args: tosaConv2DOp.getWeight(), args: tosaConv2DOp.getBias(),
151 args: tosaConv2DOp.getPadAttr(), args: tosaConv2DOp.getStrideAttr(),
152 args: tosaConv2DOp.getDilationAttr(), args: tosaConv2DOp.getAccTypeAttr());
153
154 // Create rescale to quantized type
155 double inputScale = inputQType.getScale();
156 double weightScale = weightQType.getScale();
157 double outputScale = outputQType.getScale();
158 int64_t outputZpVal = outputQType.getZeroPoint();
159
160 auto inputZp =
161 createZeroPointTensor(builder&: rewriter, loc: op->getLoc(), srcElemType: newTosaConv2DOpType, zp: 0);
162 auto outputZp = createZeroPointTensor(
163 builder&: rewriter, loc: op->getLoc(), srcElemType: tosaConv2DOp.getOutput().getType(), zp: outputZpVal);
164
165 if (!inputZp || !outputZp)
166 return failure();
167
168 double opTensorScale = (inputScale * weightScale) / outputScale;
169
170 int32_t multiplier;
171 int32_t shift;
172
173 // Obtain the quantized scale = multiplier and shift.
174 if (!computeMultiplierAndShift(scale: opTensorScale, multiplier, shift, scaleWidth: 32))
175 return failure();
176
177 bool inputUnsigned =
178 newTosaConv2DOp.getResult().getType().isUnsignedInteger();
179 bool outputUnsigned = outputType.isUnsignedInteger();
180
181 auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
182 location: op->getLoc(), args&: outputType, args: newTosaConv2DOp.getResult(),
183 args: getConstTensorInt<int32_t>(builder&: rewriter, loc: op->getLoc(), vec: {multiplier}),
184 args: getConstTensorInt<int8_t>(builder&: rewriter, loc: op->getLoc(),
185 vec: {static_cast<int8_t>(shift)}),
186 args&: inputZp.value(), args&: outputZp.value(),
187 /* scale32 = */ args: rewriter.getBoolAttr(value: true),
188 /* double_round = */ args: rewriter.getStringAttr(bytes: "DOUBLE_ROUND"),
189 /* per_channel = */ args: rewriter.getBoolAttr(value: false),
190 args: rewriter.getBoolAttr(value: inputUnsigned),
191 args: rewriter.getBoolAttr(value: outputUnsigned));
192
193 rewriter.replaceOp(op, newValues: {newTosaRescaleOp.getResult()});
194 return success();
195}
196
197namespace {
198
199struct TosaTestQuantUtilAPI
200 : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> {
201 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI)
202
203 StringRef getArgument() const final { return PASS_NAME; }
204 StringRef getDescription() const final {
205 return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
206 }
207 void runOnOperation() override;
208};
209
210void TosaTestQuantUtilAPI::runOnOperation() {
211 auto *ctx = &getContext();
212 RewritePatternSet patterns(ctx);
213 auto func = getOperation();
214
215 patterns.add<ConvertTosaNegateOp>(arg&: ctx);
216 patterns.add<ConvertTosaConv2DOp>(arg&: ctx);
217 (void)applyPatternsGreedily(op: func, patterns: std::move(patterns));
218}
219
220} // namespace
221
222namespace mlir {
223void registerTosaTestQuantUtilAPIPass() {
224 PassRegistration<TosaTestQuantUtilAPI>();
225}
226} // namespace mlir
227

source code of mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp