1//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Func dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Func/Utils/Utils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/PatternMatch.h"
17#include "llvm/ADT/SmallVector.h"
18
19using namespace mlir;
20
21FailureOr<func::FuncOp>
22func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
23 ArrayRef<unsigned> newArgsOrder,
24 ArrayRef<unsigned> newResultsOrder) {
25 // Generate an empty new function operation with the same name as the
26 // original.
27 assert(funcOp.getNumArguments() == newArgsOrder.size() &&
28 "newArgsOrder must match the number of arguments in the function");
29 assert(funcOp.getNumResults() == newResultsOrder.size() &&
30 "newResultsOrder must match the number of results in the function");
31
32 if (!funcOp.getBody().hasOneBlock())
33 return rewriter.notifyMatchFailure(
34 arg&: funcOp, msg: "expected function to have exactly one block");
35
36 ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
37 ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
38 SmallVector<Type> newInputTypes, newOutputTypes;
39 SmallVector<Location> locs;
40 for (unsigned int idx : newArgsOrder) {
41 newInputTypes.push_back(Elt: origInputTypes[idx]);
42 locs.push_back(Elt: funcOp.getArgument(idx: newArgsOrder[idx]).getLoc());
43 }
44 for (unsigned int idx : newResultsOrder)
45 newOutputTypes.push_back(Elt: origOutputTypes[idx]);
46 rewriter.setInsertionPoint(funcOp);
47 auto newFuncOp = rewriter.create<func::FuncOp>(
48 location: funcOp.getLoc(), args: funcOp.getName(),
49 args: rewriter.getFunctionType(inputs: newInputTypes, results: newOutputTypes));
50
51 Region &newRegion = newFuncOp.getBody();
52 rewriter.createBlock(parent: &newRegion, insertPt: newRegion.begin(), argTypes: newInputTypes, locs);
53 newFuncOp.setVisibility(funcOp.getVisibility());
54 newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
55
56 // Map the arguments of the original function to the new function in
57 // the new order and adjust the attributes accordingly.
58 IRMapping operandMapper;
59 SmallVector<DictionaryAttr> argAttrs, resultAttrs;
60 funcOp.getAllArgAttrs(result&: argAttrs);
61 for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
62 operandMapper.map(from: funcOp.getArgument(idx: newArgsOrder[i]),
63 to: newFuncOp.getArgument(idx: i));
64 newFuncOp.setArgAttrs(index: i, attributes: argAttrs[newArgsOrder[i]]);
65 }
66 funcOp.getAllResultAttrs(result&: resultAttrs);
67 for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
68 newFuncOp.setResultAttrs(index: i, attributes: resultAttrs[newResultsOrder[i]]);
69
70 // Clone the operations from the original function to the new function.
71 rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
72 for (Operation &op : funcOp.getOps())
73 rewriter.clone(op, mapper&: operandMapper);
74
75 // Handle the return operation.
76 auto returnOp = cast<func::ReturnOp>(
77 Val: newFuncOp.getFunctionBody().begin()->getTerminator());
78 SmallVector<Value> newReturnValues;
79 for (unsigned int idx : newResultsOrder)
80 newReturnValues.push_back(Elt: returnOp.getOperand(i: idx));
81 rewriter.setInsertionPoint(returnOp);
82 auto newReturnOp =
83 rewriter.create<func::ReturnOp>(location: newFuncOp.getLoc(), args&: newReturnValues);
84 newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
85 rewriter.eraseOp(op: returnOp);
86
87 rewriter.eraseOp(op: funcOp);
88
89 return newFuncOp;
90}
91
92func::CallOp
93func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
94 ArrayRef<unsigned> newArgsOrder,
95 ArrayRef<unsigned> newResultsOrder) {
96 assert(
97 callOp.getNumOperands() == newArgsOrder.size() &&
98 "newArgsOrder must match the number of operands in the call operation");
99 assert(
100 callOp.getNumResults() == newResultsOrder.size() &&
101 "newResultsOrder must match the number of results in the call operation");
102 SmallVector<Value> newArgsOrderValues;
103 for (unsigned int argIdx : newArgsOrder)
104 newArgsOrderValues.push_back(Elt: callOp.getOperand(i: argIdx));
105 SmallVector<Type> newResultTypes;
106 for (unsigned int resIdx : newResultsOrder)
107 newResultTypes.push_back(Elt: callOp.getResult(i: resIdx).getType());
108
109 // Replace the kernel call operation with a new one that has the
110 // reordered arguments.
111 rewriter.setInsertionPoint(callOp);
112 auto newCallOp = rewriter.create<func::CallOp>(
113 location: callOp.getLoc(), args: callOp.getCallee(), args&: newResultTypes, args&: newArgsOrderValues);
114 newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
115 for (auto &&[newIndex, origIndex] : llvm::enumerate(First&: newResultsOrder))
116 rewriter.replaceAllUsesWith(from: callOp.getResult(i: origIndex),
117 to: newCallOp.getResult(i: newIndex));
118 rewriter.eraseOp(op: callOp);
119
120 return newCallOp;
121}
122

source code of mlir/lib/Dialect/Func/Utils/Utils.cpp