1 | //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "QuantDialectBytecode.h" |
10 | #include "TypeDetail.h" |
11 | |
12 | #include "mlir/Dialect/Quant/IR/Quant.h" |
13 | #include "mlir/Dialect/Quant/IR/QuantTypes.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/IR/PatternMatch.h" |
16 | #include "mlir/IR/TypeUtilities.h" |
17 | |
18 | #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" |
19 | |
20 | namespace mlir { |
21 | namespace quant { |
22 | |
23 | namespace { |
24 | |
25 | // Verify the integrity of per-axis quantization information, if present. |
26 | // |
27 | // - uniformQuantizedPerAxisType |
28 | // A quantized type with per-axis quantization. |
29 | // |
30 | // - containerType |
31 | // Original input or result type of the operation using the provided quantized |
32 | // type. Used to ensure that the quantized type appears within a tensor and |
33 | // that the tensor is compatible with per-axis quantization information. |
34 | // |
35 | LogicalResult verifyPerAxisQuantization( |
36 | Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType, |
37 | Type containerType) { |
38 | auto tensorType = dyn_cast<TensorType>(containerType); |
39 | if (!tensorType) |
40 | return op->emitError(message: "scalar types may not use per-axis quantization" ); |
41 | |
42 | if (!tensorType.hasRank()) |
43 | return success(); |
44 | |
45 | int32_t quantizedDimension = |
46 | uniformQuantizedPerAxisType.getQuantizedDimension(); |
47 | if ((int64_t)quantizedDimension >= tensorType.getRank()) |
48 | return op->emitError(message: "quantized dimension must be less than tensor rank" ); |
49 | |
50 | int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); |
51 | if (quantizedDimensionSize != ShapedType::kDynamic && |
52 | quantizedDimensionSize != |
53 | (int64_t)uniformQuantizedPerAxisType.getScales().size()) |
54 | return op->emitError( |
55 | message: "quantized dimension size does not match number of scales" ); |
56 | |
57 | return success(); |
58 | } |
59 | |
60 | // Verifies that the sub-channel quantization parameters are consistent with |
61 | // the given container type. The function checks the following: |
62 | // |
63 | // - The container type must be a ranked tensor type. |
64 | // - Each quantized dimension must be less than the rank of the tensor. |
65 | // - The size of each dimension at the quantized dimension must be divisible |
66 | // by the corresponding block size. |
67 | // - The scale dimension size at each axis index should match the tensor |
68 | // dimension at the index divided by the corresponding block size. |
69 | // |
70 | // The `uniformQuantizedSubChannelType` argument provides the sub-channel |
71 | // quantization parameters, and the `containerType` argument specifies the |
72 | // type of the container holding the quantized data. |
73 | // |
74 | LogicalResult verifySubChannelQuantization( |
75 | Operation *op, |
76 | UniformQuantizedSubChannelType uniformQuantizedSubChannelType, |
77 | Type containerType) { |
78 | auto tensorType = dyn_cast<TensorType>(containerType); |
79 | if (!tensorType) |
80 | return op->emitError(message: "scalar types may not use sub-channel quantization" ); |
81 | |
82 | if (!tensorType.hasRank()) |
83 | return op->emitError( |
84 | message: "tensor containing the sub-channel quantized type must be ranked" ); |
85 | |
86 | const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo = |
87 | uniformQuantizedSubChannelType.getBlockSizeInfo(); |
88 | auto shape = tensorType.getShape(); |
89 | |
90 | // The dimension size of scale for an axis which is not specified as quantized |
91 | // dimension should be 1. |
92 | SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1); |
93 | for (auto [quantizedDimension, blockSize] : blockSizeInfo) { |
94 | if (quantizedDimension >= tensorType.getRank()) |
95 | return op->emitError() |
96 | << "quantized dimension " << quantizedDimension |
97 | << " must be less than tensor rank " << tensorType.getRank(); |
98 | if (!tensorType.isDynamicDim(quantizedDimension) && |
99 | tensorType.getDimSize(quantizedDimension) % blockSize != 0) |
100 | return op->emitError() |
101 | << "tensor dimension size " |
102 | << tensorType.getDimSize(quantizedDimension) << " at axis " |
103 | << quantizedDimension |
104 | << " must be divisible by the corresponding block size " |
105 | << blockSize; |
106 | if (tensorType.isDynamicDim(quantizedDimension)) |
107 | expectedScaleShape[quantizedDimension] = ShapedType::kDynamic; |
108 | else |
109 | expectedScaleShape[quantizedDimension] = |
110 | tensorType.getDimSize(quantizedDimension) / blockSize; |
111 | } |
112 | |
113 | // Block sizes must be greater than 0 and divide the corresponding dimension |
114 | // size. While a block size b must be less than or equal to the corresponding |
115 | // dimension size d, this constraint is implicitly enforced by requiring that |
116 | // d % b == 0 when d != 0. |
117 | // |
118 | // However, a problem arises when d = 0. The divisibility constraint allows b |
119 | // to be any value, potentially violating the requirement that b <= d. |
120 | // Furthermore, if b is unspecified (implicitly equal to d), it violates the |
121 | // constraint that b > 0. |
122 | // |
123 | // Therefore, we explicitly disallow the case where d = 0 to maintain |
124 | // consistency and avoid these issues. |
125 | if (llvm::is_contained(tensorType.getShape(), 0)) { |
126 | return op->emitError() << "tensor dimension size of zero is not allowed " |
127 | "with sub-channel quantization" ; |
128 | } |
129 | |
130 | auto scaleShape = |
131 | uniformQuantizedSubChannelType.getScales().getType().getShape(); |
132 | if (scaleShape.size() != shape.size()) { |
133 | return op->emitError() << "Rank of scales " << scaleShape.size() |
134 | << " must match " |
135 | << "the rank of the tensor " << shape.size(); |
136 | } |
137 | |
138 | for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) { |
139 | if (expectedScaleShape[index] != ShapedType::kDynamic && |
140 | expectedScaleShape[index] != scaleShape[index]) |
141 | return op->emitError() << "dimension size " << scaleDim |
142 | << " of scales tensor at axis " << index |
143 | << " should match (tensor dimension at axis / " |
144 | "block sizes at axis) = " |
145 | << expectedScaleShape[index]; |
146 | } |
147 | |
148 | return success(); |
149 | } |
150 | |
151 | // Common verification logic for 'quant.dcast' and 'quant.qcast' ops. |
152 | // |
153 | // - quantizedType |
154 | // Quantized type used in the input ('quant.dcast') or result ('quant.qcast'), |
155 | // whether as a primitive type or in a tensor. |
156 | // |
157 | // - floatType |
158 | // Float type used in the input ('quant.qcast') or result ('quant.dcast'), |
159 | // whether as a primitive type or in a tensor. |
160 | // |
161 | // - containerType |
162 | // Type of original input or result. |
163 | // |
164 | LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, |
165 | FloatType floatType, Type containerType) { |
166 | if (quantizedType.getExpressedType() != floatType) |
167 | return op->emitError( |
168 | message: "expressed type in quantized type expected to match float type" ); |
169 | |
170 | // Verify integrity of per-axis quantization information, if present. |
171 | if (auto quantizedPerAxisType = |
172 | dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) { |
173 | return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType); |
174 | } |
175 | |
176 | if (auto quantizedSubChannelType = |
177 | dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) { |
178 | return verifySubChannelQuantization(op, quantizedSubChannelType, |
179 | containerType); |
180 | } |
181 | |
182 | // At this point the type is UniformQuantizedType |
183 | return success(); |
184 | } |
185 | |
186 | } // namespace |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // Dialect |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | void QuantDialect::initialize() { |
193 | addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType, |
194 | UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>(); |
195 | addOperations< |
196 | #define GET_OP_LIST |
197 | #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" |
198 | >(); |
199 | detail::addBytecodeInterface(this); |
200 | } |
201 | |
202 | //===----------------------------------------------------------------------===// |
203 | // DequantizeCastOp |
204 | //===----------------------------------------------------------------------===// |
205 | |
206 | LogicalResult DequantizeCastOp::verify() { |
207 | return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), |
208 | getInput().getType()); |
209 | } |
210 | |
211 | OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) { |
212 | // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op |
213 | // with the value of x. Values x and y are guaranteed to be of the same type |
214 | // in this pattern. |
215 | auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>(); |
216 | if (!srcQcastOp) |
217 | return {}; |
218 | assert(srcQcastOp.getInput().getType() == getType()); |
219 | return srcQcastOp.getInput(); |
220 | } |
221 | |
222 | FloatType DequantizeCastOp::getFloatType() { |
223 | return cast<FloatType>(getElementTypeOrSelf(getResult().getType())); |
224 | } |
225 | |
226 | QuantizedType DequantizeCastOp::getQuantizedType() { |
227 | return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType())); |
228 | } |
229 | |
230 | //===----------------------------------------------------------------------===// |
231 | // QuantizeCastOp |
232 | //===----------------------------------------------------------------------===// |
233 | |
234 | LogicalResult QuantizeCastOp::verify() { |
235 | return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), |
236 | getInput().getType()); |
237 | } |
238 | |
239 | OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) { |
240 | // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op |
241 | // with the value of x if the casts invert each other. Contrary to the folding |
242 | // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values |
243 | // x and y are not guaranteed to be of the same type here, as they may use |
244 | // different quantization parameters. |
245 | auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>(); |
246 | if (!srcDcastOp || srcDcastOp.getInput().getType() != getType()) |
247 | return {}; |
248 | return srcDcastOp.getInput(); |
249 | } |
250 | |
251 | FloatType QuantizeCastOp::getFloatType() { |
252 | return cast<FloatType>(getElementTypeOrSelf(getInput().getType())); |
253 | } |
254 | |
255 | QuantizedType QuantizeCastOp::getQuantizedType() { |
256 | return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType())); |
257 | } |
258 | |
259 | //===----------------------------------------------------------------------===// |
260 | // StorageCastOp |
261 | //===----------------------------------------------------------------------===// |
262 | |
263 | LogicalResult StorageCastOp::verify() { |
264 | auto quantizedType = getQuantizedType(); |
265 | auto integerType = getIntegerType(); |
266 | if (quantizedType.getStorageType() != integerType) |
267 | return emitError( |
268 | "storage type in quantized type expected to match integer type" ); |
269 | |
270 | // Verify integrity of per-axis quantization information, if available. While |
271 | // the quantization type may appear in the input or the result, their tensor |
272 | // shapes are guaranteed to be identical at this point. |
273 | if (auto quantizedPerAxisType = |
274 | dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) { |
275 | return verifyPerAxisQuantization(*this, quantizedPerAxisType, |
276 | getInput().getType()); |
277 | } |
278 | |
279 | if (auto quantizedSunChannelType = |
280 | dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) { |
281 | return verifySubChannelQuantization(*this, quantizedSunChannelType, |
282 | getInput().getType()); |
283 | } |
284 | |
285 | // At this point the type is UniformQuantizedType |
286 | return success(); |
287 | } |
288 | |
289 | OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { |
290 | // Matches x -> quant.scast -> quant.scast -> y, replacing the second |
291 | // quant.scast with the value of x if the casts invert each other. |
292 | auto srcScastOp = getInput().getDefiningOp<StorageCastOp>(); |
293 | if (!srcScastOp || srcScastOp.getInput().getType() != getType()) |
294 | return {}; |
295 | return srcScastOp.getInput(); |
296 | } |
297 | |
298 | IntegerType StorageCastOp::getIntegerType() { |
299 | auto inputScalarType = getElementTypeOrSelf(getInput().getType()); |
300 | if (auto integerType = dyn_cast<IntegerType>(inputScalarType)) |
301 | return integerType; |
302 | |
303 | auto resultScalarType = getElementTypeOrSelf(getResult().getType()); |
304 | return cast<IntegerType>(resultScalarType); |
305 | } |
306 | |
307 | QuantizedType StorageCastOp::getQuantizedType() { |
308 | auto inputScalarType = getElementTypeOrSelf(getInput().getType()); |
309 | if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType)) |
310 | return quantizedType; |
311 | |
312 | auto resultScalarType = getElementTypeOrSelf(getResult().getType()); |
313 | return cast<QuantizedType>(resultScalarType); |
314 | } |
315 | |
316 | } // namespace quant |
317 | } // namespace mlir |
318 | |
319 | #define GET_OP_CLASSES |
320 | #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" |
321 | |