1//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Strips quantized types from function headers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Quant/IR/Quant.h"
18#include "mlir/Dialect/Quant/IR/QuantTypes.h"
19#include "mlir/Dialect/Quant/Transforms/Passes.h"
20#include "mlir/Dialect/Shape/IR/Shape.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Transforms/DialectConversion.h"
25
26namespace mlir {
27namespace quant {
28
29#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
30#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
31
32namespace {
33
34class QuantizedTypeConverter : public TypeConverter {
35
36 static Type convertQuantizedType(QuantizedType quantizedType) {
37 return quantizedType.getStorageType();
38 }
39
40 static Type convertTensorType(TensorType tensorType) {
41 if (auto quantizedType =
42 dyn_cast<QuantizedType>(tensorType.getElementType()))
43 return tensorType.clone(convertQuantizedType(quantizedType: quantizedType));
44 return tensorType;
45 }
46
47 static Value materializeConversion(OpBuilder &builder, Type type,
48 ValueRange inputs, Location loc) {
49 return builder.create<quant::StorageCastOp>(loc, type,
50 llvm::getSingleElement(inputs));
51 }
52
53public:
54 explicit QuantizedTypeConverter() {
55 addConversion([](Type type) { return type; });
56 addConversion(convertQuantizedType);
57 addConversion(convertTensorType);
58
59 addSourceMaterialization(materializeConversion);
60 addTargetMaterialization(materializeConversion);
61 }
62};
63
64// Conversion pass
65class StripFuncQuantTypes
66 : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
67
68public:
69 void runOnOperation() override {
70
71 auto moduleOp = cast<ModuleOp>(getOperation());
72 auto *context = &getContext();
73
74 QuantizedTypeConverter typeConverter;
75 ConversionTarget target(*context);
76 RewritePatternSet patterns(context);
77
78 // Mark func.func, func.return, and func.call illegal if they contain any
79 // quantized types.
80 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
81 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
82 typeConverter.isLegal(&op.getBody());
83 });
84 target.addDynamicallyLegalOp<func::ReturnOp>(
85 [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
86 target.addDynamicallyLegalOp<func::CallOp>(
87 [&](func::CallOp op) { return typeConverter.isLegal(op); });
88
89 // Register conversion patterns
90 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
91 patterns, typeConverter);
92 populateReturnOpTypeConversionPattern(patterns, converter: typeConverter);
93 populateCallOpTypeConversionPattern(patterns, converter: typeConverter);
94
95 // Apply conversion
96 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
97 signalPassFailure();
98 }
99};
100
101} // namespace
102
103} // namespace quant
104} // namespace mlir
105

source code of mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp