1//===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for buffer optimization passes.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Interfaces/ControlFlowInterfaces.h"
19#include "mlir/Interfaces/LoopLikeInterface.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/SmallString.h"
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::bufferization;
27
28//===----------------------------------------------------------------------===//
29// BufferPlacementAllocs
30//===----------------------------------------------------------------------===//
31
32/// Get the start operation to place the given alloc value withing the
33// specified placement block.
34Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
35 Block *placementBlock,
36 const Liveness &liveness) {
37 // We have to ensure that we place the alloc before its first use in this
38 // block.
39 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock);
40 Operation *startOperation = livenessInfo.getStartOperation(value: allocValue);
41 // Check whether the start operation lies in the desired placement block.
42 // If not, we will use the terminator as this is the last operation in
43 // this block.
44 if (startOperation->getBlock() != placementBlock) {
45 Operation *opInPlacementBlock =
46 placementBlock->findAncestorOpInBlock(op&: *startOperation);
47 startOperation = opInPlacementBlock ? opInPlacementBlock
48 : placementBlock->getTerminator();
49 }
50
51 return startOperation;
52}
53
54/// Initializes the internal list by discovering all supported allocation
55/// nodes.
56BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
57
58/// Searches for and registers all supported allocation entries.
59void BufferPlacementAllocs::build(Operation *op) {
60 op->walk(callback: [&](MemoryEffectOpInterface opInterface) {
61 // Try to find a single allocation result.
62 SmallVector<MemoryEffects::EffectInstance, 2> effects;
63 opInterface.getEffects(effects);
64
65 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
66 llvm::copy_if(
67 Range&: effects, Out: std::back_inserter(x&: allocateResultEffects),
68 P: [=](MemoryEffects::EffectInstance &it) {
69 Value value = it.getValue();
70 return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value &&
71 isa<OpResult>(Val: value) &&
72 it.getResource() !=
73 SideEffects::AutomaticAllocationScopeResource::get();
74 });
75 // If there is one result only, we will be able to move the allocation and
76 // (possibly existing) deallocation ops.
77 if (allocateResultEffects.size() != 1)
78 return;
79 // Get allocation result.
80 Value allocValue = allocateResultEffects[0].getValue();
81 // Find the associated dealloc value and register the allocation entry.
82 std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
83 // If the allocation has > 1 dealloc associated with it, skip handling it.
84 if (!dealloc)
85 return;
86 allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc));
87 });
88}
89
90//===----------------------------------------------------------------------===//
91// BufferPlacementTransformationBase
92//===----------------------------------------------------------------------===//
93
94/// Constructs a new transformation base using the given root operation.
95BufferPlacementTransformationBase::BufferPlacementTransformationBase(
96 Operation *op)
97 : aliases(op), allocs(op), liveness(op) {}
98
99//===----------------------------------------------------------------------===//
100// BufferPlacementTransformationBase
101//===----------------------------------------------------------------------===//
102
103FailureOr<memref::GlobalOp>
104bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
105 Attribute memorySpace) {
106 auto type = cast<RankedTensorType>(constantOp.getType());
107 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
108 if (!moduleOp)
109 return failure();
110
111 // If we already have a global for this constant value, no need to do
112 // anything else.
113 for (Operation &op : moduleOp.getRegion().getOps()) {
114 auto globalOp = dyn_cast<memref::GlobalOp>(&op);
115 if (!globalOp)
116 continue;
117 if (!globalOp.getInitialValue().has_value())
118 continue;
119 uint64_t opAlignment = globalOp.getAlignment().value_or(0);
120 Attribute initialValue = globalOp.getInitialValue().value();
121 if (opAlignment == alignment && initialValue == constantOp.getValue())
122 return globalOp;
123 }
124
125 // Create a builder without an insertion point. We will insert using the
126 // symbol table to guarantee unique names.
127 OpBuilder globalBuilder(moduleOp.getContext());
128 SymbolTable symbolTable(moduleOp);
129
130 // Create a pretty name.
131 SmallString<64> buf;
132 llvm::raw_svector_ostream os(buf);
133 interleave(type.getShape(), os, "x");
134 os << "x" << type.getElementType();
135
136 // Add an optional alignment to the global memref.
137 IntegerAttr memrefAlignment =
138 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
139 : IntegerAttr();
140
141 BufferizeTypeConverter typeConverter;
142 auto memrefType = cast<MemRefType>(typeConverter.convertType(type));
143 if (memorySpace)
144 memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
145 auto global = globalBuilder.create<memref::GlobalOp>(
146 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
147 /*sym_visibility=*/globalBuilder.getStringAttr("private"),
148 /*type=*/memrefType,
149 /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
150 /*constant=*/true,
151 /*alignment=*/memrefAlignment);
152 symbolTable.insert(symbol: global);
153 // The symbol table inserts at the end of the module, but globals are a bit
154 // nicer if they are at the beginning.
155 global->moveBefore(&moduleOp.front());
156 return global;
157}
158

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp