1 | //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// |
2 | // |
3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The patterns in this file are heavily inspired (and copied from) |
10 | // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the |
11 | // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N |
12 | // type conversions. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" |
17 | |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Transforms/OneToNTypeConversion.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace mlir::func; |
23 | |
24 | namespace { |
25 | |
26 | class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> { |
27 | public: |
28 | using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern; |
29 | |
30 | LogicalResult |
31 | matchAndRewrite(CallOp op, OpAdaptor adaptor, |
32 | OneToNPatternRewriter &rewriter) const override { |
33 | Location loc = op->getLoc(); |
34 | const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); |
35 | |
36 | // Nothing to do if the op doesn't have any non-identity conversions for its |
37 | // operands or results. |
38 | if (!adaptor.getOperandMapping().hasNonIdentityConversion() && |
39 | !resultMapping.hasNonIdentityConversion()) |
40 | return failure(); |
41 | |
42 | // Create new CallOp. |
43 | auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(), |
44 | adaptor.getFlatOperands()); |
45 | newOp->setAttrs(op->getAttrs()); |
46 | |
47 | rewriter.replaceOp(op, newOp->getResults(), resultMapping); |
48 | return success(); |
49 | } |
50 | }; |
51 | |
52 | class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> { |
53 | public: |
54 | using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern; |
55 | |
56 | LogicalResult |
57 | matchAndRewrite(FuncOp op, OpAdaptor adaptor, |
58 | OneToNPatternRewriter &rewriter) const override { |
59 | auto *typeConverter = getTypeConverter<OneToNTypeConverter>(); |
60 | |
61 | // Construct mapping for function arguments. |
62 | OneToNTypeMapping argumentMapping(op.getArgumentTypes()); |
63 | if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), |
64 | argumentMapping))) |
65 | return failure(); |
66 | |
67 | // Construct mapping for function results. |
68 | OneToNTypeMapping funcResultMapping(op.getResultTypes()); |
69 | if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), |
70 | funcResultMapping))) |
71 | return failure(); |
72 | |
73 | // Nothing to do if the op doesn't have any non-identity conversions for its |
74 | // operands or results. |
75 | if (!argumentMapping.hasNonIdentityConversion() && |
76 | !funcResultMapping.hasNonIdentityConversion()) |
77 | return failure(); |
78 | |
79 | // Update the function signature in-place. |
80 | auto newType = FunctionType::get(rewriter.getContext(), |
81 | argumentMapping.getConvertedTypes(), |
82 | funcResultMapping.getConvertedTypes()); |
83 | rewriter.modifyOpInPlace(op, [&] { op.setType(newType); }); |
84 | |
85 | // Update block signatures. |
86 | if (!op.isExternal()) { |
87 | Region *region = &op.getBody(); |
88 | Block *block = ®ion->front(); |
89 | rewriter.applySignatureConversion(block, argumentConversion&: argumentMapping); |
90 | } |
91 | |
92 | return success(); |
93 | } |
94 | }; |
95 | |
96 | class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> { |
97 | public: |
98 | using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern; |
99 | |
100 | LogicalResult |
101 | matchAndRewrite(ReturnOp op, OpAdaptor adaptor, |
102 | OneToNPatternRewriter &rewriter) const override { |
103 | // Nothing to do if there is no non-identity conversion. |
104 | if (!adaptor.getOperandMapping().hasNonIdentityConversion()) |
105 | return failure(); |
106 | |
107 | // Convert operands. |
108 | rewriter.modifyOpInPlace( |
109 | op, [&] { op->setOperands(adaptor.getFlatOperands()); }); |
110 | |
111 | return success(); |
112 | } |
113 | }; |
114 | |
115 | } // namespace |
116 | |
117 | namespace mlir { |
118 | |
119 | void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, |
120 | RewritePatternSet &patterns) { |
121 | patterns.add< |
122 | // clang-format off |
123 | ConvertTypesInFuncCallOp, |
124 | ConvertTypesInFuncFuncOp, |
125 | ConvertTypesInFuncReturnOp |
126 | // clang-format on |
127 | >(arg&: typeConverter, args: patterns.getContext()); |
128 | } |
129 | |
130 | } // namespace mlir |
131 | |