1//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Interfaces/ControlFlowInterfaces.h"
22#include "mlir/Interfaces/SideEffectInterfaces.h"
23#include "mlir/Pass/PassManager.h"
24#include "mlir/Transforms/Passes.h"
25#include <optional>
26
27namespace mlir {
28namespace bufferization {
29#define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
30#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31} // namespace bufferization
32} // namespace mlir
33
34#define DEBUG_TYPE "bufferize"
35
36using namespace mlir;
37using namespace mlir::bufferization;
38
39namespace {
40
41static OneShotBufferizationOptions::AnalysisHeuristic
42parseHeuristicOption(const std::string &s) {
43 if (s == "bottom-up")
44 return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp;
45 if (s == "top-down")
46 return OneShotBufferizationOptions::AnalysisHeuristic::TopDown;
47 if (s == "bottom-up-from-terminators")
48 return OneShotBufferizationOptions::AnalysisHeuristic::
49 BottomUpFromTerminators;
50 if (s == "fuzzer")
51 return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer;
52 llvm_unreachable("invalid analysisheuristic option");
53}
54
55struct OneShotBufferizePass
56 : public bufferization::impl::OneShotBufferizePassBase<
57 OneShotBufferizePass> {
58 using Base::Base;
59
60 void runOnOperation() override {
61 OneShotBufferizationOptions opt;
62 if (!options) {
63 // Make new bufferization options if none were provided when creating the
64 // pass.
65 opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
66 opt.allowUnknownOps = allowUnknownOps;
67 opt.analysisFuzzerSeed = analysisFuzzerSeed;
68 opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
69 opt.copyBeforeWrite = copyBeforeWrite;
70 opt.dumpAliasSets = dumpAliasSets;
71 opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
72
73 if (mustInferMemorySpace && useEncodingForMemorySpace) {
74 emitError(getOperation()->getLoc())
75 << "only one of 'must-infer-memory-space' and "
76 "'use-encoding-for-memory-space' are allowed in "
77 << getArgument();
78 return signalPassFailure();
79 }
80
81 if (mustInferMemorySpace) {
82 opt.defaultMemorySpaceFn =
83 [](TensorType t) -> std::optional<Attribute> {
84 return std::nullopt;
85 };
86 }
87
88 if (useEncodingForMemorySpace) {
89 opt.defaultMemorySpaceFn =
90 [](TensorType t) -> std::optional<Attribute> {
91 if (auto rtt = dyn_cast<RankedTensorType>(t))
92 return rtt.getEncoding();
93 return std::nullopt;
94 };
95 }
96
97 opt.printConflicts = printConflicts;
98 opt.bufferAlignment = bufferAlignment;
99 opt.testAnalysisOnly = testAnalysisOnly;
100 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
101 opt.checkParallelRegions = checkParallelRegions;
102 opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
103
104 // Configure type converter.
105 LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
106 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
107 emitError(UnknownLoc::get(&getContext()),
108 "Invalid option: 'infer-layout-map' is not a valid value for "
109 "'unknown-type-conversion'");
110 return signalPassFailure();
111 }
112 opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
113 const BufferizationOptions &options) {
114 auto tensorType = cast<TensorType>(value.getType());
115 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116 return bufferization::getMemRefTypeWithStaticIdentityLayout(
117 tensorType, memorySpace);
118 assert(unknownTypeConversionOption ==
119 LayoutMapOption::FullyDynamicLayoutMap &&
120 "invalid layout map option");
121 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
122 memorySpace);
123 };
124
125 // Configure op filter.
126 OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
127 // Filter may be specified via options.
128 if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
129 return llvm::is_contained(this->dialectFilter,
130 op->getDialect()->getNamespace());
131 // No filter specified: All other ops are allowed.
132 return true;
133 };
134 opt.opFilter.allowOperation(filterFn);
135 } else {
136 opt = *options;
137 }
138
139 if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
140 // These two flags do not make sense together: "copy-before-write"
141 // indicates that copies should be inserted before every memory write,
142 // but "test-analysis-only" indicates that only the analysis should be
143 // tested. (I.e., no IR is bufferized.)
144 emitError(UnknownLoc::get(&getContext()),
145 "Invalid option: 'copy-before-write' cannot be used with "
146 "'test-analysis-only'");
147 return signalPassFailure();
148 }
149
150 if (opt.printConflicts && !opt.testAnalysisOnly) {
151 emitError(
152 UnknownLoc::get(&getContext()),
153 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
154 return signalPassFailure();
155 }
156
157 if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
158 emitError(
159 UnknownLoc::get(&getContext()),
160 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
161 return signalPassFailure();
162 }
163
164 BufferizationState state;
165 BufferizationStatistics statistics;
166 ModuleOp moduleOp = getOperation();
167 if (opt.bufferizeFunctionBoundaries) {
168 if (failed(
169 runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
170 signalPassFailure();
171 return;
172 }
173 } else {
174 if (!opt.noAnalysisFuncFilter.empty()) {
175 emitError(UnknownLoc::get(&getContext()),
176 "Invalid option: 'no-analysis-func-filter' requires "
177 "'bufferize-function-boundaries'");
178 return signalPassFailure();
179 }
180 if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
181 signalPassFailure();
182 return;
183 }
184 }
185
186 // Set pass statistics.
187 this->numBufferAlloc = statistics.numBufferAlloc;
188 this->numTensorInPlace = statistics.numTensorInPlace;
189 this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
190 }
191
192private:
193 std::optional<OneShotBufferizationOptions> options;
194};
195} // namespace
196
197//===----------------------------------------------------------------------===//
198// BufferizableOpInterface-based Bufferization
199//===----------------------------------------------------------------------===//
200
201namespace {
202/// A rewriter that keeps track of extra information during bufferization.
203class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
204public:
205 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
206 DenseSet<Operation *> &toBufferOps,
207 SmallVector<Operation *> &worklist,
208 const BufferizationOptions &options,
209 BufferizationStatistics *statistics)
210 : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
211 worklist(worklist), analysisState(options), statistics(statistics) {
212 setListener(this);
213 }
214
215protected:
216 void notifyOperationErased(Operation *op) override {
217 erasedOps.insert(V: op);
218 // Erase if present.
219 toBufferOps.erase(V: op);
220 }
221
222 void notifyOperationInserted(Operation *op, InsertPoint previous) override {
223 // We only care about newly created ops.
224 if (previous.isSet())
225 return;
226
227 erasedOps.erase(V: op);
228
229 // Gather statistics about allocs.
230 if (statistics) {
231 if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
232 statistics->numBufferAlloc += static_cast<int64_t>(
233 sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
234 }
235
236 // Keep track of to_buffer ops.
237 if (isa<ToBufferOp>(op)) {
238 toBufferOps.insert(V: op);
239 return;
240 }
241
242 // Skip to_tensor ops.
243 if (isa<ToTensorOp>(op))
244 return;
245
246 // Skip non-tensor ops.
247 if (!hasTensorSemantics(op))
248 return;
249
250 // Skip ops that are not allowed to be bufferized.
251 auto const &options = analysisState.getOptions();
252 if (!options.isOpAllowed(op))
253 return;
254
255 // Add op to worklist.
256 worklist.push_back(Elt: op);
257 }
258
259private:
260 /// A set of all erased ops.
261 DenseSet<Operation *> &erasedOps;
262
263 /// A set of all to_buffer ops.
264 DenseSet<Operation *> &toBufferOps;
265
266 /// The worklist of ops to be bufferized.
267 SmallVector<Operation *> &worklist;
268
269 /// The analysis state. Used for debug assertions and access to the
270 /// bufferization options.
271 const AnalysisState analysisState;
272
273 /// Bufferization statistics for debugging.
274 BufferizationStatistics *statistics;
275};
276} // namespace
277
278LogicalResult bufferization::bufferizeOp(Operation *op,
279 const BufferizationOptions &options,
280 BufferizationState &bufferizationState,
281 BufferizationStatistics *statistics) {
282 if (options.copyBeforeWrite) {
283 AnalysisState analysisState(options);
284 if (failed(Result: insertTensorCopies(op, analysisState, bufferizationState)))
285 return failure();
286 }
287
288 // Keep track of to_buffer ops.
289 DenseSet<Operation *> toBufferOps;
290 op->walk(callback: [&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
291
292 // Gather all bufferizable ops in top-to-bottom order.
293 //
294 // We should ideally know the exact memref type of all operands when
295 // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
296 // Otherwise, we have to use a memref type with a fully dynamic layout map to
297 // avoid copies. We are currently missing patterns for layout maps to
298 // canonicalize away (or canonicalize to more precise layouts).
299 SmallVector<Operation *> worklist;
300 op->walk<WalkOrder::PostOrder>(callback: [&](Operation *op) {
301 if (options.isOpAllowed(op) && hasTensorSemantics(op))
302 worklist.push_back(Elt: op);
303 });
304
305 // Keep track of all erased ops.
306 DenseSet<Operation *> erasedOps;
307
308 // Bufferize all ops.
309 BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
310 worklist, options, statistics);
311 for (unsigned i = 0; i < worklist.size(); ++i) {
312 Operation *nextOp = worklist[i];
313 // Skip ops that were erased.
314 if (erasedOps.contains(V: nextOp))
315 continue;
316 // Skip ops that are not bufferizable or not allowed.
317 auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
318 if (!bufferizableOp)
319 continue;
320 // Skip ops that no longer have tensor semantics.
321 if (!hasTensorSemantics(op: nextOp))
322 continue;
323 // Check for unsupported unstructured control flow.
324 if (!bufferizableOp.supportsUnstructuredControlFlow())
325 for (Region &r : nextOp->getRegions())
326 if (r.getBlocks().size() > 1)
327 return nextOp->emitOpError(
328 message: "op or BufferizableOpInterface implementation does not support "
329 "unstructured control flow, but at least one region has multiple "
330 "blocks");
331
332 // Bufferize the op.
333 LLVM_DEBUG(llvm::dbgs()
334 << "//===-------------------------------------------===//\n"
335 << "IR after bufferizing: " << nextOp->getName() << "\n");
336 rewriter.setInsertionPoint(nextOp);
337 if (failed(
338 bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
339 LLVM_DEBUG(llvm::dbgs()
340 << "failed to bufferize\n"
341 << "//===-------------------------------------------===//\n");
342 return nextOp->emitError(message: "failed to bufferize op");
343 }
344 LLVM_DEBUG(llvm::dbgs()
345 << *op
346 << "\n//===-------------------------------------------===//\n");
347 }
348
349 // Return early if the top-level op is entirely gone.
350 if (erasedOps.contains(V: op))
351 return success();
352
353 // Fold all to_buffer(to_tensor(x)) pairs.
354 for (Operation *op : toBufferOps) {
355 rewriter.setInsertionPoint(op);
356 (void)bufferization::foldToBufferToTensorPair(
357 rewriter, cast<ToBufferOp>(op), options);
358 }
359
360 // Remove all dead to_tensor ops.
361 op->walk<WalkOrder::PostOrder>(callback: [&](ToTensorOp toTensorOp) {
362 if (toTensorOp->getUses().empty()) {
363 rewriter.eraseOp(op: toTensorOp);
364 return WalkResult::skip();
365 }
366 return WalkResult::advance();
367 });
368
369 /// Check the result of bufferization. Return an error if an op was not
370 /// bufferized, unless partial bufferization is allowed.
371 if (options.allowUnknownOps)
372 return success();
373
374 for (Operation *op : worklist) {
375 // Skip ops that are entirely gone.
376 if (erasedOps.contains(V: op))
377 continue;
378 // Ops that no longer have tensor semantics (because they were updated
379 // in-place) are allowed.
380 if (!hasTensorSemantics(op))
381 continue;
382 // Continue ops that are not allowed.
383 if (!options.isOpAllowed(op))
384 continue;
385 // Ops without any uses and no side effects will fold away.
386 if (op->getUses().empty() && isMemoryEffectFree(op))
387 continue;
388 // ToTensorOps/ToBufferOps are allowed in the output.
389 if (isa<ToTensorOp, ToBufferOp>(op))
390 continue;
391 return op->emitError(message: "op was not bufferized");
392 }
393
394 return success();
395}
396
397LogicalResult
398bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
399 const BufferizationOptions &options,
400 BufferizationState &state) {
401 OpBuilder::InsertionGuard g(rewriter);
402 auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
403 if (!bufferizableOp)
404 return failure();
405
406 // Compute the new signature.
407 SmallVector<Type> newTypes;
408 for (BlockArgument &bbArg : block->getArguments()) {
409 auto tensorType = dyn_cast<TensorType>(Val: bbArg.getType());
410 if (!tensorType) {
411 newTypes.push_back(Elt: bbArg.getType());
412 continue;
413 }
414
415 FailureOr<BaseMemRefType> memrefType =
416 bufferization::getBufferType(value: bbArg, options, state);
417 if (failed(Result: memrefType))
418 return failure();
419 newTypes.push_back(Elt: *memrefType);
420 }
421
422 // Change the type of all block arguments.
423 for (auto [bbArg, type] : llvm::zip(t: block->getArguments(), u&: newTypes)) {
424 if (bbArg.getType() == type)
425 continue;
426
427 // Collect all uses of the bbArg.
428 SmallVector<OpOperand *> bbArgUses;
429 for (OpOperand &use : bbArg.getUses())
430 bbArgUses.push_back(Elt: &use);
431
432 Type tensorType = bbArg.getType();
433 // Change the bbArg type to memref.
434 bbArg.setType(type);
435
436 // Replace all uses of the original tensor bbArg.
437 rewriter.setInsertionPointToStart(block);
438 if (!bbArgUses.empty()) {
439 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
440 bbArg.getLoc(), tensorType, bbArg);
441 for (OpOperand *use : bbArgUses)
442 use->set(toTensorOp);
443 }
444 }
445
446 // Bufferize callers of the block.
447 for (Operation *op : block->getUsers()) {
448 auto branchOp = dyn_cast<BranchOpInterface>(op);
449 if (!branchOp)
450 return op->emitOpError(message: "cannot bufferize ops with block references that "
451 "do not implement BranchOpInterface");
452
453 auto it = llvm::find(Range: op->getSuccessors(), Val: block);
454 assert(it != op->getSuccessors().end() && "could find successor");
455 int64_t successorIdx = std::distance(first: op->getSuccessors().begin(), last: it);
456
457 SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
458 SmallVector<Value> newOperands;
459 for (auto [operand, type] :
460 llvm::zip(operands.getForwardedOperands(), newTypes)) {
461 if (operand.getType() == type) {
462 // Not a tensor type. Nothing to do for this operand.
463 newOperands.push_back(operand);
464 continue;
465 }
466 FailureOr<BaseMemRefType> operandBufferType =
467 bufferization::getBufferType(operand, options, state);
468 if (failed(operandBufferType))
469 return failure();
470 rewriter.setInsertionPointAfterValue(operand);
471 Value bufferizedOperand = rewriter.create<bufferization::ToBufferOp>(
472 operand.getLoc(), *operandBufferType, operand);
473 // A cast is needed if the operand and the block argument have different
474 // bufferized types.
475 if (type != *operandBufferType)
476 bufferizedOperand = rewriter.create<memref::CastOp>(
477 operand.getLoc(), type, bufferizedOperand);
478 newOperands.push_back(bufferizedOperand);
479 }
480 operands.getMutableForwardedOperands().assign(values: newOperands);
481 }
482
483 return success();
484}
485

source code of mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp