1 | //===- Bufferize.cpp - Bufferization utilities ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | #include "mlir/IR/Operation.h" |
21 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Pass/PassManager.h" |
24 | #include "mlir/Transforms/Passes.h" |
25 | #include <optional> |
26 | |
27 | namespace mlir { |
28 | namespace bufferization { |
29 | #define GEN_PASS_DEF_FINALIZINGBUFFERIZE |
30 | #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE |
31 | #define GEN_PASS_DEF_ONESHOTBUFFERIZE |
32 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
33 | } // namespace bufferization |
34 | } // namespace mlir |
35 | |
36 | #define DEBUG_TYPE "bufferize" |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::bufferization; |
40 | |
41 | //===----------------------------------------------------------------------===// |
42 | // BufferizeTypeConverter |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | static Value materializeToTensor(OpBuilder &builder, TensorType type, |
46 | ValueRange inputs, Location loc) { |
47 | assert(inputs.size() == 1); |
48 | assert(isa<BaseMemRefType>(inputs[0].getType())); |
49 | return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); |
50 | } |
51 | |
52 | /// Registers conversions into BufferizeTypeConverter |
53 | BufferizeTypeConverter::BufferizeTypeConverter() { |
54 | // Keep all types unchanged. |
55 | addConversion(callback: [](Type type) { return type; }); |
56 | // Convert RankedTensorType to MemRefType. |
57 | addConversion(callback: [](RankedTensorType type) -> Type { |
58 | return MemRefType::get(type.getShape(), type.getElementType()); |
59 | }); |
60 | // Convert UnrankedTensorType to UnrankedMemRefType. |
61 | addConversion(callback: [](UnrankedTensorType type) -> Type { |
62 | return UnrankedMemRefType::get(type.getElementType(), 0); |
63 | }); |
64 | addArgumentMaterialization(callback&: materializeToTensor); |
65 | addSourceMaterialization(callback&: materializeToTensor); |
66 | addTargetMaterialization(callback: [](OpBuilder &builder, BaseMemRefType type, |
67 | ValueRange inputs, Location loc) -> Value { |
68 | assert(inputs.size() == 1 && "expected exactly one input" ); |
69 | |
70 | if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) { |
71 | // MemRef to MemRef cast. |
72 | assert(inputType != type && "expected different types" ); |
73 | // Unranked to ranked and ranked to unranked casts must be explicit. |
74 | auto rankedDestType = dyn_cast<MemRefType>(type); |
75 | if (!rankedDestType) |
76 | return nullptr; |
77 | BufferizationOptions options; |
78 | options.bufferAlignment = 0; |
79 | FailureOr<Value> replacement = |
80 | castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options); |
81 | if (failed(result: replacement)) |
82 | return nullptr; |
83 | return *replacement; |
84 | } |
85 | |
86 | if (isa<TensorType>(Val: inputs[0].getType())) { |
87 | // Tensor to MemRef cast. |
88 | return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); |
89 | } |
90 | |
91 | llvm_unreachable("only tensor/memref input types supported" ); |
92 | }); |
93 | } |
94 | |
95 | void mlir::bufferization::populateBufferizeMaterializationLegality( |
96 | ConversionTarget &target) { |
97 | target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); |
98 | } |
99 | |
100 | namespace { |
101 | // In a finalizing bufferize conversion, we know that all tensors have been |
102 | // converted to memrefs, thus, this op becomes an identity. |
103 | class BufferizeToTensorOp |
104 | : public OpConversionPattern<bufferization::ToTensorOp> { |
105 | public: |
106 | using OpConversionPattern::OpConversionPattern; |
107 | LogicalResult |
108 | matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, |
109 | ConversionPatternRewriter &rewriter) const override { |
110 | rewriter.replaceOp(op, adaptor.getMemref()); |
111 | return success(); |
112 | } |
113 | }; |
114 | } // namespace |
115 | |
116 | namespace { |
117 | // In a finalizing bufferize conversion, we know that all tensors have been |
118 | // converted to memrefs, thus, this op becomes an identity. |
119 | class BufferizeToMemrefOp |
120 | : public OpConversionPattern<bufferization::ToMemrefOp> { |
121 | public: |
122 | using OpConversionPattern::OpConversionPattern; |
123 | LogicalResult |
124 | matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, |
125 | ConversionPatternRewriter &rewriter) const override { |
126 | rewriter.replaceOp(op, adaptor.getTensor()); |
127 | return success(); |
128 | } |
129 | }; |
130 | } // namespace |
131 | |
132 | void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( |
133 | BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { |
134 | patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(arg&: typeConverter, |
135 | args: patterns.getContext()); |
136 | } |
137 | |
138 | namespace { |
139 | struct FinalizingBufferizePass |
140 | : public bufferization::impl::FinalizingBufferizeBase< |
141 | FinalizingBufferizePass> { |
142 | using FinalizingBufferizeBase< |
143 | FinalizingBufferizePass>::FinalizingBufferizeBase; |
144 | |
145 | void runOnOperation() override { |
146 | auto func = getOperation(); |
147 | auto *context = &getContext(); |
148 | |
149 | BufferizeTypeConverter typeConverter; |
150 | RewritePatternSet patterns(context); |
151 | ConversionTarget target(*context); |
152 | |
153 | populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); |
154 | |
155 | // If all result types are legal, and all block arguments are legal (ensured |
156 | // by func conversion above), then all types in the program are legal. |
157 | // |
158 | // We also check that the operand types are legal to avoid creating invalid |
159 | // IR. For example, this prevents |
160 | // populateEliminateBufferizeMaterializationsPatterns from updating the |
161 | // types of the operands to a return op without updating the enclosing |
162 | // function. |
163 | target.markUnknownOpDynamicallyLegal( |
164 | fn: [&](Operation *op) { return typeConverter.isLegal(op); }); |
165 | |
166 | if (failed(applyFullConversion(func, target, std::move(patterns)))) |
167 | signalPassFailure(); |
168 | } |
169 | }; |
170 | |
171 | static LayoutMapOption parseLayoutMapOption(const std::string &s) { |
172 | if (s == "fully-dynamic-layout-map" ) |
173 | return LayoutMapOption::FullyDynamicLayoutMap; |
174 | if (s == "identity-layout-map" ) |
175 | return LayoutMapOption::IdentityLayoutMap; |
176 | if (s == "infer-layout-map" ) |
177 | return LayoutMapOption::InferLayoutMap; |
178 | llvm_unreachable("invalid layout map option" ); |
179 | } |
180 | |
181 | static OneShotBufferizationOptions::AnalysisHeuristic |
182 | parseHeuristicOption(const std::string &s) { |
183 | if (s == "bottom-up" ) |
184 | return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp; |
185 | if (s == "top-down" ) |
186 | return OneShotBufferizationOptions::AnalysisHeuristic::TopDown; |
187 | if (s == "bottom-up-from-terminators" ) |
188 | return OneShotBufferizationOptions::AnalysisHeuristic:: |
189 | BottomUpFromTerminators; |
190 | if (s == "fuzzer" ) |
191 | return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer; |
192 | llvm_unreachable("invalid analysisheuristic option" ); |
193 | } |
194 | |
195 | struct OneShotBufferizePass |
196 | : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> { |
197 | OneShotBufferizePass() = default; |
198 | |
199 | explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) |
200 | : options(options) {} |
201 | |
202 | void getDependentDialects(DialectRegistry ®istry) const override { |
203 | registry |
204 | .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); |
205 | } |
206 | |
207 | void runOnOperation() override { |
208 | OneShotBufferizationOptions opt; |
209 | if (!options) { |
210 | // Make new bufferization options if none were provided when creating the |
211 | // pass. |
212 | opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops; |
213 | opt.allowUnknownOps = allowUnknownOps; |
214 | opt.analysisFuzzerSeed = analysisFuzzerSeed; |
215 | opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic); |
216 | opt.copyBeforeWrite = copyBeforeWrite; |
217 | opt.dumpAliasSets = dumpAliasSets; |
218 | opt.setFunctionBoundaryTypeConversion( |
219 | parseLayoutMapOption(functionBoundaryTypeConversion)); |
220 | if (mustInferMemorySpace) { |
221 | opt.defaultMemorySpaceFn = |
222 | [](TensorType t) -> std::optional<Attribute> { |
223 | return std::nullopt; |
224 | }; |
225 | } |
226 | opt.printConflicts = printConflicts; |
227 | opt.testAnalysisOnly = testAnalysisOnly; |
228 | opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; |
229 | opt.noAnalysisFuncFilter = noAnalysisFuncFilter; |
230 | |
231 | // Configure type converter. |
232 | LayoutMapOption unknownTypeConversionOption = |
233 | parseLayoutMapOption(unknownTypeConversion); |
234 | if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) { |
235 | emitError(UnknownLoc::get(&getContext()), |
236 | "Invalid option: 'infer-layout-map' is not a valid value for " |
237 | "'unknown-type-conversion'" ); |
238 | return signalPassFailure(); |
239 | } |
240 | opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, |
241 | const BufferizationOptions &options) { |
242 | auto tensorType = cast<TensorType>(Val: value.getType()); |
243 | if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) |
244 | return bufferization::getMemRefTypeWithStaticIdentityLayout( |
245 | tensorType, memorySpace); |
246 | assert(unknownTypeConversionOption == |
247 | LayoutMapOption::FullyDynamicLayoutMap && |
248 | "invalid layout map option" ); |
249 | return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, |
250 | memorySpace); |
251 | }; |
252 | |
253 | // Configure op filter. |
254 | OpFilter::Entry::FilterFn filterFn = [&](Operation *op) { |
255 | // Filter may be specified via options. |
256 | if (this->dialectFilter.hasValue()) |
257 | return llvm::is_contained(this->dialectFilter, |
258 | op->getDialect()->getNamespace()); |
259 | // No filter specified: All other ops are allowed. |
260 | return true; |
261 | }; |
262 | opt.opFilter.allowOperation(filterFn); |
263 | } else { |
264 | opt = *options; |
265 | } |
266 | |
267 | if (opt.copyBeforeWrite && opt.testAnalysisOnly) { |
268 | // These two flags do not make sense together: "copy-before-write" |
269 | // indicates that copies should be inserted before every memory write, |
270 | // but "test-analysis-only" indicates that only the analysis should be |
271 | // tested. (I.e., no IR is bufferized.) |
272 | emitError(UnknownLoc::get(&getContext()), |
273 | "Invalid option: 'copy-before-write' cannot be used with " |
274 | "'test-analysis-only'" ); |
275 | return signalPassFailure(); |
276 | } |
277 | |
278 | if (opt.printConflicts && !opt.testAnalysisOnly) { |
279 | emitError( |
280 | UnknownLoc::get(&getContext()), |
281 | "Invalid option: 'print-conflicts' requires 'test-analysis-only'" ); |
282 | return signalPassFailure(); |
283 | } |
284 | |
285 | if (opt.dumpAliasSets && !opt.testAnalysisOnly) { |
286 | emitError( |
287 | UnknownLoc::get(&getContext()), |
288 | "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'" ); |
289 | return signalPassFailure(); |
290 | } |
291 | |
292 | BufferizationStatistics statistics; |
293 | ModuleOp moduleOp = getOperation(); |
294 | if (opt.bufferizeFunctionBoundaries) { |
295 | if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { |
296 | signalPassFailure(); |
297 | return; |
298 | } |
299 | } else { |
300 | if (!opt.noAnalysisFuncFilter.empty()) { |
301 | emitError(UnknownLoc::get(&getContext()), |
302 | "Invalid option: 'no-analysis-func-filter' requires " |
303 | "'bufferize-function-boundaries'" ); |
304 | return signalPassFailure(); |
305 | } |
306 | if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { |
307 | signalPassFailure(); |
308 | return; |
309 | } |
310 | } |
311 | |
312 | // Set pass statistics. |
313 | this->numBufferAlloc = statistics.numBufferAlloc; |
314 | this->numTensorInPlace = statistics.numTensorInPlace; |
315 | this->numTensorOutOfPlace = statistics.numTensorOutOfPlace; |
316 | } |
317 | |
318 | private: |
319 | std::optional<OneShotBufferizationOptions> options; |
320 | }; |
321 | } // namespace |
322 | |
323 | namespace { |
324 | struct BufferizationBufferizePass |
325 | : public bufferization::impl::BufferizationBufferizeBase< |
326 | BufferizationBufferizePass> { |
327 | void runOnOperation() override { |
328 | BufferizationOptions options = getPartialBufferizationOptions(); |
329 | options.opFilter.allowDialect<BufferizationDialect>(); |
330 | |
331 | if (failed(bufferizeOp(getOperation(), options))) |
332 | signalPassFailure(); |
333 | } |
334 | |
335 | void getDependentDialects(DialectRegistry ®istry) const override { |
336 | registry |
337 | .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); |
338 | } |
339 | }; |
340 | } // namespace |
341 | |
342 | std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() { |
343 | return std::make_unique<BufferizationBufferizePass>(); |
344 | } |
345 | |
346 | std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { |
347 | return std::make_unique<OneShotBufferizePass>(); |
348 | } |
349 | |
350 | std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( |
351 | const OneShotBufferizationOptions &options) { |
352 | return std::make_unique<OneShotBufferizePass>(args: options); |
353 | } |
354 | |
355 | std::unique_ptr<OperationPass<func::FuncOp>> |
356 | mlir::bufferization::createFinalizingBufferizePass() { |
357 | return std::make_unique<FinalizingBufferizePass>(); |
358 | } |
359 | |
360 | //===----------------------------------------------------------------------===// |
361 | // BufferizableOpInterface-based Bufferization |
362 | //===----------------------------------------------------------------------===// |
363 | |
364 | namespace { |
365 | /// A rewriter that keeps track of extra information during bufferization. |
366 | class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { |
367 | public: |
368 | BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, |
369 | DenseSet<Operation *> &toMemrefOps, |
370 | SmallVector<Operation *> &worklist, |
371 | const BufferizationOptions &options, |
372 | BufferizationStatistics *statistics) |
373 | : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), |
374 | worklist(worklist), analysisState(options), statistics(statistics) { |
375 | setListener(this); |
376 | } |
377 | |
378 | protected: |
379 | void notifyOperationErased(Operation *op) override { |
380 | erasedOps.insert(V: op); |
381 | // Erase if present. |
382 | toMemrefOps.erase(V: op); |
383 | } |
384 | |
385 | void notifyOperationInserted(Operation *op, InsertPoint previous) override { |
386 | // We only care about newly created ops. |
387 | if (previous.isSet()) |
388 | return; |
389 | |
390 | erasedOps.erase(V: op); |
391 | |
392 | // Gather statistics about allocs. |
393 | if (statistics) { |
394 | if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op)) |
395 | statistics->numBufferAlloc += static_cast<int64_t>( |
396 | sideEffectingOp.hasEffect<MemoryEffects::Allocate>()); |
397 | } |
398 | |
399 | // Keep track of to_memref ops. |
400 | if (isa<ToMemrefOp>(op)) { |
401 | toMemrefOps.insert(V: op); |
402 | return; |
403 | } |
404 | |
405 | // Skip to_tensor ops. |
406 | if (isa<ToTensorOp>(op)) |
407 | return; |
408 | |
409 | // Skip non-tensor ops. |
410 | if (!hasTensorSemantics(op)) |
411 | return; |
412 | |
413 | // Skip ops that are not allowed to be bufferized. |
414 | auto const &options = analysisState.getOptions(); |
415 | if (!options.isOpAllowed(op)) |
416 | return; |
417 | |
418 | // Add op to worklist. |
419 | worklist.push_back(Elt: op); |
420 | } |
421 | |
422 | private: |
423 | /// A set of all erased ops. |
424 | DenseSet<Operation *> &erasedOps; |
425 | |
426 | /// A set of all to_memref ops. |
427 | DenseSet<Operation *> &toMemrefOps; |
428 | |
429 | /// The worklist of ops to be bufferized. |
430 | SmallVector<Operation *> &worklist; |
431 | |
432 | /// The analysis state. Used for debug assertions and access to the |
433 | /// bufferization options. |
434 | const AnalysisState analysisState; |
435 | |
436 | /// Bufferization statistics for debugging. |
437 | BufferizationStatistics *statistics; |
438 | }; |
439 | } // namespace |
440 | |
441 | LogicalResult bufferization::bufferizeOp(Operation *op, |
442 | const BufferizationOptions &options, |
443 | BufferizationStatistics *statistics) { |
444 | if (options.copyBeforeWrite) { |
445 | AnalysisState state(options); |
446 | if (failed(result: insertTensorCopies(op, state))) |
447 | return failure(); |
448 | } |
449 | |
450 | // Keep track of to_memref ops. |
451 | DenseSet<Operation *> toMemrefOps; |
452 | op->walk(callback: [&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); |
453 | |
454 | // Gather all bufferizable ops in top-to-bottom order. |
455 | // |
456 | // We should ideally know the exact memref type of all operands when |
457 | // bufferizing an op. (This is the case when bufferizing top-to-bottom.) |
458 | // Otherwise, we have to use a memref type with a fully dynamic layout map to |
459 | // avoid copies. We are currently missing patterns for layout maps to |
460 | // canonicalize away (or canonicalize to more precise layouts). |
461 | SmallVector<Operation *> worklist; |
462 | op->walk<WalkOrder::PostOrder>(callback: [&](Operation *op) { |
463 | if (options.isOpAllowed(op) && hasTensorSemantics(op)) |
464 | worklist.push_back(Elt: op); |
465 | }); |
466 | |
467 | // Keep track of all erased ops. |
468 | DenseSet<Operation *> erasedOps; |
469 | |
470 | // Bufferize all ops. |
471 | BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, |
472 | worklist, options, statistics); |
473 | for (unsigned i = 0; i < worklist.size(); ++i) { |
474 | Operation *nextOp = worklist[i]; |
475 | // Skip ops that were erased. |
476 | if (erasedOps.contains(V: nextOp)) |
477 | continue; |
478 | // Skip ops that are not bufferizable or not allowed. |
479 | auto bufferizableOp = options.dynCastBufferizableOp(nextOp); |
480 | if (!bufferizableOp) |
481 | continue; |
482 | // Skip ops that no longer have tensor semantics. |
483 | if (!hasTensorSemantics(op: nextOp)) |
484 | continue; |
485 | // Check for unsupported unstructured control flow. |
486 | if (!bufferizableOp.supportsUnstructuredControlFlow()) |
487 | for (Region &r : nextOp->getRegions()) |
488 | if (r.getBlocks().size() > 1) |
489 | return nextOp->emitOpError( |
490 | message: "op or BufferizableOpInterface implementation does not support " |
491 | "unstructured control flow, but at least one region has multiple " |
492 | "blocks" ); |
493 | |
494 | // Bufferize the op. |
495 | LLVM_DEBUG(llvm::dbgs() |
496 | << "//===-------------------------------------------===//\n" |
497 | << "IR after bufferizing: " << nextOp->getName() << "\n" ); |
498 | rewriter.setInsertionPoint(nextOp); |
499 | if (failed(bufferizableOp.bufferize(rewriter, options))) { |
500 | LLVM_DEBUG(llvm::dbgs() |
501 | << "failed to bufferize\n" |
502 | << "//===-------------------------------------------===//\n" ); |
503 | return nextOp->emitError(message: "failed to bufferize op" ); |
504 | } |
505 | LLVM_DEBUG(llvm::dbgs() |
506 | << *op |
507 | << "\n//===-------------------------------------------===//\n" ); |
508 | } |
509 | |
510 | // Return early if the top-level op is entirely gone. |
511 | if (erasedOps.contains(V: op)) |
512 | return success(); |
513 | |
514 | // Fold all to_memref(to_tensor(x)) pairs. |
515 | for (Operation *op : toMemrefOps) { |
516 | rewriter.setInsertionPoint(op); |
517 | (void)bufferization::foldToMemrefToTensorPair( |
518 | rewriter, cast<ToMemrefOp>(op), options); |
519 | } |
520 | |
521 | // Remove all dead to_tensor ops. |
522 | op->walk<WalkOrder::PostOrder>(callback: [&](ToTensorOp toTensorOp) { |
523 | if (toTensorOp->getUses().empty()) { |
524 | rewriter.eraseOp(op: toTensorOp); |
525 | return WalkResult::skip(); |
526 | } |
527 | return WalkResult::advance(); |
528 | }); |
529 | |
530 | /// Check the result of bufferization. Return an error if an op was not |
531 | /// bufferized, unless partial bufferization is allowed. |
532 | if (options.allowUnknownOps) |
533 | return success(); |
534 | |
535 | for (Operation *op : worklist) { |
536 | // Skip ops that are entirely gone. |
537 | if (erasedOps.contains(V: op)) |
538 | continue; |
539 | // Ops that no longer have tensor semantics (because they were updated |
540 | // in-place) are allowed. |
541 | if (!hasTensorSemantics(op)) |
542 | continue; |
543 | // Continue ops that are not allowed. |
544 | if (!options.isOpAllowed(op)) |
545 | continue; |
546 | // Ops without any uses and no side effects will fold away. |
547 | if (op->getUses().empty() && isMemoryEffectFree(op)) |
548 | continue; |
549 | // ToTensorOps/ToMemrefOps are allowed in the output. |
550 | if (isa<ToTensorOp, ToMemrefOp>(op)) |
551 | continue; |
552 | return op->emitError(message: "op was not bufferized" ); |
553 | } |
554 | |
555 | return success(); |
556 | } |
557 | |
558 | LogicalResult |
559 | bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, |
560 | const BufferizationOptions &options) { |
561 | OpBuilder::InsertionGuard g(rewriter); |
562 | auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp()); |
563 | if (!bufferizableOp) |
564 | return failure(); |
565 | |
566 | // Compute the new signature. |
567 | SmallVector<Type> newTypes; |
568 | for (BlockArgument &bbArg : block->getArguments()) { |
569 | auto tensorType = dyn_cast<TensorType>(Val: bbArg.getType()); |
570 | if (!tensorType) { |
571 | newTypes.push_back(Elt: bbArg.getType()); |
572 | continue; |
573 | } |
574 | |
575 | FailureOr<BaseMemRefType> memrefType = |
576 | bufferization::getBufferType(value: bbArg, options); |
577 | if (failed(result: memrefType)) |
578 | return failure(); |
579 | newTypes.push_back(Elt: *memrefType); |
580 | } |
581 | |
582 | // Change the type of all block arguments. |
583 | for (auto [bbArg, type] : llvm::zip(t: block->getArguments(), u&: newTypes)) { |
584 | if (bbArg.getType() == type) |
585 | continue; |
586 | |
587 | // Collect all uses of the bbArg. |
588 | SmallVector<OpOperand *> bbArgUses; |
589 | for (OpOperand &use : bbArg.getUses()) |
590 | bbArgUses.push_back(Elt: &use); |
591 | |
592 | // Change the bbArg type to memref. |
593 | bbArg.setType(type); |
594 | |
595 | // Replace all uses of the original tensor bbArg. |
596 | rewriter.setInsertionPointToStart(block); |
597 | if (!bbArgUses.empty()) { |
598 | Value toTensorOp = |
599 | rewriter.create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg); |
600 | for (OpOperand *use : bbArgUses) |
601 | use->set(toTensorOp); |
602 | } |
603 | } |
604 | |
605 | // Bufferize callers of the block. |
606 | for (Operation *op : block->getUsers()) { |
607 | auto branchOp = dyn_cast<BranchOpInterface>(op); |
608 | if (!branchOp) |
609 | return op->emitOpError(message: "cannot bufferize ops with block references that " |
610 | "do not implement BranchOpInterface" ); |
611 | |
612 | auto it = llvm::find(Range: op->getSuccessors(), Val: block); |
613 | assert(it != op->getSuccessors().end() && "could find successor" ); |
614 | int64_t successorIdx = std::distance(first: op->getSuccessors().begin(), last: it); |
615 | |
616 | SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); |
617 | SmallVector<Value> newOperands; |
618 | for (auto [operand, type] : |
619 | llvm::zip(operands.getForwardedOperands(), newTypes)) { |
620 | if (operand.getType() == type) { |
621 | // Not a tensor type. Nothing to do for this operand. |
622 | newOperands.push_back(operand); |
623 | continue; |
624 | } |
625 | FailureOr<BaseMemRefType> operandBufferType = |
626 | bufferization::getBufferType(operand, options); |
627 | if (failed(operandBufferType)) |
628 | return failure(); |
629 | rewriter.setInsertionPointAfterValue(operand); |
630 | Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>( |
631 | operand.getLoc(), *operandBufferType, operand); |
632 | // A cast is needed if the operand and the block argument have different |
633 | // bufferized types. |
634 | if (type != *operandBufferType) |
635 | bufferizedOperand = rewriter.create<memref::CastOp>( |
636 | operand.getLoc(), type, bufferizedOperand); |
637 | newOperands.push_back(bufferizedOperand); |
638 | } |
639 | operands.getMutableForwardedOperands().assign(values: newOperands); |
640 | } |
641 | |
642 | return success(); |
643 | } |
644 | |
645 | BufferizationOptions bufferization::getPartialBufferizationOptions() { |
646 | BufferizationOptions options; |
647 | options.allowUnknownOps = true; |
648 | options.copyBeforeWrite = true; |
649 | options.enforceAliasingInvariants = false; |
650 | options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, |
651 | const BufferizationOptions &options) { |
652 | return getMemRefTypeWithStaticIdentityLayout( |
653 | tensorType: cast<TensorType>(Val: value.getType()), memorySpace); |
654 | }; |
655 | options.opFilter.allowDialect<BufferizationDialect>(); |
656 | return options; |
657 | } |
658 | |