1//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The patterns in this file are heavily inspired (and copied from)
10// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
11// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
12// type conversions.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Transforms/OneToNTypeConversion.h"
20
21using namespace mlir;
22using namespace mlir::func;
23
24namespace {
25
26class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
27public:
28 using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
29
30 LogicalResult
31 matchAndRewrite(CallOp op, OpAdaptor adaptor,
32 OneToNPatternRewriter &rewriter) const override {
33 Location loc = op->getLoc();
34 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
35
36 // Nothing to do if the op doesn't have any non-identity conversions for its
37 // operands or results.
38 if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
39 !resultMapping.hasNonIdentityConversion())
40 return failure();
41
42 // Create new CallOp.
43 auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
44 adaptor.getFlatOperands());
45 newOp->setAttrs(op->getAttrs());
46
47 rewriter.replaceOp(op, newOp->getResults(), resultMapping);
48 return success();
49 }
50};
51
52class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
53public:
54 using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
55
56 LogicalResult
57 matchAndRewrite(FuncOp op, OpAdaptor adaptor,
58 OneToNPatternRewriter &rewriter) const override {
59 auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
60
61 // Construct mapping for function arguments.
62 OneToNTypeMapping argumentMapping(op.getArgumentTypes());
63 if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
64 argumentMapping)))
65 return failure();
66
67 // Construct mapping for function results.
68 OneToNTypeMapping funcResultMapping(op.getResultTypes());
69 if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
70 funcResultMapping)))
71 return failure();
72
73 // Nothing to do if the op doesn't have any non-identity conversions for its
74 // operands or results.
75 if (!argumentMapping.hasNonIdentityConversion() &&
76 !funcResultMapping.hasNonIdentityConversion())
77 return failure();
78
79 // Update the function signature in-place.
80 auto newType = FunctionType::get(rewriter.getContext(),
81 argumentMapping.getConvertedTypes(),
82 funcResultMapping.getConvertedTypes());
83 rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
84
85 // Update block signatures.
86 if (!op.isExternal()) {
87 Region *region = &op.getBody();
88 Block *block = &region->front();
89 rewriter.applySignatureConversion(block, argumentConversion&: argumentMapping);
90 }
91
92 return success();
93 }
94};
95
96class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
97public:
98 using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
99
100 LogicalResult
101 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
102 OneToNPatternRewriter &rewriter) const override {
103 // Nothing to do if there is no non-identity conversion.
104 if (!adaptor.getOperandMapping().hasNonIdentityConversion())
105 return failure();
106
107 // Convert operands.
108 rewriter.modifyOpInPlace(
109 op, [&] { op->setOperands(adaptor.getFlatOperands()); });
110
111 return success();
112 }
113};
114
115} // namespace
116
117namespace mlir {
118
119void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
120 RewritePatternSet &patterns) {
121 patterns.add<
122 // clang-format off
123 ConvertTypesInFuncCallOp,
124 ConvertTypesInFuncFuncOp,
125 ConvertTypesInFuncReturnOp
126 // clang-format on
127 >(arg&: typeConverter, args: patterns.getContext());
128}
129
130} // namespace mlir
131

source code of mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp