1//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Quant/IR/Quant.h"
17#include "mlir/Dialect/Quant/IR/QuantTypes.h"
18#include "mlir/Dialect/Quant/Transforms/Passes.h"
19#include "mlir/Dialect/Shape/IR/Shape.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Transforms/DialectConversion.h"
24
25namespace mlir {
26namespace quant {
27
28#define GEN_PASS_DEF_LOWERQUANTOPS
29#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30
31namespace {
32
33// If 'inputType' is a tensor, return its element type. If it is a scalar,
34// return it as is.
35Type getScalarType(Type inputType) {
36 if (auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
38 return inputType;
39}
40
41// Return the shape of an input value as a list of attributes (static
42// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43// list is returned. If 'input' is a tensor, its shape is returned.
44SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45 Location loc, Value input) {
46 if (isa<TensorType>(input.getType()))
47 return tensor::getMixedSizes(builder, loc, value: input);
48 return {};
49}
50
51// If 'referenceType' is a scalar, return 'elementType' as is. If
52// 'referenceType' is a tensor, return another tensor with the same shape and
53// elements of type 'elementType'.
54Type getScalarOrTensorType(Type elementType, Type referenceType) {
55 if (auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
57 return elementType;
58}
59
60// Return a constant with the given value. If 'referenceType' is a tensor, a
61// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62// scalar, 'referenceShape' is ignored and a scalar constant is returned.
63Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64 Type referenceType,
65 ArrayRef<OpFoldResult> referenceShape) {
66 // If the result type is a scalar, return the unmodified scalar constant.
67 auto tensorType = dyn_cast<TensorType>(referenceType);
68 if (!tensorType) {
69 assert(referenceShape.empty());
70 return scalar;
71 }
72
73 // Create tensor splat
74 auto tensorConstant =
75 builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76 return tensorConstant;
77}
78
79// Reshape an unranked tensor into a 1D ranked tensor.
80//
81// - input
82// Unranked tensor.
83//
84// Return values:
85//
86// - flatInput
87// 1D ranked, dynamically shaped tensor.
88//
89// - inputShape
90// 1D extent tensor containing the shape of the original unranked input.
91//
92std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93 Value input) {
94 // Get unranked input shape and total size
95 auto *context = builder.getContext();
96 auto shapeType = shape::getExtentTensorType(context);
97 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98 Value inputSize = builder.create<shape::NumElementsOp>(
99 loc, builder.getIndexType(), inputShape);
100
101 // Turn input size into 1D tensor
102 auto flatShapeType = shape::getExtentTensorType(context, 1);
103 auto flatInputShape =
104 builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
105
106 // Reshape input tensor into 1D
107 auto inputType = cast<UnrankedTensorType>(input.getType());
108 auto elementType = inputType.getElementType();
109 auto flatInputType =
110 RankedTensorType::get({ShapedType::kDynamic}, elementType);
111 auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
112 flatInputShape);
113 return std::make_pair(flatInput, inputShape);
114}
115
116// Reshape an unranked tensor into a 3D ranked tensor where the central
117// dimension of the result tensor corresponds to dimension 'axis' of the input
118// tensor.
119//
120// - input
121// Unranked tensor.
122//
123// - axis
124// Index of the input dimension around which other input dimiensions will be
125// collapsed.
126//
127// - axisSize
128// Size of input dimension 'axis'.
129//
130// Return values:
131//
132// - flatInput
133// 3D ranked tensor of shape [?, axisSize, ?].
134//
135// - inputShape
136// 1D extent tensor containing the shape of the original unranked input.
137//
138std::pair<Value, Value>
139flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140 int64_t axis, int64_t axisSize) {
141 // Get full tensor shape
142 auto *context = builder.getContext();
143 auto indexType = builder.getIndexType();
144 auto shapeType = shape::getExtentTensorType(context);
145 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
146
147 // Get shape and sizes on left and right of axis
148 auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
149 auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
150 auto shapeLeft =
151 builder
152 .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
153 inputShape, axisValue)
154 .getResult(0);
155 auto sizeLeft =
156 builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
157 auto shapeRight =
158 builder
159 .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
160 inputShape, axisNextValue)
161 .getResult(1);
162 auto sizeRight =
163 builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
164
165 // Compute flat input shape as a 3-element 1D tensor
166 auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
167 auto flatShapeType = shape::getExtentTensorType(context, 3);
168 auto flatInputShape = builder.create<tensor::FromElementsOp>(
169 loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
170
171 // Reshape input to 3D tensor
172 auto inputType = cast<UnrankedTensorType>(input.getType());
173 auto elementType = inputType.getElementType();
174 auto flatInputType = RankedTensorType::get(
175 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
176 auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
177 flatInputShape);
178
179 return std::make_pair(flatInput, inputShape);
180}
181
182// Reshape an input tensor into its original unranked shape.
183//
184// - input
185// Ranked tensor.
186//
187// - inputShape
188// 1D extent tensor.
189//
190Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
191 Value inputShape) {
192 auto inputType = cast<RankedTensorType>(input.getType());
193 auto elementType = inputType.getElementType();
194 auto unrankedType = UnrankedTensorType::get(elementType);
195 return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
196 inputShape);
197}
198
199// Create a tensor constant containing all scales in a per-channel quantized
200// type. Example:
201//
202// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
203//
204// produces
205//
206// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
207//
208Value materializePerChannelScales(OpBuilder &builder, Location loc,
209 UniformQuantizedPerAxisType quantizedType) {
210 auto scales = quantizedType.getScales();
211 auto expressedType = quantizedType.getExpressedType();
212 auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
213 return builder.getFloatAttr(expressedType, scale);
214 });
215 auto tensorType =
216 RankedTensorType::get({(int64_t)scales.size()}, expressedType);
217 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
218 return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
219}
220
221// Create a tensor constant containing all zero points in a per-channel
222// quantized type. Example:
223//
224// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
225//
226// produces
227//
228// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
229//
230Value materializePerChannelZeroPoints(
231 OpBuilder &builder, Location loc,
232 UniformQuantizedPerAxisType quantizedType) {
233 auto zeroPoints = quantizedType.getZeroPoints();
234 auto storageType = quantizedType.getStorageType();
235 auto zeroPointAttrs =
236 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
237 return builder.getIntegerAttr(storageType, zeroPoint);
238 });
239 auto tensorType =
240 RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
241 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
242 return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
243}
244
245// Create a tensor constant containing all scales in a sub-channel quantized
246// type. Example:
247//
248// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
249//
250// produces
251//
252// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
253//
254Value materializeSubChannelScales(
255 OpBuilder &builder, Location loc,
256 UniformQuantizedSubChannelType quantizedType) {
257 auto scales = quantizedType.getScales();
258 auto expressedType = quantizedType.getExpressedType();
259 auto scaleAttrs = llvm::map_to_vector(
260 scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
261 return builder.getFloatAttr(expressedType, scale);
262 });
263 auto tensorType =
264 RankedTensorType::get(scales.getType().getShape(), expressedType);
265 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
266 return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
267}
268
269// Create a tensor constant containing all zero points in a sub-channel
270// quantized type. Example:
271//
272// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
273//
274// produces
275//
276// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
277//
278Value materializeSubChannelZeroPoints(
279 OpBuilder &builder, Location loc,
280 UniformQuantizedSubChannelType quantizedType) {
281 auto zeroPoints = quantizedType.getZeroPoints();
282 auto storageType = quantizedType.getStorageType();
283 auto zeroPointAttrs = llvm::map_to_vector(
284 zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
285 return builder.getIntegerAttr(storageType, zeroPoint);
286 });
287 auto tensorType =
288 RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
289 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
290 return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
291}
292
293// Clamp the given scalar or tensor input using the storage bounds encoded in
294// the given quantized type, if present.
295//
296// - input
297// Scalar or ranked tensor input. The element type must match the storage type
298// of 'quantizedType'.
299//
300// - inputShape
301// If 'input' is a tensor, combination of attributes/values representing its
302// static/dynamic dimensions. If 'input' is a scalar, empty list.
303//
304// - quantizedType
305// Per-axis or per-channel quantized type.
306Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
307 ArrayRef<OpFoldResult> inputShape,
308 QuantizedType quantizedType) {
309 // If quantized type does not narrow down the storage type range, there is
310 // nothing to do.
311 if (!quantizedType.hasStorageTypeBounds())
312 return input;
313
314 // Materialize bounds
315 auto inputType = input.getType();
316 auto storageType = quantizedType.getStorageType();
317 auto storageMinScalar = builder.create<arith::ConstantIntOp>(
318 loc, quantizedType.getStorageTypeMin(), storageType);
319 auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
320 loc, quantizedType.getStorageTypeMax(), storageType);
321 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
322 inputType, inputShape);
323 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
324 inputType, inputShape);
325
326 // Clamp
327 if (quantizedType.isSigned()) {
328 input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
329 input = builder.create<arith::MinSIOp>(loc, input, storageMax);
330 } else {
331 input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
332 input = builder.create<arith::MinUIOp>(loc, input, storageMax);
333 }
334 return input;
335}
336
337// Emit op 'arith.fptosi' or 'arith.fptoui'.
338Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
339 Type resultType, bool isSigned) {
340 if (isSigned)
341 return builder.create<arith::FPToSIOp>(loc, resultType, input);
342 return builder.create<arith::FPToUIOp>(loc, resultType, input);
343}
344
345// Emit op 'arith.sitofp' or 'arith.uitofp'.
346Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
347 Type resultType, bool isSigned) {
348 if (isSigned)
349 return builder.create<arith::SIToFPOp>(loc, resultType, input);
350 return builder.create<arith::UIToFPOp>(loc, resultType, input);
351}
352
353// Quantize a scalar or ranked tensor value. The stored value is clamped using
354// the storage bounds encoded in the given quantized type.
355//
356// See function 'convertRanked()' below for a description of the arguments.
357Value quantizeValue(OpBuilder &builder, Location loc, Value input,
358 ArrayRef<OpFoldResult> inputShape, Value scale,
359 Value zeroPoint, QuantizedType quantizedType) {
360 // Convert scale to tensor if necessary
361 auto inputType = input.getType();
362 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
363
364 // Scale input
365 auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
366
367 // Skip unnecessary computations if no zero point is given
368 Value storedValueFloat = scaledValue;
369 if (!matchPattern(zeroPoint, m_Zero())) {
370 // Convert zero point to tensor if necessary
371 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
372 inputShape);
373
374 // Convert zero point from storage to expressed type
375 zeroPoint = convertIntegerToFloat(builder, loc, input: zeroPoint, resultType: scale.getType(),
376 isSigned: quantizedType.isSigned());
377
378 // Add zero point to stored value
379 storedValueFloat =
380 builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
381 }
382
383 // Convert stored value to storage type
384 auto storageScalarOrTensorType =
385 getScalarOrTensorType(elementType: quantizedType.getStorageType(), referenceType: inputType);
386 auto storedValueInt = convertFloatToInteger(builder, loc, input: storedValueFloat,
387 resultType: storageScalarOrTensorType,
388 isSigned: quantizedType.isSigned());
389
390 // Clamp stored value it if the storage type is bound
391 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
392 inputShape, quantizedType);
393 return storedValueClamped;
394}
395
396// Dequantize a scalar or ranked tensor input.
397//
398// See function 'convertRanked()' below for a description of the arguments.
399Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
400 ArrayRef<OpFoldResult> inputShape, Value scale,
401 Value zeroPoint, QuantizedType quantizedType) {
402 // Convert scale to tensor if necessary
403 auto inputType = input.getType();
404 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
405
406 // Convert stored value to float
407 auto result = convertIntegerToFloat(builder, loc, input, resultType: scale.getType(),
408 isSigned: quantizedType.isSigned());
409
410 // Skip unnecessary computations if no zero point is given
411 if (!matchPattern(zeroPoint, m_Zero())) {
412 // Convert zero point to tensor if necessary
413 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
414 inputShape);
415
416 // Convert zero point from storage to expressed type
417 zeroPoint = convertIntegerToFloat(builder, loc, input: zeroPoint, resultType: scale.getType(),
418 isSigned: quantizedType.isSigned());
419
420 // Subtract zero point to stored value
421 result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
422 }
423
424 // Multiply by scale
425 result = builder.create<arith::MulFOp>(loc, result, scale);
426 return result;
427}
428
429// Convert a scalar or ranked tensor input with the given scale and zero point
430// values.
431//
432// - input
433// Scalar or ranked tensor value.
434//
435// - inputShape
436// If 'input' is a tensor, combination or attributes/values representing its
437// static/dynamic dimensions. If 'input' is a scalar, empty list.
438//
439// - scale
440// Scale as a floating-point scalar value.
441//
442// - zeroPoint
443// Zero point as an integer scalar value.
444//
445// - quantizedType
446// Scalar quantized type of the result ('quant.qcast') or of the input
447// ('quant.dcast').
448//
449Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
450 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
451 Value zeroPoint, QuantizedType quantizedType) {
452 if (isa<QuantizeCastOp>(op))
453 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
454 quantizedType);
455 if (isa<DequantizeCastOp>(op))
456 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
457 quantizedType);
458 llvm_unreachable("unexpected quant op");
459}
460
461// Convert an operation using per-layer quantization with a scalar or ranked
462// tensor input.
463//
464// - op
465// 'quant.dcast' or 'quant.qcast' op.
466//
467// - input
468// Scalar or ranked tensor.
469//
470// - quantizedType
471// Per-layer quantized type.
472//
473Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
474 Value input, UniformQuantizedType quantizedType) {
475 // Create scale and zero point constants
476 auto expressedType = quantizedType.getExpressedType();
477 auto storageType = quantizedType.getStorageType();
478 auto scaleAttr =
479 builder.getFloatAttr(expressedType, quantizedType.getScale());
480 auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
481 auto zeroPointAttr =
482 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
483 auto zeroPoint =
484 builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
485
486 auto inputShape = getScalarOrTensorShape(builder, loc, input);
487 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
488 quantizedType);
489}
490
491// Convert an operation using per-layer quantization.
492//
493// - op
494// 'quant.dcast' or 'quant.qcast' op.
495//
496// - input
497// Scalar, ranked tensor, or unranked tensor.
498//
499// - quantizedType
500// Per-layer quantized type.
501//
502Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503 Value input, UniformQuantizedType quantizedType) {
504 // Flatten input if unranked
505 bool isUnranked = isa<UnrankedTensorType>(input.getType());
506 Value inputShape;
507 if (isUnranked)
508 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
509
510 // Process ranked tensor
511 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
512
513 // Restore original shape if unranked
514 if (isUnranked)
515 result = restoreUnrankedTensorShape(builder, loc, input: result, inputShape);
516
517 return result;
518}
519
520// Convert an operation using per-channel quantization and a scalar or ranked
521// tensor as an input.
522//
523// - op
524// 'quant.dcast' or 'quant.qcast' op.
525//
526// - input
527// Scalar or ranked tensor.
528//
529// - quantizedType
530// Per-channel quantized type.
531//
532Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
533 Value input,
534 UniformQuantizedPerAxisType quantizedType,
535 int64_t channelAxis) {
536 auto *context = builder.getContext();
537
538 auto inputType = cast<RankedTensorType>(input.getType());
539 auto inputRank = inputType.getRank();
540
541 auto scales = materializePerChannelScales(builder, loc, quantizedType);
542 auto zeroPoints =
543 materializePerChannelZeroPoints(builder, loc, quantizedType);
544
545 auto elementType = isa<FloatType>(inputType.getElementType())
546 ? quantizedType.getStorageType()
547 : quantizedType.getExpressedType();
548 auto initShape = tensor::getMixedSizes(builder, loc, value: input);
549 Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
550
551 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552 utils::IteratorType::parallel);
553 auto channelAxisAffineMap = AffineMap::get(
554 inputRank, 0, builder.getAffineDimExpr(position: channelAxis), context);
555 SmallVector<AffineMap> indexingMaps{
556 builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557 channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558 auto result = builder
559 .create<linalg::GenericOp>(
560 loc,
561 init.getType(), // resultType
562 ValueRange{input, scales, zeroPoints}, // inputs
563 ValueRange{init}, // outputs
564 indexingMaps, iteratorTypes,
565 [&](OpBuilder &builder, Location loc, ValueRange args) {
566 assert(args.size() == 4);
567 auto input = args[0];
568 auto scale = args[1];
569 auto zeroPoint = args[2];
570
571 auto result =
572 convertRanked(builder, loc, op, input, {}, scale,
573 zeroPoint, quantizedType);
574
575 builder.create<linalg::YieldOp>(loc, result);
576 })
577 .getResult(0);
578
579 return result;
580}
581
582// Convert an operation using per-channel quantization.
583//
584// - op
585// 'quant.dcast' or 'quant.qcast' op.
586//
587// - input
588// Scalar, ranked tensor, or unranked tensor.
589//
590// - quantizedType
591// Per-channel quantized type.
592//
593Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
594 Value input,
595 UniformQuantizedPerAxisType quantizedType) {
596 // Flatten unranked tensor into a 3D ranked tensor if necessary
597 bool isUnranked = isa<UnrankedTensorType>(input.getType());
598 int64_t channelAxis = quantizedType.getQuantizedDimension();
599 int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
600 Value inputShape;
601 if (isUnranked) {
602 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
603 builder, loc, input, channelAxis, channelAxisSize);
604 channelAxis = 1;
605 }
606
607 // Work on a ranked tensor
608 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
609 channelAxis);
610
611 // Restore original tensor shape if unranked
612 if (isUnranked)
613 result = restoreUnrankedTensorShape(builder, loc, input: result, inputShape);
614
615 return result;
616}
617
618// Convert an operation using sub-channel quantization.
619//
620// - op
621// 'quant.dcast' or 'quant.qcast' op.
622//
623// - input
624// Scalar, ranked tensor.
625//
626// - quantizedType
627// Sub-channel quantized type.
628//
629Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
630 Value input,
631 UniformQuantizedSubChannelType quantizedType) {
632 auto *context = builder.getContext();
633
634 auto inputType = cast<RankedTensorType>(input.getType());
635 auto inputRank = inputType.getRank();
636
637 auto scales = materializeSubChannelScales(builder, loc, quantizedType);
638 auto zeroPoints =
639 materializeSubChannelZeroPoints(builder, loc, quantizedType);
640
641 auto elementType = isa<FloatType>(inputType.getElementType())
642 ? quantizedType.getStorageType()
643 : quantizedType.getExpressedType();
644 auto initShape = tensor::getMixedSizes(builder, loc, value: input);
645 Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
646
647 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
648 utils::IteratorType::parallel);
649 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
650 quantizedType.getBlockSizeInfo();
651 SmallVector<AffineExpr> affineExprs(inputRank,
652 builder.getAffineConstantExpr(0));
653 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
654 affineExprs[quantizedDimension] =
655 builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
656 }
657 auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
658 SmallVector<AffineMap> indexingMaps{
659 builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
660 builder.getMultiDimIdentityMap(inputRank)};
661 auto result = builder
662 .create<linalg::GenericOp>(
663 loc,
664 init.getType(), // resultType
665 ValueRange{input, scales, zeroPoints}, // inputs
666 ValueRange{init}, // outputs
667 indexingMaps, iteratorTypes,
668 [&](OpBuilder &builder, Location loc, ValueRange args) {
669 assert(args.size() == 4);
670 auto input = args[0];
671 auto scale = args[1];
672 auto zeroPoint = args[2];
673
674 auto result =
675 convertRanked(builder, loc, op, input, {}, scale,
676 zeroPoint, quantizedType);
677
678 builder.create<linalg::YieldOp>(loc, result);
679 })
680 .getResult(0);
681
682 return result;
683}
684
685// Convert a quantization operation.
686//
687// - op
688// 'quant.dcast' or 'quant.qcast' op.
689//
690// - input
691// Scalar, ranked tensor, or unranked tensor. The element type matches
692// the storage type (quant.dcast) or expressed type (quant.qcast) of
693// 'quantizedType'.
694//
695// - quantizedType
696// Per-layer or per-channel quantized type.
697//
698Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
699 Value input, Type quantizedType) {
700 if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
701 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
702
703 if (auto uniformQuantizedPerAxisType =
704 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
705 return convertPerChannel(builder, loc, op, input,
706 uniformQuantizedPerAxisType);
707
708 if (auto uniformQuantizedSubChannelType =
709 dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
710 return convertSubChannel(builder, loc, op, input,
711 uniformQuantizedSubChannelType);
712
713 llvm_unreachable("unexpected quantized type");
714}
715
716// Lowering pattern for 'quant.dcast'
717struct DequantizeCastOpConversion
718 : public OpConversionPattern<quant::DequantizeCastOp> {
719 using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
720
721 LogicalResult
722 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
723 ConversionPatternRewriter &rewriter) const override {
724 auto loc = op.getLoc();
725 auto input = op.getInput();
726 auto quantizedType =
727 cast<QuantizedType>(getScalarType(op.getInput().getType()));
728
729 // Convert quantized input to storage type
730 auto storageScalarOrTensorType =
731 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
732 input = rewriter.create<quant::StorageCastOp>(
733 loc, storageScalarOrTensorType, input);
734
735 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
736
737 rewriter.replaceOp(op, result);
738 return success();
739 }
740};
741
742// Lowering pattern for 'quant.qcast'
743struct QuantizeCastOpConversion
744 : public OpConversionPattern<quant::QuantizeCastOp> {
745 using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
746
747 LogicalResult
748 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
749 ConversionPatternRewriter &rewriter) const override {
750 auto loc = op.getLoc();
751 auto input = op.getInput();
752 auto quantizedType = getScalarType(op.getResult().getType());
753
754 // Flatten unranked tensor input
755 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
756
757 // Cast stored value to result quantized value
758 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
759 op, op.getResult().getType(), result);
760 return success();
761 }
762};
763
764struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
765 void runOnOperation() override {
766 RewritePatternSet patterns(&getContext());
767 populateLowerQuantOpsPatterns(patterns);
768
769 ConversionTarget target(getContext());
770 target.addLegalOp<quant::StorageCastOp>();
771 target.addIllegalDialect<quant::QuantDialect>();
772 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
773 shape::ShapeDialect, tensor::TensorDialect>();
774
775 if (failed(applyPartialConversion(getOperation(), target,
776 std::move(patterns))))
777 signalPassFailure();
778 }
779};
780
781} // namespace
782
783void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
784 patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
785 patterns.getContext());
786}
787
788} // namespace quant
789} // namespace mlir
790

source code of mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp