1 | #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H |
2 | #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H |
3 | |
4 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
5 | #include "mlir/Pass/Pass.h" |
6 | |
7 | namespace mlir { |
8 | class FunctionOpInterface; |
9 | class ModuleOp; |
10 | class RewritePatternSet; |
11 | class OpBuilder; |
12 | class SymbolTable; |
13 | |
14 | namespace func { |
15 | class FuncOp; |
16 | } // namespace func |
17 | |
18 | namespace bufferization { |
19 | struct OneShotBufferizationOptions; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Passes |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | #define GEN_PASS_DECL |
26 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
27 | |
28 | /// Creates an instance of the BufferDeallocation pass to free all allocated |
29 | /// buffers. |
30 | std::unique_ptr<Pass> createBufferDeallocationPass(); |
31 | |
32 | /// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all |
33 | /// allocated buffers. |
34 | std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass( |
35 | DeallocationOptions options = DeallocationOptions()); |
36 | |
37 | /// Creates a pass that optimizes `bufferization.dealloc` operations. For |
38 | /// example, it reduces the number of alias checks needed at runtime using |
39 | /// static alias analysis. |
40 | std::unique_ptr<Pass> createBufferDeallocationSimplificationPass(); |
41 | |
42 | /// Creates an instance of the LowerDeallocations pass to lower |
43 | /// `bufferization.dealloc` operations to the `memref` dialect. |
44 | std::unique_ptr<Pass> createLowerDeallocationsPass(); |
45 | |
46 | /// Adds the conversion pattern of the `bufferization.dealloc` operation to the |
47 | /// given pattern set for use in other transformation passes. |
48 | void populateBufferizationDeallocLoweringPattern( |
49 | RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc); |
50 | |
51 | /// Construct the library function needed for the fully generic |
52 | /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass. |
53 | /// The function can then be called at bufferization dealloc sites to determine |
54 | /// aliasing and ownership. |
55 | /// |
56 | /// The generated function takes two memrefs of indices and three memrefs of |
57 | /// booleans as arguments: |
58 | /// * The first argument A should contain the result of the |
59 | /// extract_aligned_pointer_as_index operation applied to the memrefs to be |
60 | /// deallocated |
61 | /// * The second argument B should contain the result of the |
62 | /// extract_aligned_pointer_as_index operation applied to the memrefs to be |
63 | /// retained |
64 | /// * The third argument C should contain the conditions as passed directly |
65 | /// to the deallocation operation. |
66 | /// * The fourth argument D is used to pass results to the caller. Those |
67 | /// represent the condition under which the memref at the corresponding |
68 | /// position in A should be deallocated. |
69 | /// * The fifth argument E is used to pass results to the caller. It |
70 | /// provides the ownership value corresponding the the memref at the same |
71 | /// position in B |
72 | /// |
73 | /// This helper function is supposed to be called once for each |
74 | /// `bufferization.dealloc` operation to determine the deallocation need and new |
75 | /// ownership indicator for the retained values, but does not perform the |
76 | /// deallocation itself. |
77 | /// |
78 | /// Generated code: |
79 | /// ``` |
80 | /// func.func @dealloc_helper( |
81 | /// %dyn_dealloc_base_pointer_list: memref<?xindex>, |
82 | /// %dyn_retain_base_pointer_list: memref<?xindex>, |
83 | /// %dyn_cond_list: memref<?xi1>, |
84 | /// %dyn_dealloc_cond_out: memref<?xi1>, |
85 | /// %dyn_ownership_out: memref<?xi1>) { |
86 | /// %c0 = arith.constant 0 : index |
87 | /// %c1 = arith.constant 1 : index |
88 | /// %true = arith.constant true |
89 | /// %false = arith.constant false |
90 | /// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0 |
91 | /// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0 |
92 | /// // Zero initialize result buffer. |
93 | /// scf.for %i = %c0 to %num_retain_memrefs step %c1 { |
94 | /// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1> |
95 | /// } |
96 | /// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 { |
97 | /// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i] |
98 | /// %cond = memref.load %dyn_cond_list[%i] |
99 | /// // Check for aliasing with retained memrefs. |
100 | /// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs |
101 | /// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) { |
102 | /// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j] |
103 | /// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index |
104 | /// scf.if %does_alias { |
105 | /// %curr_ownership = memref.load %dyn_ownership_out[%j] |
106 | /// %updated_ownership = arith.ori %curr_ownership, %cond : i1 |
107 | /// memref.store %updated_ownership, %dyn_ownership_out[%j] |
108 | /// } |
109 | /// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index |
110 | /// %updated_aggregate = arith.andi %does_not_alias_aggregated, |
111 | /// %does_not_alias : i1 |
112 | /// scf.yield %updated_aggregate : i1 |
113 | /// } |
114 | /// // Check for aliasing with dealloc memrefs in the list before the |
115 | /// // current one, i.e., |
116 | /// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j], |
117 | /// // %dyn_dealloc_base_pointer[i])` |
118 | /// %does_not_alias_any = scf.for %j = %c0 to %i step %c1 |
119 | /// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) { |
120 | /// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j] |
121 | /// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp |
122 | /// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias |
123 | /// scf.yield %updated_alias_agg : i1 |
124 | /// } |
125 | /// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1 |
126 | /// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1> |
127 | /// } |
128 | /// return |
129 | /// } |
130 | /// ``` |
131 | func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, |
132 | SymbolTable &symbolTable); |
133 | |
134 | /// Run buffer deallocation. |
135 | LogicalResult deallocateBuffers(Operation *op); |
136 | |
137 | /// Run the ownership-based buffer deallocation. |
138 | LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op, |
139 | DeallocationOptions options); |
140 | |
141 | /// Creates a pass that moves allocations upwards to reduce the number of |
142 | /// required copies that are inserted during the BufferDeallocation pass. |
143 | std::unique_ptr<Pass> createBufferHoistingPass(); |
144 | |
145 | /// Creates a pass that moves allocations upwards out of loops. This avoids |
146 | /// reallocations inside of loops. |
147 | std::unique_ptr<Pass> createBufferLoopHoistingPass(); |
148 | |
149 | // Options struct for BufferResultsToOutParams pass. |
150 | // Note: defined only here, not in tablegen. |
151 | struct BufferResultsToOutParamsOpts { |
152 | /// Memcpy function: Generate a memcpy between two memrefs. |
153 | using MemCpyFn = |
154 | std::function<LogicalResult(OpBuilder &, Location, Value, Value)>; |
155 | |
156 | // Filter function; returns true if the function should be converted. |
157 | // Defaults to true, i.e. all functions are converted. |
158 | llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) { |
159 | return true; |
160 | }; |
161 | |
162 | /// Memcpy function; used to create a copy between two memrefs. |
163 | /// If this is empty, memref.copy is used. |
164 | std::optional<MemCpyFn> memCpyFn; |
165 | |
166 | /// If true, the pass adds a "bufferize.result" attribute to each output |
167 | /// parameter. |
168 | bool addResultAttribute = false; |
169 | }; |
170 | |
171 | /// Creates a pass that converts memref function results to out-params. |
172 | std::unique_ptr<Pass> createBufferResultsToOutParamsPass( |
173 | const BufferResultsToOutParamsOpts &options = {}); |
174 | |
175 | /// Replace buffers that are returned from a function with an out parameter. |
176 | /// Also update all call sites. |
177 | LogicalResult |
178 | promoteBufferResultsToOutParams(ModuleOp module, |
179 | const BufferResultsToOutParamsOpts &options); |
180 | |
181 | /// Creates a pass that drops memref function results that are equivalent to a |
182 | /// function argument. |
183 | std::unique_ptr<Pass> createDropEquivalentBufferResultsPass(); |
184 | |
185 | /// Create a pass that rewrites tensor.empty to bufferization.alloc_tensor. |
186 | std::unique_ptr<Pass> createEmptyTensorToAllocTensorPass(); |
187 | |
188 | /// Drop all memref function results that are equivalent to a function argument. |
189 | LogicalResult dropEquivalentBufferResults(ModuleOp module); |
190 | |
191 | /// Creates a pass that finalizes a partial bufferization by removing remaining |
192 | /// bufferization.to_tensor and bufferization.to_memref operations. |
193 | std::unique_ptr<OperationPass<func::FuncOp>> createFinalizingBufferizePass(); |
194 | |
195 | /// Create a pass that bufferizes all ops that implement BufferizableOpInterface |
196 | /// with One-Shot Bufferize. |
197 | std::unique_ptr<Pass> createOneShotBufferizePass(); |
198 | |
199 | /// Create a pass that bufferizes all ops that implement BufferizableOpInterface |
200 | /// with One-Shot Bufferize and the specified bufferization options. |
201 | std::unique_ptr<Pass> |
202 | createOneShotBufferizePass(const OneShotBufferizationOptions &options); |
203 | |
204 | /// Creates a pass that promotes heap-based allocations to stack-based ones. |
205 | /// Only buffers smaller than the provided size are promoted. |
206 | /// Dynamic shaped buffers are promoted up to the given rank. |
207 | std::unique_ptr<Pass> |
208 | createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024, |
209 | unsigned maxRankOfAllocatedMemRef = 1); |
210 | |
211 | /// Creates a pass that promotes heap-based allocations to stack-based ones. |
212 | /// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted. |
213 | std::unique_ptr<Pass> |
214 | createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc); |
215 | |
216 | /// Create a pass that tries to eliminate tensor.empty ops that are anchored on |
217 | /// insert_slice ops. |
218 | std::unique_ptr<Pass> createEmptyTensorEliminationPass(); |
219 | |
220 | /// Create a pass that bufferizes ops from the bufferization dialect. |
221 | std::unique_ptr<Pass> createBufferizationBufferizePass(); |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // Registration |
225 | //===----------------------------------------------------------------------===// |
226 | |
227 | /// Generate the code for registering passes. |
228 | #define GEN_PASS_REGISTRATION |
229 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
230 | |
231 | } // namespace bufferization |
232 | } // namespace mlir |
233 | |
234 | #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H |
235 | |