1 | //===- BufferUtils.cpp - buffer transformation utilities ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for buffer optimization passes. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
19 | #include "mlir/Interfaces/LoopLikeInterface.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "llvm/ADT/SetOperations.h" |
22 | #include "llvm/ADT/SmallString.h" |
23 | #include <optional> |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::bufferization; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // BufferPlacementAllocs |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | /// Get the start operation to place the given alloc value withing the |
33 | // specified placement block. |
34 | Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, |
35 | Block *placementBlock, |
36 | const Liveness &liveness) { |
37 | // We have to ensure that we place the alloc before its first use in this |
38 | // block. |
39 | const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock); |
40 | Operation *startOperation = livenessInfo.getStartOperation(value: allocValue); |
41 | // Check whether the start operation lies in the desired placement block. |
42 | // If not, we will use the terminator as this is the last operation in |
43 | // this block. |
44 | if (startOperation->getBlock() != placementBlock) { |
45 | Operation *opInPlacementBlock = |
46 | placementBlock->findAncestorOpInBlock(op&: *startOperation); |
47 | startOperation = opInPlacementBlock ? opInPlacementBlock |
48 | : placementBlock->getTerminator(); |
49 | } |
50 | |
51 | return startOperation; |
52 | } |
53 | |
54 | /// Initializes the internal list by discovering all supported allocation |
55 | /// nodes. |
56 | BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } |
57 | |
58 | /// Searches for and registers all supported allocation entries. |
59 | void BufferPlacementAllocs::build(Operation *op) { |
60 | op->walk(callback: [&](MemoryEffectOpInterface opInterface) { |
61 | // Try to find a single allocation result. |
62 | SmallVector<MemoryEffects::EffectInstance, 2> effects; |
63 | opInterface.getEffects(effects); |
64 | |
65 | SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; |
66 | llvm::copy_if( |
67 | Range&: effects, Out: std::back_inserter(x&: allocateResultEffects), |
68 | P: [=](MemoryEffects::EffectInstance &it) { |
69 | Value value = it.getValue(); |
70 | return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value && |
71 | isa<OpResult>(Val: value) && |
72 | it.getResource() != |
73 | SideEffects::AutomaticAllocationScopeResource::get(); |
74 | }); |
75 | // If there is one result only, we will be able to move the allocation and |
76 | // (possibly existing) deallocation ops. |
77 | if (allocateResultEffects.size() != 1) |
78 | return; |
79 | // Get allocation result. |
80 | Value allocValue = allocateResultEffects[0].getValue(); |
81 | // Find the associated dealloc value and register the allocation entry. |
82 | std::optional<Operation *> dealloc = memref::findDealloc(allocValue); |
83 | // If the allocation has > 1 dealloc associated with it, skip handling it. |
84 | if (!dealloc) |
85 | return; |
86 | allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc)); |
87 | }); |
88 | } |
89 | |
90 | //===----------------------------------------------------------------------===// |
91 | // BufferPlacementTransformationBase |
92 | //===----------------------------------------------------------------------===// |
93 | |
94 | /// Constructs a new transformation base using the given root operation. |
95 | BufferPlacementTransformationBase::BufferPlacementTransformationBase( |
96 | Operation *op) |
97 | : aliases(op), allocs(op), liveness(op) {} |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // BufferPlacementTransformationBase |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | FailureOr<memref::GlobalOp> |
104 | bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, |
105 | Attribute memorySpace) { |
106 | auto type = cast<RankedTensorType>(constantOp.getType()); |
107 | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); |
108 | if (!moduleOp) |
109 | return failure(); |
110 | |
111 | // If we already have a global for this constant value, no need to do |
112 | // anything else. |
113 | for (Operation &op : moduleOp.getRegion().getOps()) { |
114 | auto globalOp = dyn_cast<memref::GlobalOp>(&op); |
115 | if (!globalOp) |
116 | continue; |
117 | if (!globalOp.getInitialValue().has_value()) |
118 | continue; |
119 | uint64_t opAlignment = globalOp.getAlignment().value_or(0); |
120 | Attribute initialValue = globalOp.getInitialValue().value(); |
121 | if (opAlignment == alignment && initialValue == constantOp.getValue()) |
122 | return globalOp; |
123 | } |
124 | |
125 | // Create a builder without an insertion point. We will insert using the |
126 | // symbol table to guarantee unique names. |
127 | OpBuilder globalBuilder(moduleOp.getContext()); |
128 | SymbolTable symbolTable(moduleOp); |
129 | |
130 | // Create a pretty name. |
131 | SmallString<64> buf; |
132 | llvm::raw_svector_ostream os(buf); |
133 | interleave(type.getShape(), os, "x" ); |
134 | os << "x" << type.getElementType(); |
135 | |
136 | // Add an optional alignment to the global memref. |
137 | IntegerAttr memrefAlignment = |
138 | alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) |
139 | : IntegerAttr(); |
140 | |
141 | BufferizeTypeConverter typeConverter; |
142 | auto memrefType = cast<MemRefType>(typeConverter.convertType(type)); |
143 | if (memorySpace) |
144 | memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); |
145 | auto global = globalBuilder.create<memref::GlobalOp>( |
146 | constantOp.getLoc(), (Twine("__constant_" ) + os.str()).str(), |
147 | /*sym_visibility=*/globalBuilder.getStringAttr("private" ), |
148 | /*type=*/memrefType, |
149 | /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), |
150 | /*constant=*/true, |
151 | /*alignment=*/memrefAlignment); |
152 | symbolTable.insert(symbol: global); |
153 | // The symbol table inserts at the end of the module, but globals are a bit |
154 | // nicer if they are at the beginning. |
155 | global->moveBefore(&moduleOp.front()); |
156 | return global; |
157 | } |
158 | |