1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "QuantDialectBytecode.h"
10#include "TypeDetail.h"
11
12#include "mlir/Dialect/Quant/IR/Quant.h"
13#include "mlir/Dialect/Quant/IR/QuantTypes.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/IR/TypeUtilities.h"
17
18#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
19
20namespace mlir {
21namespace quant {
22
23namespace {
24
25// Verify the integrity of per-axis quantization information, if present.
26//
27// - uniformQuantizedPerAxisType
28// A quantized type with per-axis quantization.
29//
30// - containerType
31// Original input or result type of the operation using the provided quantized
32// type. Used to ensure that the quantized type appears within a tensor and
33// that the tensor is compatible with per-axis quantization information.
34//
35LogicalResult verifyPerAxisQuantization(
36 Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
37 Type containerType) {
38 auto tensorType = dyn_cast<TensorType>(containerType);
39 if (!tensorType)
40 return op->emitError(message: "scalar types may not use per-axis quantization");
41
42 if (!tensorType.hasRank())
43 return success();
44
45 int32_t quantizedDimension =
46 uniformQuantizedPerAxisType.getQuantizedDimension();
47 if ((int64_t)quantizedDimension >= tensorType.getRank())
48 return op->emitError(message: "quantized dimension must be less than tensor rank");
49
50 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
51 if (quantizedDimensionSize != ShapedType::kDynamic &&
52 quantizedDimensionSize !=
53 (int64_t)uniformQuantizedPerAxisType.getScales().size())
54 return op->emitError(
55 message: "quantized dimension size does not match number of scales");
56
57 return success();
58}
59
60// Verifies that the sub-channel quantization parameters are consistent with
61// the given container type. The function checks the following:
62//
63// - The container type must be a ranked tensor type.
64// - Each quantized dimension must be less than the rank of the tensor.
65// - The size of each dimension at the quantized dimension must be divisible
66// by the corresponding block size.
67// - The scale dimension size at each axis index should match the tensor
68// dimension at the index divided by the corresponding block size.
69//
70// The `uniformQuantizedSubChannelType` argument provides the sub-channel
71// quantization parameters, and the `containerType` argument specifies the
72// type of the container holding the quantized data.
73//
74LogicalResult verifySubChannelQuantization(
75 Operation *op,
76 UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
77 Type containerType) {
78 auto tensorType = dyn_cast<TensorType>(containerType);
79 if (!tensorType)
80 return op->emitError(message: "scalar types may not use sub-channel quantization");
81
82 if (!tensorType.hasRank())
83 return op->emitError(
84 message: "tensor containing the sub-channel quantized type must be ranked");
85
86 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
87 uniformQuantizedSubChannelType.getBlockSizeInfo();
88 auto shape = tensorType.getShape();
89
90 // The dimension size of scale for an axis which is not specified as quantized
91 // dimension should be 1.
92 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
93 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
94 if (quantizedDimension >= tensorType.getRank())
95 return op->emitError()
96 << "quantized dimension " << quantizedDimension
97 << " must be less than tensor rank " << tensorType.getRank();
98 if (!tensorType.isDynamicDim(quantizedDimension) &&
99 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
100 return op->emitError()
101 << "tensor dimension size "
102 << tensorType.getDimSize(quantizedDimension) << " at axis "
103 << quantizedDimension
104 << " must be divisible by the corresponding block size "
105 << blockSize;
106 if (tensorType.isDynamicDim(quantizedDimension))
107 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
108 else
109 expectedScaleShape[quantizedDimension] =
110 tensorType.getDimSize(quantizedDimension) / blockSize;
111 }
112
113 // Block sizes must be greater than 0 and divide the corresponding dimension
114 // size. While a block size b must be less than or equal to the corresponding
115 // dimension size d, this constraint is implicitly enforced by requiring that
116 // d % b == 0 when d != 0.
117 //
118 // However, a problem arises when d = 0. The divisibility constraint allows b
119 // to be any value, potentially violating the requirement that b <= d.
120 // Furthermore, if b is unspecified (implicitly equal to d), it violates the
121 // constraint that b > 0.
122 //
123 // Therefore, we explicitly disallow the case where d = 0 to maintain
124 // consistency and avoid these issues.
125 if (llvm::is_contained(tensorType.getShape(), 0)) {
126 return op->emitError() << "tensor dimension size of zero is not allowed "
127 "with sub-channel quantization";
128 }
129
130 auto scaleShape =
131 uniformQuantizedSubChannelType.getScales().getType().getShape();
132 if (scaleShape.size() != shape.size()) {
133 return op->emitError() << "Rank of scales " << scaleShape.size()
134 << " must match "
135 << "the rank of the tensor " << shape.size();
136 }
137
138 for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
139 if (expectedScaleShape[index] != ShapedType::kDynamic &&
140 expectedScaleShape[index] != scaleShape[index])
141 return op->emitError() << "dimension size " << scaleDim
142 << " of scales tensor at axis " << index
143 << " should match (tensor dimension at axis / "
144 "block sizes at axis) = "
145 << expectedScaleShape[index];
146 }
147
148 return success();
149}
150
151// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
152//
153// - quantizedType
154// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
155// whether as a primitive type or in a tensor.
156//
157// - floatType
158// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
159// whether as a primitive type or in a tensor.
160//
161// - containerType
162// Type of original input or result.
163//
164LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
165 FloatType floatType, Type containerType) {
166 if (quantizedType.getExpressedType() != floatType)
167 return op->emitError(
168 message: "expressed type in quantized type expected to match float type");
169
170 // Verify integrity of per-axis quantization information, if present.
171 if (auto quantizedPerAxisType =
172 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
173 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
174 }
175
176 if (auto quantizedSubChannelType =
177 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
178 return verifySubChannelQuantization(op, quantizedSubChannelType,
179 containerType);
180 }
181
182 // At this point the type is UniformQuantizedType
183 return success();
184}
185
186} // namespace
187
188//===----------------------------------------------------------------------===//
189// Dialect
190//===----------------------------------------------------------------------===//
191
192void QuantDialect::initialize() {
193 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
194 UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
195 addOperations<
196#define GET_OP_LIST
197#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
198 >();
199 detail::addBytecodeInterface(this);
200}
201
202//===----------------------------------------------------------------------===//
203// DequantizeCastOp
204//===----------------------------------------------------------------------===//
205
206LogicalResult DequantizeCastOp::verify() {
207 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
208 getInput().getType());
209}
210
211OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
212 // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
213 // with the value of x. Values x and y are guaranteed to be of the same type
214 // in this pattern.
215 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
216 if (!srcQcastOp)
217 return {};
218 assert(srcQcastOp.getInput().getType() == getType());
219 return srcQcastOp.getInput();
220}
221
222FloatType DequantizeCastOp::getFloatType() {
223 return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
224}
225
226QuantizedType DequantizeCastOp::getQuantizedType() {
227 return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
228}
229
230//===----------------------------------------------------------------------===//
231// QuantizeCastOp
232//===----------------------------------------------------------------------===//
233
234LogicalResult QuantizeCastOp::verify() {
235 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
236 getInput().getType());
237}
238
239OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
240 // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
241 // with the value of x if the casts invert each other. Contrary to the folding
242 // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
243 // x and y are not guaranteed to be of the same type here, as they may use
244 // different quantization parameters.
245 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
246 if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
247 return {};
248 return srcDcastOp.getInput();
249}
250
251FloatType QuantizeCastOp::getFloatType() {
252 return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
253}
254
255QuantizedType QuantizeCastOp::getQuantizedType() {
256 return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
257}
258
259//===----------------------------------------------------------------------===//
260// StorageCastOp
261//===----------------------------------------------------------------------===//
262
263LogicalResult StorageCastOp::verify() {
264 auto quantizedType = getQuantizedType();
265 auto integerType = getIntegerType();
266 if (quantizedType.getStorageType() != integerType)
267 return emitError(
268 "storage type in quantized type expected to match integer type");
269
270 // Verify integrity of per-axis quantization information, if available. While
271 // the quantization type may appear in the input or the result, their tensor
272 // shapes are guaranteed to be identical at this point.
273 if (auto quantizedPerAxisType =
274 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
275 return verifyPerAxisQuantization(*this, quantizedPerAxisType,
276 getInput().getType());
277 }
278
279 if (auto quantizedSunChannelType =
280 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
281 return verifySubChannelQuantization(*this, quantizedSunChannelType,
282 getInput().getType());
283 }
284
285 // At this point the type is UniformQuantizedType
286 return success();
287}
288
289OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
290 // Matches x -> quant.scast -> quant.scast -> y, replacing the second
291 // quant.scast with the value of x if the casts invert each other.
292 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
293 if (!srcScastOp || srcScastOp.getInput().getType() != getType())
294 return {};
295 return srcScastOp.getInput();
296}
297
298IntegerType StorageCastOp::getIntegerType() {
299 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
300 if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
301 return integerType;
302
303 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
304 return cast<IntegerType>(resultScalarType);
305}
306
307QuantizedType StorageCastOp::getQuantizedType() {
308 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
309 if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
310 return quantizedType;
311
312 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
313 return cast<QuantizedType>(resultScalarType);
314}
315
316} // namespace quant
317} // namespace mlir
318
319#define GET_OP_CLASSES
320#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
321

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Quant/IR/QuantOps.cpp