| 1 | //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "QuantDialectBytecode.h" |
| 10 | #include "TypeDetail.h" |
| 11 | |
| 12 | #include "mlir/Dialect/Quant/IR/Quant.h" |
| 13 | #include "mlir/Dialect/Quant/IR/QuantTypes.h" |
| 14 | #include "mlir/IR/BuiltinTypes.h" |
| 15 | #include "mlir/IR/PatternMatch.h" |
| 16 | #include "mlir/IR/TypeUtilities.h" |
| 17 | |
| 18 | #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" |
| 19 | |
| 20 | namespace mlir { |
| 21 | namespace quant { |
| 22 | |
| 23 | namespace { |
| 24 | |
| 25 | // Verify the integrity of per-axis quantization information, if present. |
| 26 | // |
| 27 | // - uniformQuantizedPerAxisType |
| 28 | // A quantized type with per-axis quantization. |
| 29 | // |
| 30 | // - containerType |
| 31 | // Original input or result type of the operation using the provided quantized |
| 32 | // type. Used to ensure that the quantized type appears within a tensor and |
| 33 | // that the tensor is compatible with per-axis quantization information. |
| 34 | // |
| 35 | LogicalResult verifyPerAxisQuantization( |
| 36 | Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType, |
| 37 | Type containerType) { |
| 38 | auto tensorType = dyn_cast<TensorType>(containerType); |
| 39 | if (!tensorType) |
| 40 | return op->emitError(message: "scalar types may not use per-axis quantization" ); |
| 41 | |
| 42 | if (!tensorType.hasRank()) |
| 43 | return success(); |
| 44 | |
| 45 | int32_t quantizedDimension = |
| 46 | uniformQuantizedPerAxisType.getQuantizedDimension(); |
| 47 | if ((int64_t)quantizedDimension >= tensorType.getRank()) |
| 48 | return op->emitError(message: "quantized dimension must be less than tensor rank" ); |
| 49 | |
| 50 | int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); |
| 51 | if (quantizedDimensionSize != ShapedType::kDynamic && |
| 52 | quantizedDimensionSize != |
| 53 | (int64_t)uniformQuantizedPerAxisType.getScales().size()) |
| 54 | return op->emitError( |
| 55 | message: "quantized dimension size does not match number of scales" ); |
| 56 | |
| 57 | return success(); |
| 58 | } |
| 59 | |
| 60 | // Verifies that the sub-channel quantization parameters are consistent with |
| 61 | // the given container type. The function checks the following: |
| 62 | // |
| 63 | // - The container type must be a ranked tensor type. |
| 64 | // - Each quantized dimension must be less than the rank of the tensor. |
| 65 | // - The size of each dimension at the quantized dimension must be divisible |
| 66 | // by the corresponding block size. |
| 67 | // - The scale dimension size at each axis index should match the tensor |
| 68 | // dimension at the index divided by the corresponding block size. |
| 69 | // |
| 70 | // The `uniformQuantizedSubChannelType` argument provides the sub-channel |
| 71 | // quantization parameters, and the `containerType` argument specifies the |
| 72 | // type of the container holding the quantized data. |
| 73 | // |
| 74 | LogicalResult verifySubChannelQuantization( |
| 75 | Operation *op, |
| 76 | UniformQuantizedSubChannelType uniformQuantizedSubChannelType, |
| 77 | Type containerType) { |
| 78 | auto tensorType = dyn_cast<TensorType>(containerType); |
| 79 | if (!tensorType) |
| 80 | return op->emitError(message: "scalar types may not use sub-channel quantization" ); |
| 81 | |
| 82 | if (!tensorType.hasRank()) |
| 83 | return op->emitError( |
| 84 | message: "tensor containing the sub-channel quantized type must be ranked" ); |
| 85 | |
| 86 | const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo = |
| 87 | uniformQuantizedSubChannelType.getBlockSizeInfo(); |
| 88 | auto shape = tensorType.getShape(); |
| 89 | |
| 90 | // The dimension size of scale for an axis which is not specified as quantized |
| 91 | // dimension should be 1. |
| 92 | SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1); |
| 93 | for (auto [quantizedDimension, blockSize] : blockSizeInfo) { |
| 94 | if (quantizedDimension >= tensorType.getRank()) |
| 95 | return op->emitError() |
| 96 | << "quantized dimension " << quantizedDimension |
| 97 | << " must be less than tensor rank " << tensorType.getRank(); |
| 98 | if (!tensorType.isDynamicDim(quantizedDimension) && |
| 99 | tensorType.getDimSize(quantizedDimension) % blockSize != 0) |
| 100 | return op->emitError() |
| 101 | << "tensor dimension size " |
| 102 | << tensorType.getDimSize(quantizedDimension) << " at axis " |
| 103 | << quantizedDimension |
| 104 | << " must be divisible by the corresponding block size " |
| 105 | << blockSize; |
| 106 | if (tensorType.isDynamicDim(quantizedDimension)) |
| 107 | expectedScaleShape[quantizedDimension] = ShapedType::kDynamic; |
| 108 | else |
| 109 | expectedScaleShape[quantizedDimension] = |
| 110 | tensorType.getDimSize(quantizedDimension) / blockSize; |
| 111 | } |
| 112 | |
| 113 | // Block sizes must be greater than 0 and divide the corresponding dimension |
| 114 | // size. While a block size b must be less than or equal to the corresponding |
| 115 | // dimension size d, this constraint is implicitly enforced by requiring that |
| 116 | // d % b == 0 when d != 0. |
| 117 | // |
| 118 | // However, a problem arises when d = 0. The divisibility constraint allows b |
| 119 | // to be any value, potentially violating the requirement that b <= d. |
| 120 | // Furthermore, if b is unspecified (implicitly equal to d), it violates the |
| 121 | // constraint that b > 0. |
| 122 | // |
| 123 | // Therefore, we explicitly disallow the case where d = 0 to maintain |
| 124 | // consistency and avoid these issues. |
| 125 | if (llvm::is_contained(tensorType.getShape(), 0)) { |
| 126 | return op->emitError() << "tensor dimension size of zero is not allowed " |
| 127 | "with sub-channel quantization" ; |
| 128 | } |
| 129 | |
| 130 | auto scaleShape = |
| 131 | uniformQuantizedSubChannelType.getScales().getType().getShape(); |
| 132 | if (scaleShape.size() != shape.size()) { |
| 133 | return op->emitError() << "Rank of scales " << scaleShape.size() |
| 134 | << " must match " |
| 135 | << "the rank of the tensor " << shape.size(); |
| 136 | } |
| 137 | |
| 138 | for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) { |
| 139 | if (expectedScaleShape[index] != ShapedType::kDynamic && |
| 140 | expectedScaleShape[index] != scaleShape[index]) |
| 141 | return op->emitError() << "dimension size " << scaleDim |
| 142 | << " of scales tensor at axis " << index |
| 143 | << " should match (tensor dimension at axis / " |
| 144 | "block sizes at axis) = " |
| 145 | << expectedScaleShape[index]; |
| 146 | } |
| 147 | |
| 148 | return success(); |
| 149 | } |
| 150 | |
| 151 | // Common verification logic for 'quant.dcast' and 'quant.qcast' ops. |
| 152 | // |
| 153 | // - quantizedType |
| 154 | // Quantized type used in the input ('quant.dcast') or result ('quant.qcast'), |
| 155 | // whether as a primitive type or in a tensor. |
| 156 | // |
| 157 | // - floatType |
| 158 | // Float type used in the input ('quant.qcast') or result ('quant.dcast'), |
| 159 | // whether as a primitive type or in a tensor. |
| 160 | // |
| 161 | // - containerType |
| 162 | // Type of original input or result. |
| 163 | // |
| 164 | LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, |
| 165 | FloatType floatType, Type containerType) { |
| 166 | if (quantizedType.getExpressedType() != floatType) |
| 167 | return op->emitError( |
| 168 | message: "expressed type in quantized type expected to match float type" ); |
| 169 | |
| 170 | // Verify integrity of per-axis quantization information, if present. |
| 171 | if (auto quantizedPerAxisType = |
| 172 | dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) { |
| 173 | return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType); |
| 174 | } |
| 175 | |
| 176 | if (auto quantizedSubChannelType = |
| 177 | dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) { |
| 178 | return verifySubChannelQuantization(op, quantizedSubChannelType, |
| 179 | containerType); |
| 180 | } |
| 181 | |
| 182 | // At this point the type is UniformQuantizedType |
| 183 | return success(); |
| 184 | } |
| 185 | |
| 186 | } // namespace |
| 187 | |
| 188 | //===----------------------------------------------------------------------===// |
| 189 | // Dialect |
| 190 | //===----------------------------------------------------------------------===// |
| 191 | |
| 192 | void QuantDialect::initialize() { |
| 193 | addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType, |
| 194 | UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>(); |
| 195 | addOperations< |
| 196 | #define GET_OP_LIST |
| 197 | #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" |
| 198 | >(); |
| 199 | detail::addBytecodeInterface(this); |
| 200 | } |
| 201 | |
| 202 | //===----------------------------------------------------------------------===// |
| 203 | // DequantizeCastOp |
| 204 | //===----------------------------------------------------------------------===// |
| 205 | |
| 206 | LogicalResult DequantizeCastOp::verify() { |
| 207 | return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), |
| 208 | getInput().getType()); |
| 209 | } |
| 210 | |
| 211 | OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) { |
| 212 | // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op |
| 213 | // with the value of x. Values x and y are guaranteed to be of the same type |
| 214 | // in this pattern. |
| 215 | auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>(); |
| 216 | if (!srcQcastOp) |
| 217 | return {}; |
| 218 | assert(srcQcastOp.getInput().getType() == getType()); |
| 219 | return srcQcastOp.getInput(); |
| 220 | } |
| 221 | |
| 222 | FloatType DequantizeCastOp::getFloatType() { |
| 223 | return cast<FloatType>(getElementTypeOrSelf(getResult().getType())); |
| 224 | } |
| 225 | |
| 226 | QuantizedType DequantizeCastOp::getQuantizedType() { |
| 227 | return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType())); |
| 228 | } |
| 229 | |
| 230 | //===----------------------------------------------------------------------===// |
| 231 | // QuantizeCastOp |
| 232 | //===----------------------------------------------------------------------===// |
| 233 | |
| 234 | LogicalResult QuantizeCastOp::verify() { |
| 235 | return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), |
| 236 | getInput().getType()); |
| 237 | } |
| 238 | |
| 239 | OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) { |
| 240 | // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op |
| 241 | // with the value of x if the casts invert each other. Contrary to the folding |
| 242 | // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values |
| 243 | // x and y are not guaranteed to be of the same type here, as they may use |
| 244 | // different quantization parameters. |
| 245 | auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>(); |
| 246 | if (!srcDcastOp || srcDcastOp.getInput().getType() != getType()) |
| 247 | return {}; |
| 248 | return srcDcastOp.getInput(); |
| 249 | } |
| 250 | |
| 251 | FloatType QuantizeCastOp::getFloatType() { |
| 252 | return cast<FloatType>(getElementTypeOrSelf(getInput().getType())); |
| 253 | } |
| 254 | |
| 255 | QuantizedType QuantizeCastOp::getQuantizedType() { |
| 256 | return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType())); |
| 257 | } |
| 258 | |
| 259 | //===----------------------------------------------------------------------===// |
| 260 | // StorageCastOp |
| 261 | //===----------------------------------------------------------------------===// |
| 262 | |
| 263 | LogicalResult StorageCastOp::verify() { |
| 264 | auto quantizedType = getQuantizedType(); |
| 265 | auto integerType = getIntegerType(); |
| 266 | if (quantizedType.getStorageType() != integerType) |
| 267 | return emitError( |
| 268 | "storage type in quantized type expected to match integer type" ); |
| 269 | |
| 270 | // Verify integrity of per-axis quantization information, if available. While |
| 271 | // the quantization type may appear in the input or the result, their tensor |
| 272 | // shapes are guaranteed to be identical at this point. |
| 273 | if (auto quantizedPerAxisType = |
| 274 | dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) { |
| 275 | return verifyPerAxisQuantization(*this, quantizedPerAxisType, |
| 276 | getInput().getType()); |
| 277 | } |
| 278 | |
| 279 | if (auto quantizedSunChannelType = |
| 280 | dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) { |
| 281 | return verifySubChannelQuantization(*this, quantizedSunChannelType, |
| 282 | getInput().getType()); |
| 283 | } |
| 284 | |
| 285 | // At this point the type is UniformQuantizedType |
| 286 | return success(); |
| 287 | } |
| 288 | |
| 289 | OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { |
| 290 | // Matches x -> quant.scast -> quant.scast -> y, replacing the second |
| 291 | // quant.scast with the value of x if the casts invert each other. |
| 292 | auto srcScastOp = getInput().getDefiningOp<StorageCastOp>(); |
| 293 | if (!srcScastOp || srcScastOp.getInput().getType() != getType()) |
| 294 | return {}; |
| 295 | return srcScastOp.getInput(); |
| 296 | } |
| 297 | |
| 298 | IntegerType StorageCastOp::getIntegerType() { |
| 299 | auto inputScalarType = getElementTypeOrSelf(getInput().getType()); |
| 300 | if (auto integerType = dyn_cast<IntegerType>(inputScalarType)) |
| 301 | return integerType; |
| 302 | |
| 303 | auto resultScalarType = getElementTypeOrSelf(getResult().getType()); |
| 304 | return cast<IntegerType>(resultScalarType); |
| 305 | } |
| 306 | |
| 307 | QuantizedType StorageCastOp::getQuantizedType() { |
| 308 | auto inputScalarType = getElementTypeOrSelf(getInput().getType()); |
| 309 | if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType)) |
| 310 | return quantizedType; |
| 311 | |
| 312 | auto resultScalarType = getElementTypeOrSelf(getResult().getType()); |
| 313 | return cast<QuantizedType>(resultScalarType); |
| 314 | } |
| 315 | |
| 316 | } // namespace quant |
| 317 | } // namespace mlir |
| 318 | |
| 319 | #define GET_OP_CLASSES |
| 320 | #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" |
| 321 | |