1//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
10#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
11#include "mlir/Dialect/SCF/IR/SCF.h"
12
13using namespace mlir;
14using namespace mlir::bufferization;
15
16namespace {
17/// The `scf.forall.in_parallel` terminator is special in a few ways:
18/// * It does not implement the BranchOpInterface or
19/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
20/// which is not supported by BufferDeallocation.
21/// * It has a graph-like region which only allows one specific tensor op
22/// * After bufferization the nested region is always empty
23/// For these reasons we provide custom deallocation logic via this external
24/// model.
25///
26/// Example:
27/// ```mlir
28/// scf.forall (%arg1) in (%arg0) {
29/// %alloc = memref.alloc() : memref<2xf32>
30/// ...
31/// <implicit in_parallel terminator here>
32/// }
33/// ```
34/// gets transformed to
35/// ```mlir
36/// scf.forall (%arg1) in (%arg0) {
37/// %alloc = memref.alloc() : memref<2xf32>
38/// ...
39/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
40/// <implicit in_parallel terminator here>
41/// }
42/// ```
43struct InParallelOpInterface
44 : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
45 scf::InParallelOp> {
46 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
47 const DeallocationOptions &options) const {
48 auto inParallelOp = cast<scf::InParallelOp>(Val: op);
49 if (!inParallelOp.getBody()->empty())
50 return op->emitError(message: "only supported when nested region is empty");
51
52 SmallVector<Value> updatedOperandOwnership;
53 return deallocation_impl::insertDeallocOpForReturnLike(
54 state, op, operands: {}, updatedOperandOwnerships&: updatedOperandOwnership);
55 }
56};
57
58struct ReduceReturnOpInterface
59 : public BufferDeallocationOpInterface::ExternalModel<
60 ReduceReturnOpInterface, scf::ReduceReturnOp> {
61 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
62 const DeallocationOptions &options) const {
63 auto reduceReturnOp = cast<scf::ReduceReturnOp>(Val: op);
64 if (isa<BaseMemRefType>(Val: reduceReturnOp.getOperand().getType()))
65 return op->emitError(message: "only supported when operand is not a MemRef");
66
67 SmallVector<Value> updatedOperandOwnership;
68 return deallocation_impl::insertDeallocOpForReturnLike(
69 state, op, operands: {}, updatedOperandOwnerships&: updatedOperandOwnership);
70 }
71};
72
73} // namespace
74
75void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
76 DialectRegistry &registry) {
77 registry.addExtension(extensionFn: +[](MLIRContext *ctx, SCFDialect *dialect) {
78 InParallelOp::attachInterface<InParallelOpInterface>(context&: *ctx);
79 ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(context&: *ctx);
80 });
81}
82

source code of mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp