1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
13#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/IR/Dialect.h"
18#include "mlir/IR/Operation.h"
19#include <optional>
20
21namespace mlir {
22/// Return all func.return ops in the given function.
23SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
24 SmallVector<func::ReturnOp> result;
25 for (Block &b : funcOp.getBody())
26 if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
27 result.push_back(returnOp);
28 return result;
29}
30
31namespace bufferization {
32namespace func_ext {
33
34void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
35 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
36 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
37 auto createdAliasingResults =
38 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping());
39 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
40 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
41 (void)createdEquiv;
42 (void)createdAliasingResults;
43 (void)createdRead;
44 (void)createdWritten;
45#ifndef NDEBUG
46 assert(createdEquiv.second && "equivalence info exists already");
47 assert(createdAliasingResults.second && "aliasing info exists already");
48 assert(createdRead.second && "bbarg access info exists already");
49 assert(createdWritten.second && "bbarg access info exists already");
50#endif // NDEBUG
51}
52
53/// Return the index-th bufferized function argument type. This assumes that the
54/// specified argument is a tensor. If the tensor is ranked, a layout map may be
55/// specified by the user (as per `options.functionArgTypeConverterFn`).
56static BaseMemRefType
57getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
58 const BufferizationOptions &options) {
59 auto tensorType =
60 dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
61 assert(tensorType && "expected TensorType");
62
63 BaseMemRefType memrefType = options.functionArgTypeConverterFn(
64 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
65
66 auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
67 index, BufferizationDialect::kBufferLayoutAttrName);
68 if (!layoutAttr)
69 return memrefType;
70
71 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
72 assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
73 return MemRefType::get(rankedMemrefType.getShape(),
74 rankedMemrefType.getElementType(), layoutAttr,
75 rankedMemrefType.getMemorySpace());
76}
77
78/// Return the FuncOp called by `callOp`.
79static FuncOp getCalledFunction(CallOpInterface callOp,
80 SymbolTableCollection &symbolTables) {
81 SymbolRefAttr sym =
82 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
83 if (!sym)
84 return nullptr;
85 return dyn_cast_or_null<FuncOp>(
86 symbolTables.lookupNearestSymbolFrom(callOp, sym));
87}
88
89/// Return the FuncOp called by `callOp`.
90static FuncOp getCalledFunction(CallOpInterface callOp,
91 const AnalysisState &state) {
92 auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
93
94 if (auto *funcAnalysisState =
95 oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
96 // Use the cached symbol tables.
97 return getCalledFunction(callOp, funcAnalysisState->symbolTables);
98 }
99
100 SymbolTableCollection symbolTables;
101 return getCalledFunction(callOp, symbolTables);
102}
103
104/// Get FuncAnalysisState.
105static const FuncAnalysisState &
106getFuncAnalysisState(const AnalysisState &state) {
107 assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
108 auto *result = static_cast<const OneShotAnalysisState &>(state)
109 .getExtension<FuncAnalysisState>();
110 assert(result && "FuncAnalysisState does not exist");
111 return *result;
112}
113
114/// Return the state (phase) of analysis of the FuncOp.
115static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
116 FuncOp funcOp) {
117 if (!isa<OneShotAnalysisState>(Val: state))
118 return FuncOpAnalysisState::NotAnalyzed;
119 auto *funcState = static_cast<const OneShotAnalysisState &>(state)
120 .getExtension<FuncAnalysisState>();
121 if (!funcState)
122 return FuncOpAnalysisState::NotAnalyzed;
123 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
124 auto it = analyzedFuncOps.find(funcOp);
125 if (it == analyzedFuncOps.end())
126 return FuncOpAnalysisState::NotAnalyzed;
127 return it->second;
128}
129
130/// Return the index of the bbArg in the given FuncOp that is equivalent to the
131/// specified return value (if any).
132static std::optional<int64_t>
133getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state,
134 int64_t returnValIdx) {
135 auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
136 if (funcOpIt == state.equivalentFuncArgs.end())
137 // No equivalence info stores for funcOp.
138 return std::nullopt;
139
140 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
141 if (retValIt == funcOpIt->getSecond().end())
142 // Return value has no equivalent bbArg.
143 return std::nullopt;
144
145 return retValIt->getSecond();
146}
147
148struct CallOpInterface
149 : public BufferizableOpInterface::ExternalModel<CallOpInterface,
150 func::CallOp> {
151 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
152 const AnalysisState &state) const {
153 func::CallOp callOp = cast<func::CallOp>(op);
154 FuncOp funcOp = getCalledFunction(callOp, state);
155 assert(funcOp && "expected CallOp to a FuncOp");
156
157 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
158 // FuncOp not analyzed yet. Assume that OpOperand is read.
159 return true;
160
161 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
162 return funcState.readBbArgs.lookup(Val: funcOp).contains(
163 opOperand.getOperandNumber());
164 }
165
166 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
167 const AnalysisState &state) const {
168 func::CallOp callOp = cast<func::CallOp>(op);
169 FuncOp funcOp = getCalledFunction(callOp, state);
170 assert(funcOp && "expected CallOp to a FuncOp");
171
172 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
173 // FuncOp not analyzed yet. Assume that OpOperand is written.
174 return true;
175
176 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
177 return funcState.writtenBbArgs.lookup(Val: funcOp).contains(
178 opOperand.getOperandNumber());
179 }
180
181 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
182 const AnalysisState &state) const {
183 func::CallOp callOp = cast<func::CallOp>(op);
184 FuncOp funcOp = getCalledFunction(callOp, state);
185 assert(funcOp && "expected CallOp to a FuncOp");
186 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
187 // FuncOp not analyzed yet. Any OpResult may be aliasing.
188 return detail::unknownGetAliasingValues(opOperand);
189
190 // Get aliasing results from state.
191 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
192 auto aliasingReturnVals =
193 funcState.aliasingReturnVals.lookup(Val: funcOp).lookup(
194 opOperand.getOperandNumber());
195
196 // Check if the aliasing OpResult is equivalent to the OpOperand.
197 std::optional<int64_t> equivalent = {};
198 if (aliasingReturnVals.size() == 1) {
199 equivalent = getEquivalentFuncArgIdx(funcOp, funcState,
200 aliasingReturnVals.front());
201 assert((!equivalent.has_value() ||
202 *equivalent == opOperand.getOperandNumber()) &&
203 "inconsistent analysis state");
204 }
205 AliasingValueList result;
206 for (int64_t resultIdx : aliasingReturnVals)
207 result.addAlias({callOp->getOpResult(resultIdx),
208 equivalent.has_value() ? BufferRelation::Equivalent
209 : BufferRelation::Unknown,
210 /*isDefinite=*/equivalent.has_value()});
211 return result;
212 }
213
214 FailureOr<BaseMemRefType>
215 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
216 const BufferizationState &state,
217 SmallVector<Value> &invocationStack) const {
218 auto callOp = cast<func::CallOp>(op);
219
220 // TODO Avoid recomputing the symbol tables every time.
221 SymbolTableCollection symbolTable;
222
223 FuncOp funcOp = getCalledFunction(callOp, symbolTable);
224 assert(funcOp && "expected CallOp to a FuncOp");
225
226 // If the callee was already bufferized, we can directly take the type from
227 // its signature.
228 FunctionType funcType = funcOp.getFunctionType();
229 Type resultType =
230 funcType.getResult(cast<OpResult>(Val&: value).getResultNumber());
231 if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
232 return bufferizedType;
233
234 // Otherwise, call the type converter to compute the bufferized type.
235 auto tensorType = cast<TensorType>(Val&: resultType);
236 return options.functionArgTypeConverterFn(
237 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
238 }
239
240 /// All function arguments are writable. It is the responsibility of the
241 /// CallOp to insert buffer copies where necessary.
242 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
243 const BufferizationOptions &options,
244 BufferizationState &state) const {
245 func::CallOp callOp = cast<func::CallOp>(op);
246
247 // 1. Compute the result types of the new CallOp.
248 SmallVector<Type> resultTypes;
249 for (Value result : callOp.getResults()) {
250 Type returnType = result.getType();
251 if (!isa<TensorType>(returnType)) {
252 // Non-tensor values are returned.
253 resultTypes.push_back(returnType);
254 continue;
255 }
256
257 // Returning a memref.
258 FailureOr<BaseMemRefType> resultType =
259 bufferization::getBufferType(result, options, state);
260 if (failed(resultType))
261 return failure();
262 resultTypes.push_back(*resultType);
263 }
264
265 // 2. Rewrite tensor operands as memrefs based on type of the already
266 // bufferized callee.
267 SmallVector<Value> newOperands;
268
269 FuncOp funcOp = getCalledFunction(callOp, state.getSymbolTables());
270 assert(funcOp && "expected CallOp to a FuncOp");
271 FunctionType funcType = funcOp.getFunctionType();
272
273 for (OpOperand &opOperand : callOp->getOpOperands()) {
274 // Non-tensor operands are just copied.
275 if (!isa<TensorType>(opOperand.get().getType())) {
276 newOperands.push_back(opOperand.get());
277 continue;
278 }
279
280 // Retrieve buffers for tensor operands.
281 FailureOr<Value> maybeBuffer =
282 getBuffer(rewriter, opOperand.get(), options, state);
283 if (failed(maybeBuffer))
284 return failure();
285 Value buffer = *maybeBuffer;
286
287 // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
288 auto memRefType = funcType.getInput(opOperand.getOperandNumber());
289 if (!isa<BaseMemRefType>(memRefType)) {
290 // The called function was not bufferized yet. This can happen when
291 // there cycles in the function call graph. Compute the bufferized
292 // result type.
293 FailureOr<BaseMemRefType> maybeMemRefType =
294 bufferization::getBufferType(
295 funcOp.getArgument(opOperand.getOperandNumber()), options,
296 state);
297 if (failed(maybeMemRefType))
298 return failure();
299 memRefType = *maybeMemRefType;
300 }
301
302 // Since we don't yet have a clear layout story, to_buffer may
303 // conservatively turn tensors into more dynamic memref than necessary.
304 // If the memref type of the callee fails, introduce an extra memref.cast
305 // that will either canonicalize away or fail compilation until we can do
306 // something better. Insert a reallocation + copy if it cannot be
307 // statically guaranteed that a direct cast would be valid.
308 if (buffer.getType() != memRefType) {
309 auto memrefDstType = dyn_cast<MemRefType>(memRefType);
310 assert(memrefDstType &&
311 "buffer layout not supported on unranked tensors");
312 FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
313 rewriter, buffer, memrefDstType, options);
314 if (failed(replacement))
315 return failure();
316 buffer = *replacement;
317 }
318 newOperands.push_back(buffer);
319 }
320
321 // 3. Create the new CallOp.
322 Operation *newCallOp = rewriter.create<func::CallOp>(
323 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
324 newCallOp->setAttrs(callOp->getAttrs());
325
326 // 4. Replace the old op with the new op.
327 replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());
328
329 return success();
330 }
331};
332
333struct ReturnOpInterface
334 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
335 func::ReturnOp> {
336 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
337 const AnalysisState &state) const {
338 return true;
339 }
340
341 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
342 const AnalysisState &state) const {
343 return false;
344 }
345
346 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
347 const AnalysisState &state) const {
348 return {};
349 }
350
351 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
352 const BufferizationOptions &options,
353 BufferizationState &state) const {
354#ifndef NDEBUG
355 auto returnOp = cast<func::ReturnOp>(op);
356 assert(isa<FuncOp>(returnOp->getParentOp()) &&
357 "only support FuncOp parent for ReturnOp");
358#endif // NDEBUG
359
360 // ReturnOps are bufferized as part of FuncOps.
361 return success();
362 }
363};
364
365struct FuncOpInterface
366 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
367 FuncOpInterface, FuncOp> {
368
369 static bool supportsUnstructuredControlFlow() { return true; }
370
371 bool hasTensorSemantics(Operation *op) const {
372 auto isaTensor = llvm::IsaPred<TensorType>;
373
374 // A function has tensor semantics if it has tensor arguments/results.
375 auto funcOp = cast<FuncOp>(op);
376 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
377 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
378 if (hasTensorArg || hasTensorResult)
379 return true;
380
381 // It also has tensor semantics if it has tensor block arguments.
382 // TODO: Decouple bufferization of unstructured control flow from
383 // BufferizableOpInterface implementations. We should only care about
384 // region entry block arguments here (which are already covered by the
385 // argument types of the function).
386 for (Block &block : funcOp.getBody())
387 if (any_of(block.getArgumentTypes(), isaTensor))
388 return true;
389
390 return false;
391 }
392
393 AliasingOpOperandList
394 getAliasingOpOperands(Operation *op, Value value,
395 const AnalysisState &state) const {
396 return getAliasingBranchOpOperands(op, bbArg: cast<BlockArgument>(Val&: value), state);
397 }
398
399 FailureOr<BaseMemRefType>
400 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
401 const BufferizationState &state,
402 SmallVector<Value> &invocationStack) const {
403 auto funcOp = cast<FuncOp>(op);
404 auto bbArg = cast<BlockArgument>(Val&: value);
405
406 // Function arguments are special.
407 if (bbArg.getOwner() == &funcOp.getBody().front())
408 return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
409 options);
410
411 return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
412 getBufferType(op, value, options, state, invocationStack);
413 }
414
415 /// Rewrite function bbArgs and return values into buffer form. This function
416 /// bufferizes the function signature and the ReturnOp. When the entire
417 /// function body has been bufferized, function return types can be switched
418 /// to more concise memref types as part of `foldMemRefCasts`.
419 ///
420 /// All function bbArgs are writable unless they are explicitly marked as
421 /// read-only. Callers must insert copies when needed.
422 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
423 const BufferizationOptions &options,
424 BufferizationState &state) const {
425 auto funcOp = cast<FuncOp>(op);
426 FunctionType funcType = funcOp.getFunctionType();
427
428 // Compute the argument types.
429 SmallVector<Type> argTypes;
430 for (const auto &it : llvm::enumerate(funcType.getInputs())) {
431 Type argType = it.value();
432 if (isa<TensorType>(argType)) {
433 argTypes.push_back(
434 getBufferizedFunctionArgType(funcOp, it.index(), options));
435 continue;
436 }
437 argTypes.push_back(argType);
438 }
439
440 // Compute the result types.
441 SmallVector<Type> retTypes;
442 for (Type resultType : funcType.getResults()) {
443 if (auto tensorType = dyn_cast<TensorType>(resultType)) {
444 BaseMemRefType resultType = options.functionArgTypeConverterFn(
445 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
446 options);
447 retTypes.push_back(resultType);
448 continue;
449 }
450 retTypes.push_back(resultType);
451 }
452
453 // Compute the new function type.
454 auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
455
456 // If the function has no body, set the new function type and we are done.
457 if (funcOp.isExternal()) {
458 funcOp.setType(newFuncType);
459 return success();
460 }
461
462 // 1. Bufferize every block.
463 for (Block &block : funcOp.getBody())
464 if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
465 options, state)))
466 return failure();
467
468 // 2. Bufferize the operands of the all return op.
469 for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
470 assert(returnOp->getNumOperands() == retTypes.size() &&
471 "incorrect number of return values");
472 SmallVector<Value> returnValues;
473 for (auto [returnVal, bufferizedType] :
474 llvm::zip_equal(returnOp->getOperands(), retTypes)) {
475 auto tensorType = dyn_cast<TensorType>(returnVal.getType());
476 rewriter.setInsertionPoint(returnOp);
477
478 // If not a tensor type just forward it.
479 if (!tensorType) {
480 returnValues.push_back(returnVal);
481 continue;
482 }
483
484 // Note: If `inferFunctionResultLayout = true`, casts are later folded
485 // away.
486 Value toBufferOp = rewriter.create<bufferization::ToBufferOp>(
487 returnOp.getLoc(), bufferizedType, returnVal);
488 returnValues.push_back(toBufferOp);
489 }
490
491 returnOp.getOperandsMutable().assign(returnValues);
492 }
493
494 // 3. Set the new function type.
495 funcOp.setType(newFuncType);
496 return success();
497 }
498
499 /// Return `true` if the given function argument is writable.
500 bool isWritable(Operation *op, Value value,
501 const AnalysisState &state) const {
502 auto funcOp = cast<FuncOp>(op);
503 BlockArgument bbArg = dyn_cast<BlockArgument>(Val&: value);
504 assert(bbArg && "expected BlockArgument");
505
506 // Non-entry block arguments are always writable. (They may alias with
507 // values that are not writable, which will turn them into read-only.)
508 if (bbArg.getOwner() != &funcOp.getBody().front())
509 return true;
510
511 // "bufferization.writable" overrides other writability decisions. This is
512 // currently used for testing only.
513 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
514 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
515 return writable.getValue();
516
517 // All function arguments are writable by default.
518 return true;
519 }
520};
521
522} // namespace func_ext
523} // namespace bufferization
524} // namespace mlir
525
526void mlir::bufferization::func_ext::
527 registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
528 registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) {
529 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
530 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
531 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
532 });
533}
534

source code of mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp