1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
10#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
13#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/Operation.h"
18#include <optional>
19
20namespace mlir {
21/// Return all func.return ops in the given function.
22SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
23 SmallVector<func::ReturnOp> result;
24 for (Block &b : funcOp.getBody())
25 if (auto returnOp = dyn_cast<func::ReturnOp>(Val: b.getTerminator()))
26 result.push_back(Elt: returnOp);
27 return result;
28}
29
30namespace bufferization {
31namespace func_ext {
32
33void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
34 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
35 auto createdEquiv = equivalentFuncArgs.try_emplace(Key: funcOp, Args: IndexMapping());
36 auto createdAliasingResults =
37 aliasingReturnVals.try_emplace(Key: funcOp, Args: IndexToIndexListMapping());
38 auto createdRead = readBbArgs.try_emplace(Key: funcOp, Args: BbArgIndexSet());
39 auto createdWritten = writtenBbArgs.try_emplace(Key: funcOp, Args: BbArgIndexSet());
40 (void)createdEquiv;
41 (void)createdAliasingResults;
42 (void)createdRead;
43 (void)createdWritten;
44#ifndef NDEBUG
45 assert(createdEquiv.second && "equivalence info exists already");
46 assert(createdAliasingResults.second && "aliasing info exists already");
47 assert(createdRead.second && "bbarg access info exists already");
48 assert(createdWritten.second && "bbarg access info exists already");
49#endif // NDEBUG
50}
51
52/// Return the index-th bufferized function argument type. This assumes that the
53/// specified argument is a tensor. If the tensor is ranked, a layout map may be
54/// specified by the user (as per `options.functionArgTypeConverterFn`).
55static BaseMemRefType
56getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
57 const BufferizationOptions &options) {
58 auto tensorType =
59 dyn_cast<TensorType>(Val: funcOp.getFunctionType().getInput(i: index));
60 assert(tensorType && "expected TensorType");
61
62 BaseMemRefType memrefType = options.functionArgTypeConverterFn(
63 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
64
65 auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
66 index, name: BufferizationDialect::kBufferLayoutAttrName);
67 if (!layoutAttr)
68 return memrefType;
69
70 auto rankedMemrefType = dyn_cast<MemRefType>(Val&: memrefType);
71 assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
72 return MemRefType::get(shape: rankedMemrefType.getShape(),
73 elementType: rankedMemrefType.getElementType(), layout: layoutAttr,
74 memorySpace: rankedMemrefType.getMemorySpace());
75}
76
77/// Return the FuncOp called by `callOp`.
78static FuncOp getCalledFunction(CallOpInterface callOp,
79 SymbolTableCollection &symbolTables) {
80 SymbolRefAttr sym =
81 llvm::dyn_cast_if_present<SymbolRefAttr>(Val: callOp.getCallableForCallee());
82 if (!sym)
83 return nullptr;
84 return dyn_cast_or_null<FuncOp>(
85 Val: symbolTables.lookupNearestSymbolFrom(from: callOp, symbol: sym));
86}
87
88/// Return the FuncOp called by `callOp`.
89static FuncOp getCalledFunction(CallOpInterface callOp,
90 const AnalysisState &state) {
91 auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
92
93 if (auto *funcAnalysisState =
94 oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
95 // Use the cached symbol tables.
96 return getCalledFunction(callOp, symbolTables&: funcAnalysisState->symbolTables);
97 }
98
99 SymbolTableCollection symbolTables;
100 return getCalledFunction(callOp, symbolTables);
101}
102
103/// Get FuncAnalysisState.
104static const FuncAnalysisState &
105getFuncAnalysisState(const AnalysisState &state) {
106 assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
107 auto *result = static_cast<const OneShotAnalysisState &>(state)
108 .getExtension<FuncAnalysisState>();
109 assert(result && "FuncAnalysisState does not exist");
110 return *result;
111}
112
113/// Return the state (phase) of analysis of the FuncOp.
114static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
115 FuncOp funcOp) {
116 if (!isa<OneShotAnalysisState>(Val: state))
117 return FuncOpAnalysisState::NotAnalyzed;
118 auto *funcState = static_cast<const OneShotAnalysisState &>(state)
119 .getExtension<FuncAnalysisState>();
120 if (!funcState)
121 return FuncOpAnalysisState::NotAnalyzed;
122 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
123 auto it = analyzedFuncOps.find(Val: funcOp);
124 if (it == analyzedFuncOps.end())
125 return FuncOpAnalysisState::NotAnalyzed;
126 return it->second;
127}
128
129/// Return the index of the bbArg in the given FuncOp that is equivalent to the
130/// specified return value (if any).
131static std::optional<int64_t>
132getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state,
133 int64_t returnValIdx) {
134 auto funcOpIt = state.equivalentFuncArgs.find(Val: funcOp);
135 if (funcOpIt == state.equivalentFuncArgs.end())
136 // No equivalence info stores for funcOp.
137 return std::nullopt;
138
139 auto retValIt = funcOpIt->getSecond().find(Val: returnValIdx);
140 if (retValIt == funcOpIt->getSecond().end())
141 // Return value has no equivalent bbArg.
142 return std::nullopt;
143
144 return retValIt->getSecond();
145}
146
147struct CallOpInterface
148 : public BufferizableOpInterface::ExternalModel<CallOpInterface,
149 func::CallOp> {
150 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
151 const AnalysisState &state) const {
152 func::CallOp callOp = cast<func::CallOp>(Val: op);
153 FuncOp funcOp = getCalledFunction(callOp, state);
154 assert(funcOp && "expected CallOp to a FuncOp");
155
156 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
157 // FuncOp not analyzed yet. Assume that OpOperand is read.
158 return true;
159
160 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
161 return funcState.readBbArgs.lookup(Val: funcOp).contains(
162 V: opOperand.getOperandNumber());
163 }
164
165 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
166 const AnalysisState &state) const {
167 func::CallOp callOp = cast<func::CallOp>(Val: op);
168 FuncOp funcOp = getCalledFunction(callOp, state);
169 assert(funcOp && "expected CallOp to a FuncOp");
170
171 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
172 // FuncOp not analyzed yet. Assume that OpOperand is written.
173 return true;
174
175 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
176 return funcState.writtenBbArgs.lookup(Val: funcOp).contains(
177 V: opOperand.getOperandNumber());
178 }
179
180 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
181 const AnalysisState &state) const {
182 func::CallOp callOp = cast<func::CallOp>(Val: op);
183 FuncOp funcOp = getCalledFunction(callOp, state);
184 assert(funcOp && "expected CallOp to a FuncOp");
185 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
186 // FuncOp not analyzed yet. Any OpResult may be aliasing.
187 return detail::unknownGetAliasingValues(opOperand);
188
189 // Get aliasing results from state.
190 const FuncAnalysisState &funcState = getFuncAnalysisState(state);
191 auto aliasingReturnVals =
192 funcState.aliasingReturnVals.lookup(Val: funcOp).lookup(
193 Val: opOperand.getOperandNumber());
194
195 // Check if the aliasing OpResult is equivalent to the OpOperand.
196 std::optional<int64_t> equivalent = {};
197 if (aliasingReturnVals.size() == 1) {
198 equivalent = getEquivalentFuncArgIdx(funcOp, state: funcState,
199 returnValIdx: aliasingReturnVals.front());
200 assert((!equivalent.has_value() ||
201 *equivalent == opOperand.getOperandNumber()) &&
202 "inconsistent analysis state");
203 }
204 AliasingValueList result;
205 for (int64_t resultIdx : aliasingReturnVals)
206 result.addAlias(alias: {callOp->getOpResult(idx: resultIdx),
207 equivalent.has_value() ? BufferRelation::Equivalent
208 : BufferRelation::Unknown,
209 /*isDefinite=*/equivalent.has_value()});
210 return result;
211 }
212
213 FailureOr<BufferLikeType>
214 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
215 const BufferizationState &state,
216 SmallVector<Value> &invocationStack) const {
217 auto callOp = cast<func::CallOp>(Val: op);
218
219 // TODO Avoid recomputing the symbol tables every time.
220 SymbolTableCollection symbolTable;
221
222 FuncOp funcOp = getCalledFunction(callOp, symbolTables&: symbolTable);
223 assert(funcOp && "expected CallOp to a FuncOp");
224
225 // If the callee was already bufferized, we can directly take the type from
226 // its signature.
227 FunctionType funcType = funcOp.getFunctionType();
228 Type resultType =
229 funcType.getResult(i: cast<OpResult>(Val&: value).getResultNumber());
230 if (auto bufferizedType = dyn_cast<BaseMemRefType>(Val&: resultType))
231 return cast<BufferLikeType>(Val&: bufferizedType);
232
233 // Otherwise, call the type converter to compute the bufferized type.
234 auto tensorType = cast<TensorType>(Val&: resultType);
235 return cast<BufferLikeType>(Val: options.functionArgTypeConverterFn(
236 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
237 options));
238 }
239
240 /// All function arguments are writable. It is the responsibility of the
241 /// CallOp to insert buffer copies where necessary.
242 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
243 const BufferizationOptions &options,
244 BufferizationState &state) const {
245 func::CallOp callOp = cast<func::CallOp>(Val: op);
246
247 // 1. Compute the result types of the new CallOp.
248 SmallVector<Type> resultTypes;
249 for (Value result : callOp.getResults()) {
250 Type returnType = result.getType();
251 if (!isa<TensorType>(Val: returnType)) {
252 // Non-tensor values are returned.
253 resultTypes.push_back(Elt: returnType);
254 continue;
255 }
256
257 // Returning a memref.
258 FailureOr<BufferLikeType> resultType =
259 bufferization::getBufferType(value: result, options, state);
260 if (failed(Result: resultType))
261 return failure();
262 resultTypes.push_back(Elt: *resultType);
263 }
264
265 // 2. Rewrite tensor operands as memrefs based on type of the already
266 // bufferized callee.
267 SmallVector<Value> newOperands;
268
269 FuncOp funcOp = getCalledFunction(callOp, symbolTables&: state.getSymbolTables());
270 assert(funcOp && "expected CallOp to a FuncOp");
271 FunctionType funcType = funcOp.getFunctionType();
272
273 for (OpOperand &opOperand : callOp->getOpOperands()) {
274 // Non-tensor operands are just copied.
275 if (!isa<TensorType>(Val: opOperand.get().getType())) {
276 newOperands.push_back(Elt: opOperand.get());
277 continue;
278 }
279
280 // Retrieve buffers for tensor operands.
281 FailureOr<Value> maybeBuffer =
282 getBuffer(rewriter, value: opOperand.get(), options, state);
283 if (failed(Result: maybeBuffer))
284 return failure();
285 Value buffer = *maybeBuffer;
286
287 // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
288 auto memRefType = funcType.getInput(i: opOperand.getOperandNumber());
289 if (!isa<BaseMemRefType>(Val: memRefType)) {
290 // The called function was not bufferized yet. This can happen when
291 // there cycles in the function call graph. Compute the bufferized
292 // result type.
293 FailureOr<BufferLikeType> maybeBufferType =
294 bufferization::getBufferType(
295 value: funcOp.getArgument(idx: opOperand.getOperandNumber()), options,
296 state);
297 if (failed(Result: maybeBufferType))
298 return failure();
299 memRefType = *maybeBufferType;
300 }
301
302 // Since we don't yet have a clear layout story, to_buffer may
303 // conservatively turn tensors into more dynamic memref than necessary.
304 // If the memref type of the callee fails, introduce an extra memref.cast
305 // that will either canonicalize away or fail compilation until we can do
306 // something better. Insert a reallocation + copy if it cannot be
307 // statically guaranteed that a direct cast would be valid.
308 if (buffer.getType() != memRefType) {
309 auto memrefDstType = dyn_cast<MemRefType>(Val&: memRefType);
310 assert(memrefDstType &&
311 "buffer layout not supported on unranked tensors");
312 FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
313 b&: rewriter, value: buffer, type: memrefDstType, options);
314 if (failed(Result: replacement))
315 return failure();
316 buffer = *replacement;
317 }
318 newOperands.push_back(Elt: buffer);
319 }
320
321 // 3. Create the new CallOp.
322 Operation *newCallOp = rewriter.create<func::CallOp>(
323 location: callOp.getLoc(), args: funcOp.getSymName(), args&: resultTypes, args&: newOperands);
324 newCallOp->setAttrs(callOp->getAttrs());
325
326 // 4. Replace the old op with the new op.
327 replaceOpWithBufferizedValues(rewriter, op: callOp, values: newCallOp->getResults());
328
329 return success();
330 }
331};
332
333struct ReturnOpInterface
334 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
335 func::ReturnOp> {
336 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
337 const AnalysisState &state) const {
338 return true;
339 }
340
341 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
342 const AnalysisState &state) const {
343 return false;
344 }
345
346 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
347 const AnalysisState &state) const {
348 return {};
349 }
350
351 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
352 const BufferizationOptions &options,
353 BufferizationState &state) const {
354#ifndef NDEBUG
355 auto returnOp = cast<func::ReturnOp>(op);
356 assert(isa<FuncOp>(returnOp->getParentOp()) &&
357 "only support FuncOp parent for ReturnOp");
358#endif // NDEBUG
359
360 // ReturnOps are bufferized as part of FuncOps.
361 return success();
362 }
363};
364
365struct FuncOpInterface
366 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
367 FuncOpInterface, FuncOp> {
368
369 static bool supportsUnstructuredControlFlow() { return true; }
370
371 bool hasTensorSemantics(Operation *op) const {
372 auto isaTensor = llvm::IsaPred<TensorType>;
373
374 // A function has tensor semantics if it has tensor arguments/results.
375 auto funcOp = cast<FuncOp>(Val: op);
376 bool hasTensorArg = any_of(Range: funcOp.getArgumentTypes(), P: isaTensor);
377 bool hasTensorResult = any_of(Range: funcOp.getResultTypes(), P: isaTensor);
378 if (hasTensorArg || hasTensorResult)
379 return true;
380
381 // It also has tensor semantics if it has tensor block arguments.
382 // TODO: Decouple bufferization of unstructured control flow from
383 // BufferizableOpInterface implementations. We should only care about
384 // region entry block arguments here (which are already covered by the
385 // argument types of the function).
386 for (Block &block : funcOp.getBody())
387 if (any_of(Range: block.getArgumentTypes(), P: isaTensor))
388 return true;
389
390 return false;
391 }
392
393 AliasingOpOperandList
394 getAliasingOpOperands(Operation *op, Value value,
395 const AnalysisState &state) const {
396 return getAliasingBranchOpOperands(op, bbArg: cast<BlockArgument>(Val&: value), state);
397 }
398
399 FailureOr<BufferLikeType>
400 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
401 const BufferizationState &state,
402 SmallVector<Value> &invocationStack) const {
403 auto funcOp = cast<FuncOp>(Val: op);
404 auto bbArg = cast<BlockArgument>(Val&: value);
405
406 // Function arguments are special.
407 if (bbArg.getOwner() == &funcOp.getBody().front())
408 return cast<BufferLikeType>(
409 Val: getBufferizedFunctionArgType(funcOp, index: bbArg.getArgNumber(), options));
410
411 return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
412 getBufferType(op, value, options, state, invocationStack);
413 }
414
415 /// Rewrite function bbArgs and return values into buffer form. This function
416 /// bufferizes the function signature and the ReturnOp. When the entire
417 /// function body has been bufferized, function return types can be switched
418 /// to more concise memref types as part of `foldMemRefCasts`.
419 ///
420 /// All function bbArgs are writable unless they are explicitly marked as
421 /// read-only. Callers must insert copies when needed.
422 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
423 const BufferizationOptions &options,
424 BufferizationState &state) const {
425 auto funcOp = cast<FuncOp>(Val: op);
426 FunctionType funcType = funcOp.getFunctionType();
427
428 // Compute the argument types.
429 SmallVector<Type> argTypes;
430 for (const auto &it : llvm::enumerate(First: funcType.getInputs())) {
431 Type argType = it.value();
432 if (isa<TensorType>(Val: argType)) {
433 argTypes.push_back(
434 Elt: getBufferizedFunctionArgType(funcOp, index: it.index(), options));
435 continue;
436 }
437 argTypes.push_back(Elt: argType);
438 }
439
440 // Compute the result types.
441 SmallVector<Type> retTypes;
442 for (Type resultType : funcType.getResults()) {
443 if (auto tensorType = dyn_cast<TensorType>(Val&: resultType)) {
444 BaseMemRefType resultType = options.functionArgTypeConverterFn(
445 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
446 options);
447 retTypes.push_back(Elt: resultType);
448 continue;
449 }
450 retTypes.push_back(Elt: resultType);
451 }
452
453 // Compute the new function type.
454 auto newFuncType = FunctionType::get(context: op->getContext(), inputs: argTypes, results: retTypes);
455
456 // If the function has no body, set the new function type and we are done.
457 if (funcOp.isExternal()) {
458 funcOp.setType(newFuncType);
459 return success();
460 }
461
462 // 1. Bufferize every block.
463 for (Block &block : funcOp.getBody())
464 if (failed(Result: bufferization::bufferizeBlockSignature(block: &block, rewriter,
465 options, state)))
466 return failure();
467
468 // 2. Bufferize the operands of the all return op.
469 for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
470 assert(returnOp->getNumOperands() == retTypes.size() &&
471 "incorrect number of return values");
472 SmallVector<Value> returnValues;
473 for (auto [returnVal, bufferizedType] :
474 llvm::zip_equal(t: returnOp->getOperands(), u&: retTypes)) {
475 auto tensorType = dyn_cast<TensorType>(Val: returnVal.getType());
476 rewriter.setInsertionPoint(returnOp);
477
478 // If not a tensor type just forward it.
479 if (!tensorType) {
480 returnValues.push_back(Elt: returnVal);
481 continue;
482 }
483
484 // Note: If `inferFunctionResultLayout = true`, casts are later folded
485 // away.
486 Value toBufferOp = rewriter.create<bufferization::ToBufferOp>(
487 location: returnOp.getLoc(), args&: bufferizedType, args&: returnVal);
488 returnValues.push_back(Elt: toBufferOp);
489 }
490
491 returnOp.getOperandsMutable().assign(values: returnValues);
492 }
493
494 // 3. Set the new function type.
495 funcOp.setType(newFuncType);
496 return success();
497 }
498
499 /// Return `true` if the given function argument is writable.
500 bool isWritable(Operation *op, Value value,
501 const AnalysisState &state) const {
502 auto funcOp = cast<FuncOp>(Val: op);
503 BlockArgument bbArg = dyn_cast<BlockArgument>(Val&: value);
504 assert(bbArg && "expected BlockArgument");
505
506 // Non-entry block arguments are always writable. (They may alias with
507 // values that are not writable, which will turn them into read-only.)
508 if (bbArg.getOwner() != &funcOp.getBody().front())
509 return true;
510
511 // "bufferization.writable" overrides other writability decisions. This is
512 // currently used for testing only.
513 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
514 index: bbArg.getArgNumber(), name: BufferizationDialect::kWritableAttrName))
515 return writable.getValue();
516
517 // All function arguments are writable by default.
518 return true;
519 }
520};
521
522} // namespace func_ext
523} // namespace bufferization
524} // namespace mlir
525
526void mlir::bufferization::func_ext::
527 registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
528 registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) {
529 func::CallOp::attachInterface<func_ext::CallOpInterface>(context&: *ctx);
530 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(context&: *ctx);
531 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(context&: *ctx);
532 });
533}
534

source code of mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp