1//===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for buffer optimization passes.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14
15#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
16#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/IR/Operation.h"
20#include <optional>
21
22using namespace mlir;
23using namespace mlir::bufferization;
24
25//===----------------------------------------------------------------------===//
26// BufferPlacementAllocs
27//===----------------------------------------------------------------------===//
28
29/// Get the start operation to place the given alloc value withing the
30// specified placement block.
31Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
32 Block *placementBlock,
33 const Liveness &liveness) {
34 // We have to ensure that we place the alloc before its first use in this
35 // block.
36 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock);
37 Operation *startOperation = livenessInfo.getStartOperation(value: allocValue);
38 // Check whether the start operation lies in the desired placement block.
39 // If not, we will use the terminator as this is the last operation in
40 // this block.
41 if (startOperation->getBlock() != placementBlock) {
42 Operation *opInPlacementBlock =
43 placementBlock->findAncestorOpInBlock(op&: *startOperation);
44 startOperation = opInPlacementBlock ? opInPlacementBlock
45 : placementBlock->getTerminator();
46 }
47
48 return startOperation;
49}
50
51/// Initializes the internal list by discovering all supported allocation
52/// nodes.
53BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
54
55/// Searches for and registers all supported allocation entries.
56void BufferPlacementAllocs::build(Operation *op) {
57 op->walk(callback: [&](MemoryEffectOpInterface opInterface) {
58 // Try to find a single allocation result.
59 SmallVector<MemoryEffects::EffectInstance, 2> effects;
60 opInterface.getEffects(effects);
61
62 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
63 llvm::copy_if(
64 Range&: effects, Out: std::back_inserter(x&: allocateResultEffects),
65 P: [=](MemoryEffects::EffectInstance &it) {
66 Value value = it.getValue();
67 return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value &&
68 isa<OpResult>(Val: value) &&
69 it.getResource() !=
70 SideEffects::AutomaticAllocationScopeResource::get();
71 });
72 // If there is one result only, we will be able to move the allocation and
73 // (possibly existing) deallocation ops.
74 if (allocateResultEffects.size() != 1)
75 return;
76 // Get allocation result.
77 Value allocValue = allocateResultEffects[0].getValue();
78 // Find the associated dealloc value and register the allocation entry.
79 std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
80 // If the allocation has > 1 dealloc associated with it, skip handling it.
81 if (!dealloc)
82 return;
83 allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc));
84 });
85}
86
87//===----------------------------------------------------------------------===//
88// BufferPlacementTransformationBase
89//===----------------------------------------------------------------------===//
90
91/// Constructs a new transformation base using the given root operation.
92BufferPlacementTransformationBase::BufferPlacementTransformationBase(
93 Operation *op)
94 : aliases(op), allocs(op), liveness(op) {}
95
96//===----------------------------------------------------------------------===//
97// BufferPlacementTransformationBase
98//===----------------------------------------------------------------------===//
99
100FailureOr<memref::GlobalOp>
101bufferization::getGlobalFor(arith::ConstantOp constantOp,
102 SymbolTableCollection &symbolTables,
103 uint64_t alignment, Attribute memorySpace) {
104 auto type = cast<RankedTensorType>(Val: constantOp.getType());
105 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
106 if (!moduleOp)
107 return failure();
108
109 // If we already have a global for this constant value, no need to do
110 // anything else.
111 for (Operation &op : moduleOp.getRegion().getOps()) {
112 auto globalOp = dyn_cast<memref::GlobalOp>(Val: &op);
113 if (!globalOp)
114 continue;
115 if (!globalOp.getInitialValue().has_value())
116 continue;
117 uint64_t opAlignment = globalOp.getAlignment().value_or(u: 0);
118 Attribute initialValue = globalOp.getInitialValue().value();
119 if (opAlignment == alignment && initialValue == constantOp.getValue())
120 return globalOp;
121 }
122
123 // Create a builder without an insertion point. We will insert using the
124 // symbol table to guarantee unique names.
125 OpBuilder globalBuilder(moduleOp.getContext());
126 SymbolTable &symbolTable = symbolTables.getSymbolTable(op: moduleOp);
127
128 // Create a pretty name.
129 SmallString<64> buf;
130 llvm::raw_svector_ostream os(buf);
131 interleave(c: type.getShape(), os, separator: "x");
132 os << "x" << type.getElementType();
133
134 // Add an optional alignment to the global memref.
135 IntegerAttr memrefAlignment =
136 alignment > 0 ? IntegerAttr::get(type: globalBuilder.getI64Type(), value: alignment)
137 : IntegerAttr();
138
139 // Memref globals always have an identity layout.
140 auto memrefType =
141 cast<MemRefType>(Val: getMemRefTypeWithStaticIdentityLayout(tensorType: type));
142 if (memorySpace)
143 memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
144 auto global = globalBuilder.create<memref::GlobalOp>(
145 location: constantOp.getLoc(), args: (Twine("__constant_") + os.str()).str(),
146 /*sym_visibility=*/args: globalBuilder.getStringAttr(bytes: "private"),
147 /*type=*/args&: memrefType,
148 /*initial_value=*/args: cast<ElementsAttr>(Val: constantOp.getValue()),
149 /*constant=*/args: true,
150 /*alignment=*/args&: memrefAlignment);
151 symbolTable.insert(symbol: global);
152 // The symbol table inserts at the end of the module, but globals are a bit
153 // nicer if they are at the beginning.
154 global->moveBefore(existingOp: &moduleOp.front());
155 return global;
156}
157
158namespace mlir::bufferization {
159void removeSymbol(Operation *op, BufferizationState &state) {
160 SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
161 op: op->getParentWithTrait<OpTrait::SymbolTable>());
162
163 symbolTable.remove(op);
164}
165
166void insertSymbol(Operation *op, BufferizationState &state) {
167 SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
168 op: op->getParentWithTrait<OpTrait::SymbolTable>());
169
170 symbolTable.insert(symbol: op);
171}
172} // namespace mlir::bufferization
173

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp