1//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Module Bufferization is an extension of One-Shot Bufferize that
10// bufferizes function boundaries. It provides `BufferizableOpInterface`
11// implementations for FuncOp, CallOp and ReturnOp.
12//
13// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14// This function analyzes the given module and determines the order of analysis
15// and bufferization: Functions that are called are processed before their
16// respective callers.
17//
18// After analyzing a FuncOp, additional information about its bbArgs is
19// gathered and stored in `FuncAnalysisState`.
20//
21// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22// for
23// each tensor return value (if any).
24// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25// read/written.
26//
27// Module Bufferization implements the following calling convention.
28//
29// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30// be written to in-place.
31// * If a tensor operand of a CallOp is read after the CallOp, the operand of
32// the CallOp must bufferize out-of-place.
33//
34// Example: The tensor.insert op bufferizes in-place because it is allowed to
35// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36// out-of-place because `%t0` is modified by the callee but read by the
37// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39// ```
40// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41// %f = ... : f32
42// %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43// return %0 : tensor<?xf32>
44// }
45//
46// func @caller() -> () {
47// %t0 = ... : tensor<?xf32>
48// %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49// %2 = tensor.extract %1[...] : tensor<?xf32>
50// }
51// ```
52//
53// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54// analyze the function body. In such a case, the CallOp analysis conservatively
55// assumes that each tensor OpOperand is both read and written.
56//
57// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58// as "not reading" and/or "not writing".
59
60#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
61
62#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
63#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
64#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
65#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
66#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
67#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
68#include "mlir/Dialect/Func/IR/FuncOps.h"
69#include "mlir/Dialect/MemRef/IR/MemRef.h"
70#include "mlir/IR/BuiltinTypes.h"
71#include "mlir/IR/Operation.h"
72
73using namespace mlir;
74using namespace mlir::bufferization;
75using namespace mlir::bufferization::func_ext;
76
77/// A mapping of FuncOps to their callers.
78using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
79
80/// Get or create FuncAnalysisState.
81static FuncAnalysisState &
82getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
83 auto *result = state.getExtension<FuncAnalysisState>();
84 if (result)
85 return *result;
86 return state.addExtension<FuncAnalysisState>();
87}
88
89/// Return the unique ReturnOp that terminates `funcOp`.
90/// Return nullptr if there is no such unique ReturnOp.
91static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92 func::ReturnOp returnOp;
93 for (Block &b : funcOp.getBody()) {
94 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
95 if (returnOp)
96 return nullptr;
97 returnOp = candidateOp;
98 }
99 }
100 return returnOp;
101}
102
103namespace {
104
105/// Annotate IR with the results of the analysis. For testing purposes only.
106static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
107 BlockArgument bbArg) {
108 const char *kEquivalentArgsAttr = "__equivalent_func_args__";
109 Operation *op = returnVal.getOwner();
110
111 SmallVector<int64_t> equivBbArgs;
112 if (op->hasAttr(name: kEquivalentArgsAttr)) {
113 auto attr = cast<ArrayAttr>(op->getAttr(name: kEquivalentArgsAttr));
114 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
115 return cast<IntegerAttr>(a).getValue().getSExtValue();
116 }));
117 } else {
118 equivBbArgs.append(NumInputs: op->getNumOperands(), Elt: -1);
119 }
120 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
121
122 OpBuilder b(op->getContext());
123 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
124}
125
126/// Store function BlockArguments that are equivalent to/aliasing a returned
127/// value in FuncAnalysisState.
128static LogicalResult
129aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
130 FuncAnalysisState &funcState) {
131 if (funcOp.getBody().empty()) {
132 // No function body available. Conservatively assume that every tensor
133 // return value may alias with any tensor bbArg.
134 FunctionType type = funcOp.getFunctionType();
135 for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
136 if (!isa<TensorType>(inputIt.value()))
137 continue;
138 for (const auto &resultIt : llvm::enumerate(type.getResults())) {
139 if (!isa<TensorType>(resultIt.value()))
140 continue;
141 int64_t returnIdx = resultIt.index();
142 int64_t bbArgIdx = inputIt.index();
143 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
144 }
145 }
146 return success();
147 }
148
149 // Support only single return-terminated block in the function.
150 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151 assert(returnOp && "expected func with single return op");
152
153 for (OpOperand &returnVal : returnOp->getOpOperands())
154 if (isa<RankedTensorType>(returnVal.get().getType()))
155 for (BlockArgument bbArg : funcOp.getArguments())
156 if (isa<RankedTensorType>(bbArg.getType())) {
157 int64_t returnIdx = returnVal.getOperandNumber();
158 int64_t bbArgIdx = bbArg.getArgNumber();
159 if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
160 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
161 if (state.getOptions().testAnalysisOnly)
162 annotateEquivalentReturnBbArg(returnVal, bbArg);
163 }
164 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
165 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
166 }
167
168 return success();
169}
170
171static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
172 bool isWritten) {
173 OpBuilder b(funcOp.getContext());
174 Attribute accessType;
175 if (isRead && isWritten) {
176 accessType = b.getStringAttr("read-write");
177 } else if (isRead) {
178 accessType = b.getStringAttr("read");
179 } else if (isWritten) {
180 accessType = b.getStringAttr("write");
181 } else {
182 accessType = b.getStringAttr("none");
183 }
184 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
185 accessType);
186}
187
188/// Determine which FuncOp bbArgs are read and which are written. When run on a
189/// function with unknown ops, we conservatively assume that such ops bufferize
190/// to a read + write.
191static LogicalResult
192funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
193 FuncAnalysisState &funcState) {
194 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
195 ++idx) {
196 // Skip non-tensor arguments.
197 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
198 continue;
199 bool isRead;
200 bool isWritten;
201 if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
202 idx, BufferizationDialect::kBufferAccessAttrName)) {
203 // Buffer access behavior is specified on the function. Skip the analysis.
204 StringRef str = accessAttr.getValue();
205 isRead = str == "read" || str == "read-write";
206 isWritten = str == "write" || str == "read-write";
207 } else if (funcOp.getBody().empty()) {
208 // If the function has no body, conservatively assume that all args are
209 // read + written.
210 isRead = true;
211 isWritten = true;
212 } else {
213 // Analyze the body of the function.
214 BlockArgument bbArg = funcOp.getArgument(idx);
215 isRead = state.isValueRead(bbArg);
216 isWritten = state.isValueWritten(value: bbArg);
217 }
218
219 if (state.getOptions().testAnalysisOnly)
220 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
221 if (isRead)
222 funcState.readBbArgs[funcOp].insert(idx);
223 if (isWritten)
224 funcState.writtenBbArgs[funcOp].insert(idx);
225 }
226
227 return success();
228}
229} // namespace
230
231/// Remove bufferization attributes on FuncOp arguments.
232static void removeBufferizationAttributes(BlockArgument bbArg) {
233 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
234 funcOp.removeArgAttr(bbArg.getArgNumber(),
235 BufferizationDialect::kBufferLayoutAttrName);
236 funcOp.removeArgAttr(bbArg.getArgNumber(),
237 BufferizationDialect::kWritableAttrName);
238}
239
240/// Return the func::FuncOp called by `callOp`.
241static func::FuncOp getCalledFunction(func::CallOp callOp) {
242 SymbolRefAttr sym =
243 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
244 if (!sym)
245 return nullptr;
246 return dyn_cast_or_null<func::FuncOp>(
247 SymbolTable::lookupNearestSymbolFrom(callOp, sym));
248}
249
250/// Gather equivalence info of CallOps.
251/// Note: This only adds new equivalence info if the called function was already
252/// analyzed.
253// TODO: This does not handle cyclic function call graphs etc.
254static void equivalenceAnalysis(func::FuncOp funcOp,
255 OneShotAnalysisState &state,
256 FuncAnalysisState &funcState) {
257 funcOp->walk([&](func::CallOp callOp) {
258 func::FuncOp calledFunction = getCalledFunction(callOp);
259 assert(calledFunction && "could not retrieved called func::FuncOp");
260
261 // No equivalence info available for the called function.
262 if (!funcState.equivalentFuncArgs.count(Val: calledFunction))
263 return WalkResult::skip();
264
265 for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
266 int64_t returnIdx = it.first;
267 int64_t bbargIdx = it.second;
268 if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
269 continue;
270 Value returnVal = callOp.getResult(returnIdx);
271 Value argVal = callOp->getOperand(bbargIdx);
272 state.unionEquivalenceClasses(returnVal, argVal);
273 }
274
275 return WalkResult::advance();
276 });
277}
278
279/// Return "true" if the given function signature has tensor semantics.
280static bool hasTensorSignature(func::FuncOp funcOp) {
281 return llvm::any_of(funcOp.getFunctionType().getInputs(),
282 llvm::IsaPred<TensorType>) ||
283 llvm::any_of(funcOp.getFunctionType().getResults(),
284 llvm::IsaPred<TensorType>);
285}
286
287/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288/// callee-caller order (i.e. callees without callers first).
289/// Store the map of FuncOp to all its callers in `callerMap`.
290/// Return `failure()` if a cycle of calls is detected or if we are unable to
291/// retrieve the called FuncOp from any func::CallOp.
292static LogicalResult
293getFuncOpsOrderedByCalls(ModuleOp moduleOp,
294 SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295 FuncCallerMap &callerMap) {
296 // For each FuncOp, the set of functions called by it (i.e. the union of
297 // symbols of all nested func::CallOp).
298 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
299 // For each FuncOp, the number of func::CallOp it contains.
300 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302 if (!funcOp.getBody().empty()) {
303 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
304 if (!returnOp)
305 return funcOp->emitError()
306 << "cannot bufferize a FuncOp with tensors and "
307 "without a unique ReturnOp";
308 }
309
310 // Collect function calls and populate the caller map.
311 numberCallOpsContainedInFuncOp[funcOp] = 0;
312 return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
313 func::FuncOp calledFunction = getCalledFunction(callOp);
314 assert(calledFunction && "could not retrieved called func::FuncOp");
315 // If the called function does not have any tensors in its signature, then
316 // it is not necessary to bufferize the callee before the caller.
317 if (!hasTensorSignature(calledFunction))
318 return WalkResult::skip();
319
320 callerMap[calledFunction].insert(callOp);
321 if (calledBy[calledFunction].insert(funcOp).second) {
322 numberCallOpsContainedInFuncOp[funcOp]++;
323 }
324 return WalkResult::advance();
325 });
326 });
327 if (res.wasInterrupted())
328 return failure();
329 // Iteratively remove function operations that do not call any of the
330 // functions remaining in the callCounter map and add them to the worklist.
331 while (!numberCallOpsContainedInFuncOp.empty()) {
332 auto it = llvm::find_if(Range&: numberCallOpsContainedInFuncOp,
333 P: [](auto entry) { return entry.getSecond() == 0; });
334 if (it == numberCallOpsContainedInFuncOp.end())
335 return moduleOp.emitOpError(
336 "expected callgraph to be free of circular dependencies.");
337 orderedFuncOps.push_back(it->getFirst());
338 for (auto callee : calledBy[it->getFirst()])
339 numberCallOpsContainedInFuncOp[callee]--;
340 numberCallOpsContainedInFuncOp.erase(I: it);
341 }
342 return success();
343}
344
345/// Fold return values that are memref casts and update function return types.
346///
347/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
348/// is not known yet. Therefore, the bufferization uses memref types with the
349/// most generic layout map as function return types. After bufferizing the
350/// entire function body, a more concise memref type can potentially be used for
351/// the return type of the function.
352static void foldMemRefCasts(func::FuncOp funcOp) {
353 if (funcOp.getBody().empty())
354 return;
355
356 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
357 SmallVector<Type> resultTypes;
358
359 for (OpOperand &operand : returnOp->getOpOperands()) {
360 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
361 operand.set(castOp.getSource());
362 resultTypes.push_back(castOp.getSource().getType());
363 } else {
364 resultTypes.push_back(operand.get().getType());
365 }
366 }
367
368 auto newFuncType = FunctionType::get(
369 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
370 funcOp.setType(newFuncType);
371}
372
373LogicalResult
374mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
375 OneShotAnalysisState &state,
376 BufferizationStatistics *statistics) {
377 assert(state.getOptions().bufferizeFunctionBoundaries &&
378 "expected that function boundary bufferization is activated");
379 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380
381 // A list of functions in the order in which they are analyzed + bufferized.
382 SmallVector<func::FuncOp> orderedFuncOps;
383
384 // A mapping of FuncOps to their callers.
385 FuncCallerMap callerMap;
386
387 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
388 return failure();
389
390 // Analyze ops.
391 for (func::FuncOp funcOp : orderedFuncOps) {
392 if (!state.getOptions().isOpAllowed(funcOp))
393 continue;
394
395 // Now analyzing function.
396 funcState.startFunctionAnalysis(funcOp);
397
398 // Gather equivalence info for CallOps.
399 equivalenceAnalysis(funcOp, state, funcState);
400
401 // Analyze funcOp.
402 if (failed(analyzeOp(funcOp, state, statistics)))
403 return failure();
404
405 // Run some extra function analyses.
406 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
407 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
408 return failure();
409
410 // Mark op as fully analyzed.
411 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
412 }
413
414 return success();
415}
416
417void mlir::bufferization::removeBufferizationAttributesInModule(
418 ModuleOp moduleOp) {
419 moduleOp.walk([&](func::FuncOp op) {
420 for (BlockArgument bbArg : op.getArguments())
421 removeBufferizationAttributes(bbArg);
422 });
423}
424
425LogicalResult mlir::bufferization::bufferizeModuleOp(
426 ModuleOp moduleOp, const OneShotBufferizationOptions &options,
427 BufferizationStatistics *statistics) {
428 assert(options.bufferizeFunctionBoundaries &&
429 "expected that function boundary bufferization is activated");
430 IRRewriter rewriter(moduleOp.getContext());
431
432 // A list of functions in the order in which they are analyzed + bufferized.
433 SmallVector<func::FuncOp> orderedFuncOps;
434
435 // A mapping of FuncOps to their callers.
436 FuncCallerMap callerMap;
437
438 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
439 return failure();
440
441 // Bufferize functions.
442 for (func::FuncOp funcOp : orderedFuncOps) {
443 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
444 // would be invalidated.
445
446 if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
447 // This function was not analyzed and RaW conflicts were not resolved.
448 // Buffer copies must be inserted before every write.
449 OneShotBufferizationOptions updatedOptions = options;
450 updatedOptions.copyBeforeWrite = true;
451 if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
452 return failure();
453 } else {
454 if (failed(bufferizeOp(funcOp, options, statistics)))
455 return failure();
456 }
457
458 // Change buffer return types to more precise layout maps.
459 if (options.inferFunctionResultLayout)
460 foldMemRefCasts(funcOp);
461 }
462
463 // Bufferize all other ops.
464 for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
465 // Functions were already bufferized.
466 if (isa<func::FuncOp>(&op))
467 continue;
468 if (failed(bufferizeOp(&op, options, statistics)))
469 return failure();
470 }
471
472 // Post-pass cleanup of function argument attributes.
473 removeBufferizationAttributesInModule(moduleOp);
474
475 return success();
476}
477
478LogicalResult mlir::bufferization::runOneShotModuleBufferize(
479 ModuleOp moduleOp, const OneShotBufferizationOptions &options,
480 BufferizationStatistics *statistics) {
481 assert(options.bufferizeFunctionBoundaries &&
482 "expected that function boundary bufferization is activated");
483 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
484 "invalid combination of bufferization flags");
485 if (!options.copyBeforeWrite) {
486 if (options.noAnalysisFuncFilter.empty()) {
487 if (failed(insertTensorCopies(moduleOp, options, statistics)))
488 return failure();
489 } else {
490 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
491 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
492 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493 auto func = dyn_cast<func::FuncOp>(op);
494 if (!func)
495 func = op->getParentOfType<func::FuncOp>();
496 if (func)
497 return llvm::is_contained(options.noAnalysisFuncFilter,
498 func.getSymName());
499 return false;
500 };
501 OneShotBufferizationOptions updatedOptions(options);
502 updatedOptions.opFilter.denyOperation(analysisFilterFn);
503 if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
504 return failure();
505 }
506 }
507 if (options.testAnalysisOnly)
508 return success();
509 if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
510 return failure();
511 return success();
512}
513

source code of mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp