1//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Strips quantized types from function headers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15#include "mlir/Dialect/Quant/IR/Quant.h"
16#include "mlir/Dialect/Quant/IR/QuantTypes.h"
17#include "mlir/Dialect/Quant/Transforms/Passes.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace mlir {
22namespace quant {
23
24#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
25#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
26
27namespace {
28
29class QuantizedTypeConverter : public TypeConverter {
30
31 static Type convertQuantizedType(QuantizedType quantizedType) {
32 return quantizedType.getStorageType();
33 }
34
35 static Type convertTensorType(TensorType tensorType) {
36 if (auto quantizedType =
37 dyn_cast<QuantizedType>(Val: tensorType.getElementType()))
38 return tensorType.clone(elementType: convertQuantizedType(quantizedType));
39 return tensorType;
40 }
41
42 static Value materializeConversion(OpBuilder &builder, Type type,
43 ValueRange inputs, Location loc) {
44 return builder.create<quant::StorageCastOp>(location: loc, args&: type,
45 args: llvm::getSingleElement(C&: inputs));
46 }
47
48public:
49 explicit QuantizedTypeConverter() {
50 addConversion(callback: [](Type type) { return type; });
51 addConversion(callback&: convertQuantizedType);
52 addConversion(callback&: convertTensorType);
53
54 addSourceMaterialization(callback&: materializeConversion);
55 addTargetMaterialization(callback&: materializeConversion);
56 }
57};
58
59// Conversion pass
60class StripFuncQuantTypes
61 : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
62
63public:
64 void runOnOperation() override {
65
66 auto moduleOp = cast<ModuleOp>(Val: getOperation());
67 auto *context = &getContext();
68
69 QuantizedTypeConverter typeConverter;
70 ConversionTarget target(*context);
71 RewritePatternSet patterns(context);
72
73 // Mark func.func, func.return, and func.call illegal if they contain any
74 // quantized types.
75 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
76 return typeConverter.isSignatureLegal(ty: op.getFunctionType()) &&
77 typeConverter.isLegal(region: &op.getBody());
78 });
79 target.addDynamicallyLegalOp<func::ReturnOp>(
80 callback: [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
81 target.addDynamicallyLegalOp<func::CallOp>(
82 callback: [&](func::CallOp op) { return typeConverter.isLegal(op); });
83
84 // Register conversion patterns
85 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
86 patterns, converter: typeConverter);
87 populateReturnOpTypeConversionPattern(patterns, converter: typeConverter);
88 populateCallOpTypeConversionPattern(patterns, converter: typeConverter);
89
90 // Apply conversion
91 if (failed(Result: applyPartialConversion(op: moduleOp, target, patterns: std::move(patterns))))
92 signalPassFailure();
93 }
94};
95
96} // namespace
97
98} // namespace quant
99} // namespace mlir
100

source code of mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp