1//===- TosaTestPasses.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Test passes to exercise TOSA helper functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/Dialect/Tosa/Transforms/Passes.h"
17#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Matchers.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23#define PASS_NAME "tosa-test-quant-utils"
24
25using namespace mlir;
26using namespace mlir::tosa;
27
28// This transformation converts quantized uint8 to quantized int8. The
29// construction of the new type invokes buildQTypeFromMinMax. Extracted from
30// TOSA legalization infrastructure.
31struct ConvertTosaNegateOp : public RewritePattern {
32 explicit ConvertTosaNegateOp(MLIRContext *context)
33 : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
34 LogicalResult matchAndRewrite(Operation *op,
35 PatternRewriter &rewriter) const override;
36};
37
38LogicalResult
39ConvertTosaNegateOp::matchAndRewrite(Operation *op,
40 PatternRewriter &rewriter) const {
41
42 auto tosaNegateOp = cast<tosa::NegateOp>(op);
43
44 auto inputType =
45 dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType());
46 // skip if input is not ranked tensor type
47 if (!inputType)
48 return failure();
49
50 // skip if it's not ranked tensor type.
51 auto outputType =
52 dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType());
53 if (!outputType)
54 return failure();
55
56 // skip if output is not per-tensor quantized type.
57 auto outputElementType =
58 dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
59 if (!outputElementType)
60 return failure();
61
62 // skip if output is not uint8.
63 if (outputElementType.isSigned() ||
64 outputElementType.getStorageTypeIntegralWidth() != 8)
65 return failure();
66
67 double typeRangeMin = double(outputElementType.getStorageTypeMin() -
68 outputElementType.getZeroPoint()) *
69 outputElementType.getScale();
70 double typeRangeMax = double(outputElementType.getStorageTypeMax() -
71 outputElementType.getZeroPoint()) *
72 outputElementType.getScale();
73 bool narrowRange = outputElementType.getStorageTypeMin() == 1;
74
75 auto dstQConstType = RankedTensorType::get(
76 outputType.getShape(),
77 buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
78 rewriter.getF64FloatAttr(typeRangeMin),
79 rewriter.getF64FloatAttr(typeRangeMax),
80 rewriter.getI32IntegerAttr(
81 outputElementType.getStorageTypeIntegralWidth()),
82 0, true /* signed */,
83 rewriter.getBoolAttr(narrowRange)));
84
85 ElementsAttr inputElems;
86 if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems)))
87 return failure();
88
89 auto newConstOp =
90 rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
91 auto newNegateOp = rewriter.create<tosa::NegateOp>(
92 op->getLoc(), dstQConstType, newConstOp.getResult());
93
94 rewriter.replaceOp(op, {newNegateOp.getResult()});
95 return success();
96}
97
98// This transformation modifies the quantized output of a test conv2d input and
99// appends a TOSA rescale after it. The rescale op requires the invocation of
100// computeMultiplierAndShift. From TOSA legalization infrastructure.
101struct ConvertTosaConv2DOp : public RewritePattern {
102 explicit ConvertTosaConv2DOp(MLIRContext *context)
103 : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
104 LogicalResult matchAndRewrite(Operation *op,
105 PatternRewriter &rewriter) const override;
106};
107
108LogicalResult
109ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
110 PatternRewriter &rewriter) const {
111
112 auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
113
114 auto inputType =
115 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType());
116
117 // skip if input is not ranked tensor type
118 if (!inputType)
119 return failure();
120
121 auto weightType =
122 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType());
123
124 // skip if wt is not ranked tensor type
125 if (!weightType)
126 return failure();
127
128 // skip if it's not ranked tensor type.
129 auto outputType =
130 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType());
131 if (!outputType)
132 return failure();
133
134 auto inputQType =
135 dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType());
136 auto weightQType =
137 dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType());
138 auto outputQType =
139 dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
140
141 // Works on quantized type only.
142 if (!(inputQType && weightQType && outputQType))
143 return failure();
144
145 auto newTosaConv2DOpType =
146 RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
147
148 auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
149 op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
150 tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(),
151 tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(),
152 tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr());
153
154 // Create rescale to quantized type
155 double inputScale = inputQType.getScale();
156 double weightScale = weightQType.getScale();
157 double outputScale = outputQType.getScale();
158 int64_t outputZpVal = outputQType.getZeroPoint();
159
160 auto inputZp =
161 createZeroPointTensor(rewriter, op->getLoc(), newTosaConv2DOpType, 0);
162 auto outputZp = createZeroPointTensor(
163 rewriter, op->getLoc(), tosaConv2DOp.getOutput().getType(), outputZpVal);
164
165 if (!inputZp || !outputZp)
166 return failure();
167
168 double opTensorScale = (inputScale * weightScale) / outputScale;
169
170 int32_t multiplier;
171 int32_t shift;
172
173 // Obtain the quantized scale = multiplier and shift.
174 if (!computeMultiplierAndShift(scale: opTensorScale, multiplier, shift, scaleWidth: 32))
175 return failure();
176
177 bool inputUnsigned =
178 newTosaConv2DOp.getResult().getType().isUnsignedInteger();
179 bool outputUnsigned = outputType.isUnsignedInteger();
180
181 auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
182 op->getLoc(), outputType, newTosaConv2DOp.getResult(),
183 getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
184 getConstTensorInt<int8_t>(rewriter, op->getLoc(),
185 {static_cast<int8_t>(shift)}),
186 inputZp.value(), outputZp.value(),
187 /* scale32 = */ rewriter.getBoolAttr(true),
188 /* double_round = */ rewriter.getStringAttr("DOUBLE_ROUND"),
189 /* per_channel = */ rewriter.getBoolAttr(false),
190 rewriter.getBoolAttr(inputUnsigned),
191 rewriter.getBoolAttr(outputUnsigned));
192
193 rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
194 return success();
195}
196
197namespace {
198
199struct TosaTestQuantUtilAPI
200 : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> {
201 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI)
202
203 StringRef getArgument() const final { return PASS_NAME; }
204 StringRef getDescription() const final {
205 return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
206 }
207 void runOnOperation() override;
208};
209
210void TosaTestQuantUtilAPI::runOnOperation() {
211 auto *ctx = &getContext();
212 RewritePatternSet patterns(ctx);
213 auto func = getOperation();
214
215 patterns.add<ConvertTosaNegateOp>(ctx);
216 patterns.add<ConvertTosaConv2DOp>(ctx);
217 (void)applyPatternsGreedily(func, std::move(patterns));
218}
219
220} // namespace
221
222namespace mlir {
223void registerTosaTestQuantUtilAPIPass() {
224 PassRegistration<TosaTestQuantUtilAPI>();
225}
226} // namespace mlir
227

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp