| 1 | //===- Utils.cpp - Utilities to support the Func dialect ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements utilities for the Func dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Func/Utils/Utils.h" |
| 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 15 | #include "mlir/IR/IRMapping.h" |
| 16 | #include "mlir/IR/PatternMatch.h" |
| 17 | #include "llvm/ADT/SmallVector.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | |
| 21 | FailureOr<func::FuncOp> |
| 22 | func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp, |
| 23 | ArrayRef<unsigned> newArgsOrder, |
| 24 | ArrayRef<unsigned> newResultsOrder) { |
| 25 | // Generate an empty new function operation with the same name as the |
| 26 | // original. |
| 27 | assert(funcOp.getNumArguments() == newArgsOrder.size() && |
| 28 | "newArgsOrder must match the number of arguments in the function" ); |
| 29 | assert(funcOp.getNumResults() == newResultsOrder.size() && |
| 30 | "newResultsOrder must match the number of results in the function" ); |
| 31 | |
| 32 | if (!funcOp.getBody().hasOneBlock()) |
| 33 | return rewriter.notifyMatchFailure( |
| 34 | arg&: funcOp, msg: "expected function to have exactly one block" ); |
| 35 | |
| 36 | ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs(); |
| 37 | ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults(); |
| 38 | SmallVector<Type> newInputTypes, newOutputTypes; |
| 39 | SmallVector<Location> locs; |
| 40 | for (unsigned int idx : newArgsOrder) { |
| 41 | newInputTypes.push_back(Elt: origInputTypes[idx]); |
| 42 | locs.push_back(Elt: funcOp.getArgument(idx: newArgsOrder[idx]).getLoc()); |
| 43 | } |
| 44 | for (unsigned int idx : newResultsOrder) |
| 45 | newOutputTypes.push_back(Elt: origOutputTypes[idx]); |
| 46 | rewriter.setInsertionPoint(funcOp); |
| 47 | auto newFuncOp = rewriter.create<func::FuncOp>( |
| 48 | location: funcOp.getLoc(), args: funcOp.getName(), |
| 49 | args: rewriter.getFunctionType(inputs: newInputTypes, results: newOutputTypes)); |
| 50 | |
| 51 | Region &newRegion = newFuncOp.getBody(); |
| 52 | rewriter.createBlock(parent: &newRegion, insertPt: newRegion.begin(), argTypes: newInputTypes, locs); |
| 53 | newFuncOp.setVisibility(funcOp.getVisibility()); |
| 54 | newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary()); |
| 55 | |
| 56 | // Map the arguments of the original function to the new function in |
| 57 | // the new order and adjust the attributes accordingly. |
| 58 | IRMapping operandMapper; |
| 59 | SmallVector<DictionaryAttr> argAttrs, resultAttrs; |
| 60 | funcOp.getAllArgAttrs(result&: argAttrs); |
| 61 | for (unsigned int i = 0; i < newArgsOrder.size(); ++i) { |
| 62 | operandMapper.map(from: funcOp.getArgument(idx: newArgsOrder[i]), |
| 63 | to: newFuncOp.getArgument(idx: i)); |
| 64 | newFuncOp.setArgAttrs(index: i, attributes: argAttrs[newArgsOrder[i]]); |
| 65 | } |
| 66 | funcOp.getAllResultAttrs(result&: resultAttrs); |
| 67 | for (unsigned int i = 0; i < newResultsOrder.size(); ++i) |
| 68 | newFuncOp.setResultAttrs(index: i, attributes: resultAttrs[newResultsOrder[i]]); |
| 69 | |
| 70 | // Clone the operations from the original function to the new function. |
| 71 | rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); |
| 72 | for (Operation &op : funcOp.getOps()) |
| 73 | rewriter.clone(op, mapper&: operandMapper); |
| 74 | |
| 75 | // Handle the return operation. |
| 76 | auto returnOp = cast<func::ReturnOp>( |
| 77 | Val: newFuncOp.getFunctionBody().begin()->getTerminator()); |
| 78 | SmallVector<Value> newReturnValues; |
| 79 | for (unsigned int idx : newResultsOrder) |
| 80 | newReturnValues.push_back(Elt: returnOp.getOperand(i: idx)); |
| 81 | rewriter.setInsertionPoint(returnOp); |
| 82 | auto newReturnOp = |
| 83 | rewriter.create<func::ReturnOp>(location: newFuncOp.getLoc(), args&: newReturnValues); |
| 84 | newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary()); |
| 85 | rewriter.eraseOp(op: returnOp); |
| 86 | |
| 87 | rewriter.eraseOp(op: funcOp); |
| 88 | |
| 89 | return newFuncOp; |
| 90 | } |
| 91 | |
| 92 | func::CallOp |
| 93 | func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp, |
| 94 | ArrayRef<unsigned> newArgsOrder, |
| 95 | ArrayRef<unsigned> newResultsOrder) { |
| 96 | assert( |
| 97 | callOp.getNumOperands() == newArgsOrder.size() && |
| 98 | "newArgsOrder must match the number of operands in the call operation" ); |
| 99 | assert( |
| 100 | callOp.getNumResults() == newResultsOrder.size() && |
| 101 | "newResultsOrder must match the number of results in the call operation" ); |
| 102 | SmallVector<Value> newArgsOrderValues; |
| 103 | for (unsigned int argIdx : newArgsOrder) |
| 104 | newArgsOrderValues.push_back(Elt: callOp.getOperand(i: argIdx)); |
| 105 | SmallVector<Type> newResultTypes; |
| 106 | for (unsigned int resIdx : newResultsOrder) |
| 107 | newResultTypes.push_back(Elt: callOp.getResult(i: resIdx).getType()); |
| 108 | |
| 109 | // Replace the kernel call operation with a new one that has the |
| 110 | // reordered arguments. |
| 111 | rewriter.setInsertionPoint(callOp); |
| 112 | auto newCallOp = rewriter.create<func::CallOp>( |
| 113 | location: callOp.getLoc(), args: callOp.getCallee(), args&: newResultTypes, args&: newArgsOrderValues); |
| 114 | newCallOp.setNoInlineAttr(callOp.getNoInlineAttr()); |
| 115 | for (auto &&[newIndex, origIndex] : llvm::enumerate(First&: newResultsOrder)) |
| 116 | rewriter.replaceAllUsesWith(from: callOp.getResult(i: origIndex), |
| 117 | to: newCallOp.getResult(i: newIndex)); |
| 118 | rewriter.eraseOp(op: callOp); |
| 119 | |
| 120 | return newCallOp; |
| 121 | } |
| 122 | |