1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" |
10 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/IR/Dialect.h" |
18 | #include "mlir/IR/Operation.h" |
19 | #include <optional> |
20 | |
21 | namespace mlir { |
22 | namespace bufferization { |
23 | namespace func_ext { |
24 | |
25 | void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { |
26 | analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; |
27 | auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); |
28 | auto createdAliasingResults = |
29 | aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); |
30 | auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); |
31 | auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); |
32 | (void)createdEquiv; |
33 | (void)createdAliasingResults; |
34 | (void)createdRead; |
35 | (void)createdWritten; |
36 | #ifndef NDEBUG |
37 | assert(createdEquiv.second && "equivalence info exists already" ); |
38 | assert(createdAliasingResults.second && "aliasing info exists already" ); |
39 | assert(createdRead.second && "bbarg access info exists already" ); |
40 | assert(createdWritten.second && "bbarg access info exists already" ); |
41 | #endif // NDEBUG |
42 | } |
43 | |
44 | /// Return the unique ReturnOp that terminates `funcOp`. |
45 | /// Return nullptr if there is no such unique ReturnOp. |
46 | static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { |
47 | func::ReturnOp returnOp; |
48 | for (Block &b : funcOp.getBody()) { |
49 | if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { |
50 | if (returnOp) |
51 | return nullptr; |
52 | returnOp = candidateOp; |
53 | } |
54 | } |
55 | return returnOp; |
56 | } |
57 | |
58 | /// Return the index-th bufferized function argument type. This assumes that the |
59 | /// specified argument is a tensor. If the tensor is ranked, a layout map may be |
60 | /// specified by the user (as per `options.functionArgTypeConverterFn`). |
61 | static BaseMemRefType |
62 | getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, |
63 | const BufferizationOptions &options) { |
64 | auto tensorType = |
65 | dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index)); |
66 | assert(tensorType && "expected TensorType" ); |
67 | |
68 | BaseMemRefType memrefType = options.functionArgTypeConverterFn( |
69 | tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); |
70 | |
71 | auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( |
72 | index, BufferizationDialect::kBufferLayoutAttrName); |
73 | if (!layoutAttr) |
74 | return memrefType; |
75 | |
76 | auto rankedMemrefType = dyn_cast<MemRefType>(memrefType); |
77 | assert(rankedMemrefType && "buffer layout not supported on unranked tensors" ); |
78 | return MemRefType::get( |
79 | rankedMemrefType.getShape(), rankedMemrefType.getElementType(), |
80 | layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); |
81 | } |
82 | |
83 | /// Return the FuncOp called by `callOp`. |
84 | static FuncOp getCalledFunction(CallOpInterface callOp) { |
85 | SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); |
86 | if (!sym) |
87 | return nullptr; |
88 | return dyn_cast_or_null<FuncOp>( |
89 | SymbolTable::lookupNearestSymbolFrom(callOp, sym)); |
90 | } |
91 | |
92 | /// Get FuncAnalysisState. |
93 | static const FuncAnalysisState & |
94 | getFuncAnalysisState(const AnalysisState &state) { |
95 | assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState" ); |
96 | auto *result = static_cast<const OneShotAnalysisState &>(state) |
97 | .getExtension<FuncAnalysisState>(); |
98 | assert(result && "FuncAnalysisState does not exist" ); |
99 | return *result; |
100 | } |
101 | |
102 | /// Return the state (phase) of analysis of the FuncOp. |
103 | static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, |
104 | FuncOp funcOp) { |
105 | if (!isa<OneShotAnalysisState>(Val: state)) |
106 | return FuncOpAnalysisState::NotAnalyzed; |
107 | auto *funcState = static_cast<const OneShotAnalysisState &>(state) |
108 | .getExtension<FuncAnalysisState>(); |
109 | if (!funcState) |
110 | return FuncOpAnalysisState::NotAnalyzed; |
111 | const auto &analyzedFuncOps = funcState->analyzedFuncOps; |
112 | auto it = analyzedFuncOps.find(funcOp); |
113 | if (it == analyzedFuncOps.end()) |
114 | return FuncOpAnalysisState::NotAnalyzed; |
115 | return it->second; |
116 | } |
117 | |
118 | /// Return the index of the bbArg in the given FuncOp that is equivalent to the |
119 | /// specified return value (if any). |
120 | static std::optional<int64_t> |
121 | getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, |
122 | int64_t returnValIdx) { |
123 | auto funcOpIt = state.equivalentFuncArgs.find(funcOp); |
124 | if (funcOpIt == state.equivalentFuncArgs.end()) |
125 | // No equivalence info stores for funcOp. |
126 | return std::nullopt; |
127 | |
128 | auto retValIt = funcOpIt->getSecond().find(returnValIdx); |
129 | if (retValIt == funcOpIt->getSecond().end()) |
130 | // Return value has no equivalent bbArg. |
131 | return std::nullopt; |
132 | |
133 | return retValIt->getSecond(); |
134 | } |
135 | |
136 | struct CallOpInterface |
137 | : public BufferizableOpInterface::ExternalModel<CallOpInterface, |
138 | func::CallOp> { |
139 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
140 | const AnalysisState &state) const { |
141 | func::CallOp callOp = cast<func::CallOp>(op); |
142 | FuncOp funcOp = getCalledFunction(callOp); |
143 | assert(funcOp && "expected CallOp to a FuncOp" ); |
144 | |
145 | if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) |
146 | // FuncOp not analyzed yet. Assume that OpOperand is read. |
147 | return true; |
148 | |
149 | const FuncAnalysisState &funcState = getFuncAnalysisState(state); |
150 | return funcState.readBbArgs.lookup(Val: funcOp).contains( |
151 | opOperand.getOperandNumber()); |
152 | } |
153 | |
154 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
155 | const AnalysisState &state) const { |
156 | func::CallOp callOp = cast<func::CallOp>(op); |
157 | FuncOp funcOp = getCalledFunction(callOp); |
158 | assert(funcOp && "expected CallOp to a FuncOp" ); |
159 | |
160 | if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) |
161 | // FuncOp not analyzed yet. Assume that OpOperand is written. |
162 | return true; |
163 | |
164 | const FuncAnalysisState &funcState = getFuncAnalysisState(state); |
165 | return funcState.writtenBbArgs.lookup(Val: funcOp).contains( |
166 | opOperand.getOperandNumber()); |
167 | } |
168 | |
169 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
170 | const AnalysisState &state) const { |
171 | func::CallOp callOp = cast<func::CallOp>(op); |
172 | FuncOp funcOp = getCalledFunction(callOp); |
173 | assert(funcOp && "expected CallOp to a FuncOp" ); |
174 | if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) |
175 | // FuncOp not analyzed yet. Any OpResult may be aliasing. |
176 | return detail::unknownGetAliasingValues(opOperand); |
177 | |
178 | // Get aliasing results from state. |
179 | const FuncAnalysisState &funcState = getFuncAnalysisState(state); |
180 | auto aliasingReturnVals = |
181 | funcState.aliasingReturnVals.lookup(Val: funcOp).lookup( |
182 | opOperand.getOperandNumber()); |
183 | |
184 | // Check if the aliasing OpResult is equivalent to the OpOperand. |
185 | std::optional<int64_t> equivalent = {}; |
186 | if (aliasingReturnVals.size() == 1) { |
187 | equivalent = getEquivalentFuncArgIdx(funcOp, funcState, |
188 | aliasingReturnVals.front()); |
189 | assert((!equivalent.has_value() || |
190 | *equivalent == opOperand.getOperandNumber()) && |
191 | "inconsistent analysis state" ); |
192 | } |
193 | AliasingValueList result; |
194 | for (int64_t resultIdx : aliasingReturnVals) |
195 | result.addAlias({callOp->getOpResult(resultIdx), |
196 | equivalent.has_value() ? BufferRelation::Equivalent |
197 | : BufferRelation::Unknown, |
198 | /*isDefinite=*/equivalent.has_value()}); |
199 | return result; |
200 | } |
201 | |
202 | FailureOr<BaseMemRefType> |
203 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
204 | SmallVector<Value> &invocationStack) const { |
205 | auto callOp = cast<func::CallOp>(op); |
206 | FuncOp funcOp = getCalledFunction(callOp); |
207 | assert(funcOp && "expected CallOp to a FuncOp" ); |
208 | |
209 | // The callee was already bufferized, so we can directly take the type from |
210 | // its signature. |
211 | FunctionType funcType = funcOp.getFunctionType(); |
212 | return cast<BaseMemRefType>( |
213 | funcType.getResult(cast<OpResult>(Val&: value).getResultNumber())); |
214 | } |
215 | |
216 | /// All function arguments are writable. It is the responsibility of the |
217 | /// CallOp to insert buffer copies where necessary. |
218 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
219 | const BufferizationOptions &options) const { |
220 | func::CallOp callOp = cast<func::CallOp>(op); |
221 | |
222 | // 1. Compute the result types of the new CallOp. |
223 | SmallVector<Type> resultTypes; |
224 | for (Value result : callOp.getResults()) { |
225 | Type returnType = result.getType(); |
226 | if (!isa<TensorType>(returnType)) { |
227 | // Non-tensor values are returned. |
228 | resultTypes.push_back(returnType); |
229 | continue; |
230 | } |
231 | |
232 | // Returning a memref. |
233 | FailureOr<BaseMemRefType> resultType = |
234 | bufferization::getBufferType(result, options); |
235 | if (failed(resultType)) |
236 | return failure(); |
237 | resultTypes.push_back(*resultType); |
238 | } |
239 | |
240 | // 2. Rewrite tensor operands as memrefs based on type of the already |
241 | // bufferized callee. |
242 | SmallVector<Value> newOperands; |
243 | FuncOp funcOp = getCalledFunction(callOp); |
244 | assert(funcOp && "expected CallOp to a FuncOp" ); |
245 | FunctionType funcType = funcOp.getFunctionType(); |
246 | |
247 | for (OpOperand &opOperand : callOp->getOpOperands()) { |
248 | // Non-tensor operands are just copied. |
249 | if (!isa<TensorType>(opOperand.get().getType())) { |
250 | newOperands.push_back(opOperand.get()); |
251 | continue; |
252 | } |
253 | |
254 | // Retrieve buffers for tensor operands. |
255 | FailureOr<Value> maybeBuffer = |
256 | getBuffer(rewriter, opOperand.get(), options); |
257 | if (failed(maybeBuffer)) |
258 | return failure(); |
259 | Value buffer = *maybeBuffer; |
260 | |
261 | // Caller / callee type mismatch is handled with a CastOp. |
262 | auto memRefType = funcType.getInput(opOperand.getOperandNumber()); |
263 | // Since we don't yet have a clear layout story, to_memref may |
264 | // conservatively turn tensors into more dynamic memref than necessary. |
265 | // If the memref type of the callee fails, introduce an extra memref.cast |
266 | // that will either canonicalize away or fail compilation until we can do |
267 | // something better. |
268 | if (buffer.getType() != memRefType) { |
269 | assert( |
270 | memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && |
271 | "CallOp::bufferize: cast incompatible" ); |
272 | Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), |
273 | memRefType, buffer); |
274 | buffer = castBuffer; |
275 | } |
276 | newOperands.push_back(buffer); |
277 | } |
278 | |
279 | // 3. Create the new CallOp. |
280 | Operation *newCallOp = rewriter.create<func::CallOp>( |
281 | callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); |
282 | newCallOp->setAttrs(callOp->getAttrs()); |
283 | |
284 | // 4. Replace the old op with the new op. |
285 | replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults()); |
286 | |
287 | return success(); |
288 | } |
289 | }; |
290 | |
291 | struct ReturnOpInterface |
292 | : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, |
293 | func::ReturnOp> { |
294 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
295 | const AnalysisState &state) const { |
296 | return true; |
297 | } |
298 | |
299 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
300 | const AnalysisState &state) const { |
301 | return false; |
302 | } |
303 | |
304 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
305 | const AnalysisState &state) const { |
306 | return {}; |
307 | } |
308 | |
309 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
310 | const BufferizationOptions &options) const { |
311 | #ifndef NDEBUG |
312 | auto returnOp = cast<func::ReturnOp>(op); |
313 | assert(isa<FuncOp>(returnOp->getParentOp()) && |
314 | "only support FuncOp parent for ReturnOp" ); |
315 | #endif // NDEBUG |
316 | |
317 | // ReturnOps are bufferized as part of FuncOps. |
318 | return success(); |
319 | } |
320 | }; |
321 | |
322 | struct FuncOpInterface |
323 | : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< |
324 | FuncOpInterface, FuncOp> { |
325 | |
326 | static bool supportsUnstructuredControlFlow() { return true; } |
327 | |
328 | bool hasTensorSemantics(Operation *op) const { |
329 | auto isaTensor = llvm::IsaPred<TensorType>; |
330 | |
331 | // A function has tensor semantics if it has tensor arguments/results. |
332 | auto funcOp = cast<FuncOp>(op); |
333 | bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); |
334 | bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); |
335 | if (hasTensorArg || hasTensorResult) |
336 | return true; |
337 | |
338 | // It also has tensor semantics if it has tensor block arguments. |
339 | // TODO: Decouple bufferization of unstructured control flow from |
340 | // BufferizableOpInterface implementations. We should only care about |
341 | // region entry block arguments here (which are already covered by the |
342 | // argument types of the function). |
343 | for (Block &block : funcOp.getBody()) |
344 | if (any_of(block.getArgumentTypes(), isaTensor)) |
345 | return true; |
346 | |
347 | return false; |
348 | } |
349 | |
350 | AliasingOpOperandList |
351 | getAliasingOpOperands(Operation *op, Value value, |
352 | const AnalysisState &state) const { |
353 | return getAliasingBranchOpOperands(op, cast<BlockArgument>(Val&: value), state); |
354 | } |
355 | |
356 | FailureOr<BaseMemRefType> |
357 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
358 | SmallVector<Value> &invocationStack) const { |
359 | auto funcOp = cast<FuncOp>(op); |
360 | auto bbArg = cast<BlockArgument>(Val&: value); |
361 | |
362 | // Function arguments are special. |
363 | if (bbArg.getOwner() == &funcOp.getBody().front()) |
364 | return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), |
365 | options); |
366 | |
367 | return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: |
368 | getBufferType(op, value, options, invocationStack); |
369 | } |
370 | |
371 | LogicalResult verifyAnalysis(Operation *op, |
372 | const AnalysisState &state) const { |
373 | auto funcOp = cast<func::FuncOp>(op); |
374 | // TODO: func.func with multiple returns are not supported. |
375 | if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal()) |
376 | return op->emitOpError(message: "op without unique func.return is not supported" ); |
377 | return success(); |
378 | } |
379 | |
380 | /// Rewrite function bbArgs and return values into buffer form. This function |
381 | /// bufferizes the function signature and the ReturnOp. When the entire |
382 | /// function body has been bufferized, function return types can be switched |
383 | /// to more concise memref types as part of `foldMemRefCasts`. |
384 | /// |
385 | /// All function bbArgs are writable unless they are explicitly marked as |
386 | /// read-only. Callers must insert copies when needed. |
387 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
388 | const BufferizationOptions &options) const { |
389 | auto funcOp = cast<FuncOp>(op); |
390 | FunctionType funcType = funcOp.getFunctionType(); |
391 | |
392 | // Construct the bufferized function type. |
393 | SmallVector<Type> argTypes; |
394 | for (const auto &it : llvm::enumerate(funcType.getInputs())) { |
395 | Type argType = it.value(); |
396 | if (dyn_cast<TensorType>(argType)) { |
397 | argTypes.push_back( |
398 | getBufferizedFunctionArgType(funcOp, it.index(), options)); |
399 | continue; |
400 | } |
401 | argTypes.push_back(argType); |
402 | } |
403 | |
404 | // Bodiless functions are assumed opaque and we cannot know the |
405 | // bufferization contract they want to enforce. As a consequence, only |
406 | // support functions that don't return any tensors atm. |
407 | if (funcOp.isExternal()) { |
408 | SmallVector<Type> retTypes; |
409 | for (Type resultType : funcType.getResults()) { |
410 | if (isa<TensorType>(resultType)) |
411 | return funcOp->emitError() << "cannot bufferize bodiless function " |
412 | << "that returns a tensor" ; |
413 | retTypes.push_back(resultType); |
414 | } |
415 | funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); |
416 | return success(); |
417 | } |
418 | |
419 | // TODO: Support functions with multiple returns. |
420 | func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); |
421 | assert(returnOp && "expected func with single return op" ); |
422 | Location loc = returnOp.getLoc(); |
423 | |
424 | // 1. Bufferize every block. |
425 | for (Block &block : funcOp.getBody()) |
426 | if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, |
427 | options))) |
428 | return failure(); |
429 | |
430 | // 2. For each result, keep track of which inplace argument it reuses. |
431 | SmallVector<Value> returnValues; |
432 | for (OpOperand &returnOperand : returnOp->getOpOperands()) { |
433 | Value returnVal = returnOperand.get(); |
434 | auto tensorType = dyn_cast<TensorType>(returnVal.getType()); |
435 | rewriter.setInsertionPoint(returnOp); |
436 | |
437 | // If not a tensor type just forward it. |
438 | if (!tensorType) { |
439 | returnValues.push_back(returnVal); |
440 | continue; |
441 | } |
442 | |
443 | // Note: If `inferFunctionResultLayout = true`, cast are later folded |
444 | // away. |
445 | BaseMemRefType resultType = options.functionArgTypeConverterFn( |
446 | tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, |
447 | options); |
448 | Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( |
449 | loc, resultType, returnVal); |
450 | returnValues.push_back(toMemrefOp); |
451 | } |
452 | |
453 | // 3. Rewrite the terminator without the in-place bufferizable values. |
454 | returnOp.getOperandsMutable().assign(returnValues); |
455 | |
456 | // 4. Rewrite the FuncOp type to buffer form. |
457 | funcOp.setType(FunctionType::get(op->getContext(), argTypes, |
458 | ValueRange(returnValues).getTypes())); |
459 | |
460 | return success(); |
461 | } |
462 | |
463 | /// Return `true` if the given function argument is writable. |
464 | bool isWritable(Operation *op, Value value, |
465 | const AnalysisState &state) const { |
466 | auto funcOp = cast<FuncOp>(op); |
467 | BlockArgument bbArg = dyn_cast<BlockArgument>(Val&: value); |
468 | assert(bbArg && "expected BlockArgument" ); |
469 | |
470 | // Non-entry block arguments are always writable. (They may alias with |
471 | // values that are not writable, which will turn them into read-only.) |
472 | if (bbArg.getOwner() != &funcOp.getBody().front()) |
473 | return true; |
474 | |
475 | // "bufferization.writable" overrides other writability decisions. This is |
476 | // currently used for testing only. |
477 | if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( |
478 | bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) |
479 | return writable.getValue(); |
480 | |
481 | // All function arguments are writable by default. |
482 | return true; |
483 | } |
484 | }; |
485 | |
486 | } // namespace func_ext |
487 | } // namespace bufferization |
488 | } // namespace mlir |
489 | |
490 | void mlir::bufferization::func_ext:: |
491 | registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { |
492 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) { |
493 | func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); |
494 | func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); |
495 | func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); |
496 | }); |
497 | } |
498 | |