1#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
2#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
3
4#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
5#include "mlir/Pass/Pass.h"
6
7namespace mlir {
8class FunctionOpInterface;
9class ModuleOp;
10class RewritePatternSet;
11class OpBuilder;
12class SymbolTable;
13
14namespace func {
15class FuncOp;
16} // namespace func
17
18namespace bufferization {
19struct OneShotBufferizationOptions;
20
21//===----------------------------------------------------------------------===//
22// Passes
23//===----------------------------------------------------------------------===//
24
25#define GEN_PASS_DECL
26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27
28/// Creates an instance of the BufferDeallocation pass to free all allocated
29/// buffers.
30std::unique_ptr<Pass> createBufferDeallocationPass();
31
32/// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all
33/// allocated buffers.
34std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
35 DeallocationOptions options = DeallocationOptions());
36
37/// Creates a pass that optimizes `bufferization.dealloc` operations. For
38/// example, it reduces the number of alias checks needed at runtime using
39/// static alias analysis.
40std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
41
42/// Creates an instance of the LowerDeallocations pass to lower
43/// `bufferization.dealloc` operations to the `memref` dialect.
44std::unique_ptr<Pass> createLowerDeallocationsPass();
45
46/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
47/// given pattern set for use in other transformation passes.
48void populateBufferizationDeallocLoweringPattern(
49 RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
50
51/// Construct the library function needed for the fully generic
52/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
53/// The function can then be called at bufferization dealloc sites to determine
54/// aliasing and ownership.
55///
56/// The generated function takes two memrefs of indices and three memrefs of
57/// booleans as arguments:
58/// * The first argument A should contain the result of the
59/// extract_aligned_pointer_as_index operation applied to the memrefs to be
60/// deallocated
61/// * The second argument B should contain the result of the
62/// extract_aligned_pointer_as_index operation applied to the memrefs to be
63/// retained
64/// * The third argument C should contain the conditions as passed directly
65/// to the deallocation operation.
66/// * The fourth argument D is used to pass results to the caller. Those
67/// represent the condition under which the memref at the corresponding
68/// position in A should be deallocated.
69/// * The fifth argument E is used to pass results to the caller. It
70/// provides the ownership value corresponding the the memref at the same
71/// position in B
72///
73/// This helper function is supposed to be called once for each
74/// `bufferization.dealloc` operation to determine the deallocation need and new
75/// ownership indicator for the retained values, but does not perform the
76/// deallocation itself.
77///
78/// Generated code:
79/// ```
80/// func.func @dealloc_helper(
81/// %dyn_dealloc_base_pointer_list: memref<?xindex>,
82/// %dyn_retain_base_pointer_list: memref<?xindex>,
83/// %dyn_cond_list: memref<?xi1>,
84/// %dyn_dealloc_cond_out: memref<?xi1>,
85/// %dyn_ownership_out: memref<?xi1>) {
86/// %c0 = arith.constant 0 : index
87/// %c1 = arith.constant 1 : index
88/// %true = arith.constant true
89/// %false = arith.constant false
90/// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
91/// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
92/// // Zero initialize result buffer.
93/// scf.for %i = %c0 to %num_retain_memrefs step %c1 {
94/// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
95/// }
96/// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
97/// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
98/// %cond = memref.load %dyn_cond_list[%i]
99/// // Check for aliasing with retained memrefs.
100/// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
101/// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
102/// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
103/// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
104/// scf.if %does_alias {
105/// %curr_ownership = memref.load %dyn_ownership_out[%j]
106/// %updated_ownership = arith.ori %curr_ownership, %cond : i1
107/// memref.store %updated_ownership, %dyn_ownership_out[%j]
108/// }
109/// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
110/// %updated_aggregate = arith.andi %does_not_alias_aggregated,
111/// %does_not_alias : i1
112/// scf.yield %updated_aggregate : i1
113/// }
114/// // Check for aliasing with dealloc memrefs in the list before the
115/// // current one, i.e.,
116/// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
117/// // %dyn_dealloc_base_pointer[i])`
118/// %does_not_alias_any = scf.for %j = %c0 to %i step %c1
119/// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
120/// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
121/// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
122/// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
123/// scf.yield %updated_alias_agg : i1
124/// }
125/// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
126/// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
127/// }
128/// return
129/// }
130/// ```
131func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
132 SymbolTable &symbolTable);
133
134/// Run buffer deallocation.
135LogicalResult deallocateBuffers(Operation *op);
136
137/// Run the ownership-based buffer deallocation.
138LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op,
139 DeallocationOptions options);
140
141/// Creates a pass that moves allocations upwards to reduce the number of
142/// required copies that are inserted during the BufferDeallocation pass.
143std::unique_ptr<Pass> createBufferHoistingPass();
144
145/// Creates a pass that moves allocations upwards out of loops. This avoids
146/// reallocations inside of loops.
147std::unique_ptr<Pass> createBufferLoopHoistingPass();
148
149// Options struct for BufferResultsToOutParams pass.
150// Note: defined only here, not in tablegen.
151struct BufferResultsToOutParamsOpts {
152 /// Memcpy function: Generate a memcpy between two memrefs.
153 using MemCpyFn =
154 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
155
156 // Filter function; returns true if the function should be converted.
157 // Defaults to true, i.e. all functions are converted.
158 llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
159 return true;
160 };
161
162 /// Memcpy function; used to create a copy between two memrefs.
163 /// If this is empty, memref.copy is used.
164 std::optional<MemCpyFn> memCpyFn;
165
166 /// If true, the pass adds a "bufferize.result" attribute to each output
167 /// parameter.
168 bool addResultAttribute = false;
169};
170
171/// Creates a pass that converts memref function results to out-params.
172std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
173 const BufferResultsToOutParamsOpts &options = {});
174
175/// Replace buffers that are returned from a function with an out parameter.
176/// Also update all call sites.
177LogicalResult
178promoteBufferResultsToOutParams(ModuleOp module,
179 const BufferResultsToOutParamsOpts &options);
180
181/// Creates a pass that drops memref function results that are equivalent to a
182/// function argument.
183std::unique_ptr<Pass> createDropEquivalentBufferResultsPass();
184
185/// Create a pass that rewrites tensor.empty to bufferization.alloc_tensor.
186std::unique_ptr<Pass> createEmptyTensorToAllocTensorPass();
187
188/// Drop all memref function results that are equivalent to a function argument.
189LogicalResult dropEquivalentBufferResults(ModuleOp module);
190
191/// Creates a pass that finalizes a partial bufferization by removing remaining
192/// bufferization.to_tensor and bufferization.to_memref operations.
193std::unique_ptr<OperationPass<func::FuncOp>> createFinalizingBufferizePass();
194
195/// Create a pass that bufferizes all ops that implement BufferizableOpInterface
196/// with One-Shot Bufferize.
197std::unique_ptr<Pass> createOneShotBufferizePass();
198
199/// Create a pass that bufferizes all ops that implement BufferizableOpInterface
200/// with One-Shot Bufferize and the specified bufferization options.
201std::unique_ptr<Pass>
202createOneShotBufferizePass(const OneShotBufferizationOptions &options);
203
204/// Creates a pass that promotes heap-based allocations to stack-based ones.
205/// Only buffers smaller than the provided size are promoted.
206/// Dynamic shaped buffers are promoted up to the given rank.
207std::unique_ptr<Pass>
208createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
209 unsigned maxRankOfAllocatedMemRef = 1);
210
211/// Creates a pass that promotes heap-based allocations to stack-based ones.
212/// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted.
213std::unique_ptr<Pass>
214createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
215
216/// Create a pass that tries to eliminate tensor.empty ops that are anchored on
217/// insert_slice ops.
218std::unique_ptr<Pass> createEmptyTensorEliminationPass();
219
220/// Create a pass that bufferizes ops from the bufferization dialect.
221std::unique_ptr<Pass> createBufferizationBufferizePass();
222
223//===----------------------------------------------------------------------===//
224// Registration
225//===----------------------------------------------------------------------===//
226
227/// Generate the code for registering passes.
228#define GEN_PASS_REGISTRATION
229#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
230
231} // namespace bufferization
232} // namespace mlir
233
234#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
235

source code of mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h