1 | //===- BufferUtils.cpp - buffer transformation utilities ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for buffer optimization passes. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" |
14 | |
15 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
19 | #include "mlir/IR/Operation.h" |
20 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
21 | #include "mlir/Interfaces/LoopLikeInterface.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | #include "llvm/ADT/SetOperations.h" |
24 | #include "llvm/ADT/SmallString.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::bufferization; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // BufferPlacementAllocs |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Get the start operation to place the given alloc value withing the |
35 | // specified placement block. |
36 | Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, |
37 | Block *placementBlock, |
38 | const Liveness &liveness) { |
39 | // We have to ensure that we place the alloc before its first use in this |
40 | // block. |
41 | const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock); |
42 | Operation *startOperation = livenessInfo.getStartOperation(value: allocValue); |
43 | // Check whether the start operation lies in the desired placement block. |
44 | // If not, we will use the terminator as this is the last operation in |
45 | // this block. |
46 | if (startOperation->getBlock() != placementBlock) { |
47 | Operation *opInPlacementBlock = |
48 | placementBlock->findAncestorOpInBlock(op&: *startOperation); |
49 | startOperation = opInPlacementBlock ? opInPlacementBlock |
50 | : placementBlock->getTerminator(); |
51 | } |
52 | |
53 | return startOperation; |
54 | } |
55 | |
56 | /// Initializes the internal list by discovering all supported allocation |
57 | /// nodes. |
58 | BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } |
59 | |
60 | /// Searches for and registers all supported allocation entries. |
61 | void BufferPlacementAllocs::build(Operation *op) { |
62 | op->walk(callback: [&](MemoryEffectOpInterface opInterface) { |
63 | // Try to find a single allocation result. |
64 | SmallVector<MemoryEffects::EffectInstance, 2> effects; |
65 | opInterface.getEffects(effects); |
66 | |
67 | SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; |
68 | llvm::copy_if( |
69 | Range&: effects, Out: std::back_inserter(x&: allocateResultEffects), |
70 | P: [=](MemoryEffects::EffectInstance &it) { |
71 | Value value = it.getValue(); |
72 | return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value && |
73 | isa<OpResult>(Val: value) && |
74 | it.getResource() != |
75 | SideEffects::AutomaticAllocationScopeResource::get(); |
76 | }); |
77 | // If there is one result only, we will be able to move the allocation and |
78 | // (possibly existing) deallocation ops. |
79 | if (allocateResultEffects.size() != 1) |
80 | return; |
81 | // Get allocation result. |
82 | Value allocValue = allocateResultEffects[0].getValue(); |
83 | // Find the associated dealloc value and register the allocation entry. |
84 | std::optional<Operation *> dealloc = memref::findDealloc(allocValue); |
85 | // If the allocation has > 1 dealloc associated with it, skip handling it. |
86 | if (!dealloc) |
87 | return; |
88 | allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc)); |
89 | }); |
90 | } |
91 | |
92 | //===----------------------------------------------------------------------===// |
93 | // BufferPlacementTransformationBase |
94 | //===----------------------------------------------------------------------===// |
95 | |
96 | /// Constructs a new transformation base using the given root operation. |
97 | BufferPlacementTransformationBase::BufferPlacementTransformationBase( |
98 | Operation *op) |
99 | : aliases(op), allocs(op), liveness(op) {} |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // BufferPlacementTransformationBase |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | FailureOr<memref::GlobalOp> |
106 | bufferization::getGlobalFor(arith::ConstantOp constantOp, |
107 | SymbolTableCollection &symbolTables, |
108 | uint64_t alignment, Attribute memorySpace) { |
109 | auto type = cast<RankedTensorType>(constantOp.getType()); |
110 | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); |
111 | if (!moduleOp) |
112 | return failure(); |
113 | |
114 | // If we already have a global for this constant value, no need to do |
115 | // anything else. |
116 | for (Operation &op : moduleOp.getRegion().getOps()) { |
117 | auto globalOp = dyn_cast<memref::GlobalOp>(&op); |
118 | if (!globalOp) |
119 | continue; |
120 | if (!globalOp.getInitialValue().has_value()) |
121 | continue; |
122 | uint64_t opAlignment = globalOp.getAlignment().value_or(0); |
123 | Attribute initialValue = globalOp.getInitialValue().value(); |
124 | if (opAlignment == alignment && initialValue == constantOp.getValue()) |
125 | return globalOp; |
126 | } |
127 | |
128 | // Create a builder without an insertion point. We will insert using the |
129 | // symbol table to guarantee unique names. |
130 | OpBuilder globalBuilder(moduleOp.getContext()); |
131 | SymbolTable &symbolTable = symbolTables.getSymbolTable(op: moduleOp); |
132 | |
133 | // Create a pretty name. |
134 | SmallString<64> buf; |
135 | llvm::raw_svector_ostream os(buf); |
136 | interleave(type.getShape(), os, "x" ); |
137 | os << "x" << type.getElementType(); |
138 | |
139 | // Add an optional alignment to the global memref. |
140 | IntegerAttr memrefAlignment = |
141 | alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) |
142 | : IntegerAttr(); |
143 | |
144 | // Memref globals always have an identity layout. |
145 | auto memrefType = |
146 | cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type)); |
147 | if (memorySpace) |
148 | memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); |
149 | auto global = globalBuilder.create<memref::GlobalOp>( |
150 | constantOp.getLoc(), (Twine("__constant_" ) + os.str()).str(), |
151 | /*sym_visibility=*/globalBuilder.getStringAttr("private" ), |
152 | /*type=*/memrefType, |
153 | /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), |
154 | /*constant=*/true, |
155 | /*alignment=*/memrefAlignment); |
156 | symbolTable.insert(symbol: global); |
157 | // The symbol table inserts at the end of the module, but globals are a bit |
158 | // nicer if they are at the beginning. |
159 | global->moveBefore(&moduleOp.front()); |
160 | return global; |
161 | } |
162 | |
163 | namespace mlir::bufferization { |
164 | void removeSymbol(Operation *op, BufferizationState &state) { |
165 | SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( |
166 | op: op->getParentWithTrait<OpTrait::SymbolTable>()); |
167 | |
168 | symbolTable.remove(op); |
169 | } |
170 | |
171 | void insertSymbol(Operation *op, BufferizationState &state) { |
172 | SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( |
173 | op: op->getParentWithTrait<OpTrait::SymbolTable>()); |
174 | |
175 | symbolTable.insert(op); |
176 | } |
177 | } // namespace mlir::bufferization |
178 | |