1//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
10#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13
14using namespace mlir;
15using namespace mlir::bufferization;
16
17namespace {
18/// The `scf.forall.in_parallel` terminator is special in a few ways:
19/// * It does not implement the BranchOpInterface or
20/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
21/// which is not supported by BufferDeallocation.
22/// * It has a graph-like region which only allows one specific tensor op
23/// * After bufferization the nested region is always empty
24/// For these reasons we provide custom deallocation logic via this external
25/// model.
26///
27/// Example:
28/// ```mlir
29/// scf.forall (%arg1) in (%arg0) {
30/// %alloc = memref.alloc() : memref<2xf32>
31/// ...
32/// <implicit in_parallel terminator here>
33/// }
34/// ```
35/// gets transformed to
36/// ```mlir
37/// scf.forall (%arg1) in (%arg0) {
38/// %alloc = memref.alloc() : memref<2xf32>
39/// ...
40/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
41/// <implicit in_parallel terminator here>
42/// }
43/// ```
44struct InParallelOpInterface
45 : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
46 scf::InParallelOp> {
47 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
48 const DeallocationOptions &options) const {
49 auto inParallelOp = cast<scf::InParallelOp>(op);
50 if (!inParallelOp.getBody()->empty())
51 return op->emitError(message: "only supported when nested region is empty");
52
53 SmallVector<Value> updatedOperandOwnership;
54 return deallocation_impl::insertDeallocOpForReturnLike(
55 state, op, operands: {}, updatedOperandOwnerships&: updatedOperandOwnership);
56 }
57};
58
59struct ReduceReturnOpInterface
60 : public BufferDeallocationOpInterface::ExternalModel<
61 ReduceReturnOpInterface, scf::ReduceReturnOp> {
62 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
63 const DeallocationOptions &options) const {
64 auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
65 if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
66 return op->emitError(message: "only supported when operand is not a MemRef");
67
68 SmallVector<Value> updatedOperandOwnership;
69 return deallocation_impl::insertDeallocOpForReturnLike(
70 state, op, operands: {}, updatedOperandOwnerships&: updatedOperandOwnership);
71 }
72};
73
74} // namespace
75
76void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
77 DialectRegistry &registry) {
78 registry.addExtension(extensionFn: +[](MLIRContext *ctx, SCFDialect *dialect) {
79 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
80 ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
81 });
82}
83

source code of mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp