1//===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for buffer optimization passes.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14
15#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
16#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/IR/Operation.h"
20#include "mlir/Interfaces/ControlFlowInterfaces.h"
21#include "mlir/Interfaces/LoopLikeInterface.h"
22#include "mlir/Pass/Pass.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallString.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::bufferization;
29
30//===----------------------------------------------------------------------===//
31// BufferPlacementAllocs
32//===----------------------------------------------------------------------===//
33
34/// Get the start operation to place the given alloc value withing the
35// specified placement block.
36Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
37 Block *placementBlock,
38 const Liveness &liveness) {
39 // We have to ensure that we place the alloc before its first use in this
40 // block.
41 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock);
42 Operation *startOperation = livenessInfo.getStartOperation(value: allocValue);
43 // Check whether the start operation lies in the desired placement block.
44 // If not, we will use the terminator as this is the last operation in
45 // this block.
46 if (startOperation->getBlock() != placementBlock) {
47 Operation *opInPlacementBlock =
48 placementBlock->findAncestorOpInBlock(op&: *startOperation);
49 startOperation = opInPlacementBlock ? opInPlacementBlock
50 : placementBlock->getTerminator();
51 }
52
53 return startOperation;
54}
55
56/// Initializes the internal list by discovering all supported allocation
57/// nodes.
58BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
59
60/// Searches for and registers all supported allocation entries.
61void BufferPlacementAllocs::build(Operation *op) {
62 op->walk(callback: [&](MemoryEffectOpInterface opInterface) {
63 // Try to find a single allocation result.
64 SmallVector<MemoryEffects::EffectInstance, 2> effects;
65 opInterface.getEffects(effects);
66
67 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
68 llvm::copy_if(
69 Range&: effects, Out: std::back_inserter(x&: allocateResultEffects),
70 P: [=](MemoryEffects::EffectInstance &it) {
71 Value value = it.getValue();
72 return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value &&
73 isa<OpResult>(Val: value) &&
74 it.getResource() !=
75 SideEffects::AutomaticAllocationScopeResource::get();
76 });
77 // If there is one result only, we will be able to move the allocation and
78 // (possibly existing) deallocation ops.
79 if (allocateResultEffects.size() != 1)
80 return;
81 // Get allocation result.
82 Value allocValue = allocateResultEffects[0].getValue();
83 // Find the associated dealloc value and register the allocation entry.
84 std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
85 // If the allocation has > 1 dealloc associated with it, skip handling it.
86 if (!dealloc)
87 return;
88 allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc));
89 });
90}
91
92//===----------------------------------------------------------------------===//
93// BufferPlacementTransformationBase
94//===----------------------------------------------------------------------===//
95
96/// Constructs a new transformation base using the given root operation.
97BufferPlacementTransformationBase::BufferPlacementTransformationBase(
98 Operation *op)
99 : aliases(op), allocs(op), liveness(op) {}
100
101//===----------------------------------------------------------------------===//
102// BufferPlacementTransformationBase
103//===----------------------------------------------------------------------===//
104
105FailureOr<memref::GlobalOp>
106bufferization::getGlobalFor(arith::ConstantOp constantOp,
107 SymbolTableCollection &symbolTables,
108 uint64_t alignment, Attribute memorySpace) {
109 auto type = cast<RankedTensorType>(constantOp.getType());
110 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
111 if (!moduleOp)
112 return failure();
113
114 // If we already have a global for this constant value, no need to do
115 // anything else.
116 for (Operation &op : moduleOp.getRegion().getOps()) {
117 auto globalOp = dyn_cast<memref::GlobalOp>(&op);
118 if (!globalOp)
119 continue;
120 if (!globalOp.getInitialValue().has_value())
121 continue;
122 uint64_t opAlignment = globalOp.getAlignment().value_or(0);
123 Attribute initialValue = globalOp.getInitialValue().value();
124 if (opAlignment == alignment && initialValue == constantOp.getValue())
125 return globalOp;
126 }
127
128 // Create a builder without an insertion point. We will insert using the
129 // symbol table to guarantee unique names.
130 OpBuilder globalBuilder(moduleOp.getContext());
131 SymbolTable &symbolTable = symbolTables.getSymbolTable(op: moduleOp);
132
133 // Create a pretty name.
134 SmallString<64> buf;
135 llvm::raw_svector_ostream os(buf);
136 interleave(type.getShape(), os, "x");
137 os << "x" << type.getElementType();
138
139 // Add an optional alignment to the global memref.
140 IntegerAttr memrefAlignment =
141 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
142 : IntegerAttr();
143
144 // Memref globals always have an identity layout.
145 auto memrefType =
146 cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type));
147 if (memorySpace)
148 memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
149 auto global = globalBuilder.create<memref::GlobalOp>(
150 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
151 /*sym_visibility=*/globalBuilder.getStringAttr("private"),
152 /*type=*/memrefType,
153 /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
154 /*constant=*/true,
155 /*alignment=*/memrefAlignment);
156 symbolTable.insert(symbol: global);
157 // The symbol table inserts at the end of the module, but globals are a bit
158 // nicer if they are at the beginning.
159 global->moveBefore(&moduleOp.front());
160 return global;
161}
162
163namespace mlir::bufferization {
164void removeSymbol(Operation *op, BufferizationState &state) {
165 SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
166 op: op->getParentWithTrait<OpTrait::SymbolTable>());
167
168 symbolTable.remove(op);
169}
170
171void insertSymbol(Operation *op, BufferizationState &state) {
172 SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
173 op: op->getParentWithTrait<OpTrait::SymbolTable>());
174
175 symbolTable.insert(op);
176}
177} // namespace mlir::bufferization
178

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp