1 | //===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Strips quantized types from function headers. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/Quant/IR/Quant.h" |
18 | #include "mlir/Dialect/Quant/IR/QuantTypes.h" |
19 | #include "mlir/Dialect/Quant/Transforms/Passes.h" |
20 | #include "mlir/Dialect/Shape/IR/Shape.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/IR/Matchers.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | |
26 | namespace mlir { |
27 | namespace quant { |
28 | |
29 | #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES |
30 | #include "mlir/Dialect/Quant/Transforms/Passes.h.inc" |
31 | |
32 | namespace { |
33 | |
34 | class QuantizedTypeConverter : public TypeConverter { |
35 | |
36 | static Type convertQuantizedType(QuantizedType quantizedType) { |
37 | return quantizedType.getStorageType(); |
38 | } |
39 | |
40 | static Type convertTensorType(TensorType tensorType) { |
41 | if (auto quantizedType = |
42 | dyn_cast<QuantizedType>(tensorType.getElementType())) |
43 | return tensorType.clone(convertQuantizedType(quantizedType: quantizedType)); |
44 | return tensorType; |
45 | } |
46 | |
47 | static Value materializeConversion(OpBuilder &builder, Type type, |
48 | ValueRange inputs, Location loc) { |
49 | return builder.create<quant::StorageCastOp>(loc, type, |
50 | llvm::getSingleElement(inputs)); |
51 | } |
52 | |
53 | public: |
54 | explicit QuantizedTypeConverter() { |
55 | addConversion([](Type type) { return type; }); |
56 | addConversion(convertQuantizedType); |
57 | addConversion(convertTensorType); |
58 | |
59 | addSourceMaterialization(materializeConversion); |
60 | addTargetMaterialization(materializeConversion); |
61 | } |
62 | }; |
63 | |
64 | // Conversion pass |
65 | class StripFuncQuantTypes |
66 | : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> { |
67 | |
68 | public: |
69 | void runOnOperation() override { |
70 | |
71 | auto moduleOp = cast<ModuleOp>(getOperation()); |
72 | auto *context = &getContext(); |
73 | |
74 | QuantizedTypeConverter typeConverter; |
75 | ConversionTarget target(*context); |
76 | RewritePatternSet patterns(context); |
77 | |
78 | // Mark func.func, func.return, and func.call illegal if they contain any |
79 | // quantized types. |
80 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
81 | return typeConverter.isSignatureLegal(op.getFunctionType()) && |
82 | typeConverter.isLegal(&op.getBody()); |
83 | }); |
84 | target.addDynamicallyLegalOp<func::ReturnOp>( |
85 | [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); |
86 | target.addDynamicallyLegalOp<func::CallOp>( |
87 | [&](func::CallOp op) { return typeConverter.isLegal(op); }); |
88 | |
89 | // Register conversion patterns |
90 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
91 | patterns, typeConverter); |
92 | populateReturnOpTypeConversionPattern(patterns, converter: typeConverter); |
93 | populateCallOpTypeConversionPattern(patterns, converter: typeConverter); |
94 | |
95 | // Apply conversion |
96 | if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) |
97 | signalPassFailure(); |
98 | } |
99 | }; |
100 | |
101 | } // namespace |
102 | |
103 | } // namespace quant |
104 | } // namespace mlir |
105 | |