| 1 | //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" |
| 10 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| 11 | |
| 12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 14 | #include "mlir/IR/Operation.h" |
| 15 | #include "mlir/Pass/Pass.h" |
| 16 | |
| 17 | namespace mlir { |
| 18 | namespace bufferization { |
| 19 | #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS |
| 20 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
| 21 | } // namespace bufferization |
| 22 | } // namespace mlir |
| 23 | |
| 24 | using namespace mlir; |
| 25 | using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; |
| 26 | using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; |
| 27 | |
| 28 | /// Return `true` if the given MemRef type has a fully dynamic layout. |
| 29 | static bool hasFullyDynamicLayoutMap(MemRefType type) { |
| 30 | int64_t offset; |
| 31 | SmallVector<int64_t, 4> strides; |
| 32 | if (failed(type.getStridesAndOffset(strides, offset))) |
| 33 | return false; |
| 34 | if (!llvm::all_of(strides, ShapedType::isDynamic)) |
| 35 | return false; |
| 36 | if (!ShapedType::isDynamic(offset)) |
| 37 | return false; |
| 38 | return true; |
| 39 | } |
| 40 | |
| 41 | /// Return `true` if the given MemRef type has a static identity layout (i.e., |
| 42 | /// no layout). |
| 43 | static bool hasStaticIdentityLayout(MemRefType type) { |
| 44 | return type.getLayout().isIdentity(); |
| 45 | } |
| 46 | |
| 47 | // Updates the func op and entry block. |
| 48 | // |
| 49 | // Any args appended to the entry block are added to `appendedEntryArgs`. |
| 50 | // If `addResultAttribute` is true, adds the unit attribute `bufferize.result` |
| 51 | // to each newly created function argument. |
| 52 | static LogicalResult |
| 53 | updateFuncOp(func::FuncOp func, |
| 54 | SmallVectorImpl<BlockArgument> &appendedEntryArgs, |
| 55 | bool addResultAttribute) { |
| 56 | auto functionType = func.getFunctionType(); |
| 57 | |
| 58 | // Collect information about the results will become appended arguments. |
| 59 | SmallVector<Type, 6> erasedResultTypes; |
| 60 | BitVector erasedResultIndices(functionType.getNumResults()); |
| 61 | for (const auto &resultType : llvm::enumerate(functionType.getResults())) { |
| 62 | if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) { |
| 63 | if (!hasStaticIdentityLayout(memrefType) && |
| 64 | !hasFullyDynamicLayoutMap(memrefType)) { |
| 65 | // Only buffers with static identity layout can be allocated. These can |
| 66 | // be casted to memrefs with fully dynamic layout map. Other layout maps |
| 67 | // are not supported. |
| 68 | return func->emitError() |
| 69 | << "cannot create out param for result with unsupported layout" ; |
| 70 | } |
| 71 | erasedResultIndices.set(resultType.index()); |
| 72 | erasedResultTypes.push_back(memrefType); |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | // Add the new arguments to the function type. |
| 77 | auto newArgTypes = llvm::to_vector<6>( |
| 78 | llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); |
| 79 | auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, |
| 80 | functionType.getResults()); |
| 81 | func.setType(newFunctionType); |
| 82 | |
| 83 | // Transfer the result attributes to arg attributes. |
| 84 | auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); |
| 85 | for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { |
| 86 | func.setArgAttrs(functionType.getNumInputs() + i, |
| 87 | func.getResultAttrs(*erasedIndicesIt)); |
| 88 | if (addResultAttribute) |
| 89 | func.setArgAttr(functionType.getNumInputs() + i, |
| 90 | StringAttr::get(func.getContext(), "bufferize.result" ), |
| 91 | UnitAttr::get(func.getContext())); |
| 92 | } |
| 93 | |
| 94 | // Erase the results. |
| 95 | if (failed(func.eraseResults(erasedResultIndices))) |
| 96 | return failure(); |
| 97 | |
| 98 | // Add the new arguments to the entry block if the function is not external. |
| 99 | if (func.isExternal()) |
| 100 | return success(); |
| 101 | Location loc = func.getLoc(); |
| 102 | for (Type type : erasedResultTypes) |
| 103 | appendedEntryArgs.push_back(Elt: func.front().addArgument(type, loc)); |
| 104 | |
| 105 | return success(); |
| 106 | } |
| 107 | |
| 108 | // Updates all ReturnOps in the scope of the given func::FuncOp by either |
| 109 | // keeping them as return values or copying the associated buffer contents into |
| 110 | // the given out-params. |
| 111 | static LogicalResult |
| 112 | updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, |
| 113 | const bufferization::BufferResultsToOutParamsOpts &options) { |
| 114 | auto res = func.walk([&](func::ReturnOp op) { |
| 115 | SmallVector<Value, 6> copyIntoOutParams; |
| 116 | SmallVector<Value, 6> keepAsReturnOperands; |
| 117 | for (Value operand : op.getOperands()) { |
| 118 | if (isa<MemRefType>(operand.getType())) |
| 119 | copyIntoOutParams.push_back(operand); |
| 120 | else |
| 121 | keepAsReturnOperands.push_back(operand); |
| 122 | } |
| 123 | OpBuilder builder(op); |
| 124 | for (auto [orig, arg] : llvm::zip(t&: copyIntoOutParams, u&: appendedEntryArgs)) { |
| 125 | if (options.hoistStaticAllocs && |
| 126 | isa_and_nonnull<bufferization::AllocationOpInterface>( |
| 127 | orig.getDefiningOp()) && |
| 128 | mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) { |
| 129 | orig.replaceAllUsesWith(newValue: arg); |
| 130 | orig.getDefiningOp()->erase(); |
| 131 | } else { |
| 132 | if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) |
| 133 | return WalkResult::interrupt(); |
| 134 | } |
| 135 | } |
| 136 | builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); |
| 137 | op.erase(); |
| 138 | return WalkResult::advance(); |
| 139 | }); |
| 140 | return failure(res.wasInterrupted()); |
| 141 | } |
| 142 | |
| 143 | // Updates all CallOps in the scope of the given ModuleOp by allocating |
| 144 | // temporary buffers for newly introduced out params. |
| 145 | static LogicalResult |
| 146 | updateCalls(ModuleOp module, |
| 147 | const bufferization::BufferResultsToOutParamsOpts &options) { |
| 148 | bool didFail = false; |
| 149 | SymbolTable symtab(module); |
| 150 | module.walk([&](func::CallOp op) { |
| 151 | auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); |
| 152 | if (!callee) { |
| 153 | op.emitError() << "cannot find callee '" << op.getCallee() << "' in " |
| 154 | << "symbol table" ; |
| 155 | didFail = true; |
| 156 | return; |
| 157 | } |
| 158 | if (!options.filterFn(&callee)) |
| 159 | return; |
| 160 | SmallVector<Value, 6> replaceWithNewCallResults; |
| 161 | SmallVector<Value, 6> replaceWithOutParams; |
| 162 | for (OpResult result : op.getResults()) { |
| 163 | if (isa<MemRefType>(result.getType())) |
| 164 | replaceWithOutParams.push_back(result); |
| 165 | else |
| 166 | replaceWithNewCallResults.push_back(result); |
| 167 | } |
| 168 | SmallVector<Value, 6> outParams; |
| 169 | OpBuilder builder(op); |
| 170 | for (Value memref : replaceWithOutParams) { |
| 171 | if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { |
| 172 | op.emitError() |
| 173 | << "cannot create out param for dynamically shaped result" ; |
| 174 | didFail = true; |
| 175 | return; |
| 176 | } |
| 177 | auto memrefType = cast<MemRefType>(memref.getType()); |
| 178 | auto allocType = |
| 179 | MemRefType::get(memrefType.getShape(), memrefType.getElementType(), |
| 180 | AffineMap(), memrefType.getMemorySpace()); |
| 181 | auto maybeOutParam = |
| 182 | options.allocationFn(builder, op.getLoc(), allocType); |
| 183 | if (failed(maybeOutParam)) { |
| 184 | op.emitError() << "failed to create allocation op" ; |
| 185 | didFail = true; |
| 186 | return; |
| 187 | } |
| 188 | Value outParam = maybeOutParam.value(); |
| 189 | if (!hasStaticIdentityLayout(memrefType)) { |
| 190 | // Layout maps are already checked in `updateFuncOp`. |
| 191 | assert(hasFullyDynamicLayoutMap(memrefType) && |
| 192 | "layout map not supported" ); |
| 193 | outParam = |
| 194 | builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); |
| 195 | } |
| 196 | memref.replaceAllUsesWith(newValue: outParam); |
| 197 | outParams.push_back(Elt: outParam); |
| 198 | } |
| 199 | |
| 200 | auto newOperands = llvm::to_vector<6>(op.getOperands()); |
| 201 | newOperands.append(outParams.begin(), outParams.end()); |
| 202 | auto newResultTypes = llvm::to_vector<6>(Range: llvm::map_range( |
| 203 | C&: replaceWithNewCallResults, F: [](Value v) { return v.getType(); })); |
| 204 | auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), |
| 205 | newResultTypes, newOperands); |
| 206 | for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) |
| 207 | std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); |
| 208 | op.erase(); |
| 209 | }); |
| 210 | |
| 211 | return failure(IsFailure: didFail); |
| 212 | } |
| 213 | |
| 214 | LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( |
| 215 | ModuleOp module, |
| 216 | const bufferization::BufferResultsToOutParamsOpts &options) { |
| 217 | for (auto func : module.getOps<func::FuncOp>()) { |
| 218 | if (!options.filterFn(&func)) |
| 219 | continue; |
| 220 | SmallVector<BlockArgument, 6> appendedEntryArgs; |
| 221 | if (failed( |
| 222 | updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) |
| 223 | return failure(); |
| 224 | if (func.isExternal()) |
| 225 | continue; |
| 226 | if (failed(updateReturnOps(func, appendedEntryArgs, options))) { |
| 227 | return failure(); |
| 228 | } |
| 229 | } |
| 230 | if (failed(updateCalls(module, options))) |
| 231 | return failure(); |
| 232 | return success(); |
| 233 | } |
| 234 | |
| 235 | namespace { |
| 236 | struct BufferResultsToOutParamsPass |
| 237 | : bufferization::impl::BufferResultsToOutParamsPassBase< |
| 238 | BufferResultsToOutParamsPass> { |
| 239 | using Base::Base; |
| 240 | |
| 241 | void runOnOperation() override { |
| 242 | // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. |
| 243 | if (addResultAttribute) |
| 244 | options.addResultAttribute = true; |
| 245 | if (hoistStaticAllocs) |
| 246 | options.hoistStaticAllocs = true; |
| 247 | |
| 248 | if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), |
| 249 | options))) |
| 250 | return signalPassFailure(); |
| 251 | } |
| 252 | |
| 253 | private: |
| 254 | bufferization::BufferResultsToOutParamsOpts options; |
| 255 | }; |
| 256 | } // namespace |
| 257 | |