| 1 | //===- BufferUtils.cpp - buffer transformation utilities ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements utilities for buffer optimization passes. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 16 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 19 | #include "mlir/IR/Operation.h" |
| 20 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 21 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 22 | #include "mlir/Pass/Pass.h" |
| 23 | #include "llvm/ADT/SetOperations.h" |
| 24 | #include "llvm/ADT/SmallString.h" |
| 25 | #include <optional> |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::bufferization; |
| 29 | |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | // BufferPlacementAllocs |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | |
| 34 | /// Get the start operation to place the given alloc value withing the |
| 35 | // specified placement block. |
| 36 | Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, |
| 37 | Block *placementBlock, |
| 38 | const Liveness &liveness) { |
| 39 | // We have to ensure that we place the alloc before its first use in this |
| 40 | // block. |
| 41 | const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock); |
| 42 | Operation *startOperation = livenessInfo.getStartOperation(value: allocValue); |
| 43 | // Check whether the start operation lies in the desired placement block. |
| 44 | // If not, we will use the terminator as this is the last operation in |
| 45 | // this block. |
| 46 | if (startOperation->getBlock() != placementBlock) { |
| 47 | Operation *opInPlacementBlock = |
| 48 | placementBlock->findAncestorOpInBlock(op&: *startOperation); |
| 49 | startOperation = opInPlacementBlock ? opInPlacementBlock |
| 50 | : placementBlock->getTerminator(); |
| 51 | } |
| 52 | |
| 53 | return startOperation; |
| 54 | } |
| 55 | |
| 56 | /// Initializes the internal list by discovering all supported allocation |
| 57 | /// nodes. |
| 58 | BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } |
| 59 | |
| 60 | /// Searches for and registers all supported allocation entries. |
| 61 | void BufferPlacementAllocs::build(Operation *op) { |
| 62 | op->walk(callback: [&](MemoryEffectOpInterface opInterface) { |
| 63 | // Try to find a single allocation result. |
| 64 | SmallVector<MemoryEffects::EffectInstance, 2> effects; |
| 65 | opInterface.getEffects(effects); |
| 66 | |
| 67 | SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; |
| 68 | llvm::copy_if( |
| 69 | Range&: effects, Out: std::back_inserter(x&: allocateResultEffects), |
| 70 | P: [=](MemoryEffects::EffectInstance &it) { |
| 71 | Value value = it.getValue(); |
| 72 | return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value && |
| 73 | isa<OpResult>(Val: value) && |
| 74 | it.getResource() != |
| 75 | SideEffects::AutomaticAllocationScopeResource::get(); |
| 76 | }); |
| 77 | // If there is one result only, we will be able to move the allocation and |
| 78 | // (possibly existing) deallocation ops. |
| 79 | if (allocateResultEffects.size() != 1) |
| 80 | return; |
| 81 | // Get allocation result. |
| 82 | Value allocValue = allocateResultEffects[0].getValue(); |
| 83 | // Find the associated dealloc value and register the allocation entry. |
| 84 | std::optional<Operation *> dealloc = memref::findDealloc(allocValue); |
| 85 | // If the allocation has > 1 dealloc associated with it, skip handling it. |
| 86 | if (!dealloc) |
| 87 | return; |
| 88 | allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc)); |
| 89 | }); |
| 90 | } |
| 91 | |
| 92 | //===----------------------------------------------------------------------===// |
| 93 | // BufferPlacementTransformationBase |
| 94 | //===----------------------------------------------------------------------===// |
| 95 | |
| 96 | /// Constructs a new transformation base using the given root operation. |
| 97 | BufferPlacementTransformationBase::BufferPlacementTransformationBase( |
| 98 | Operation *op) |
| 99 | : aliases(op), allocs(op), liveness(op) {} |
| 100 | |
| 101 | //===----------------------------------------------------------------------===// |
| 102 | // BufferPlacementTransformationBase |
| 103 | //===----------------------------------------------------------------------===// |
| 104 | |
| 105 | FailureOr<memref::GlobalOp> |
| 106 | bufferization::getGlobalFor(arith::ConstantOp constantOp, |
| 107 | SymbolTableCollection &symbolTables, |
| 108 | uint64_t alignment, Attribute memorySpace) { |
| 109 | auto type = cast<RankedTensorType>(constantOp.getType()); |
| 110 | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); |
| 111 | if (!moduleOp) |
| 112 | return failure(); |
| 113 | |
| 114 | // If we already have a global for this constant value, no need to do |
| 115 | // anything else. |
| 116 | for (Operation &op : moduleOp.getRegion().getOps()) { |
| 117 | auto globalOp = dyn_cast<memref::GlobalOp>(&op); |
| 118 | if (!globalOp) |
| 119 | continue; |
| 120 | if (!globalOp.getInitialValue().has_value()) |
| 121 | continue; |
| 122 | uint64_t opAlignment = globalOp.getAlignment().value_or(0); |
| 123 | Attribute initialValue = globalOp.getInitialValue().value(); |
| 124 | if (opAlignment == alignment && initialValue == constantOp.getValue()) |
| 125 | return globalOp; |
| 126 | } |
| 127 | |
| 128 | // Create a builder without an insertion point. We will insert using the |
| 129 | // symbol table to guarantee unique names. |
| 130 | OpBuilder globalBuilder(moduleOp.getContext()); |
| 131 | SymbolTable &symbolTable = symbolTables.getSymbolTable(op: moduleOp); |
| 132 | |
| 133 | // Create a pretty name. |
| 134 | SmallString<64> buf; |
| 135 | llvm::raw_svector_ostream os(buf); |
| 136 | interleave(type.getShape(), os, "x" ); |
| 137 | os << "x" << type.getElementType(); |
| 138 | |
| 139 | // Add an optional alignment to the global memref. |
| 140 | IntegerAttr memrefAlignment = |
| 141 | alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) |
| 142 | : IntegerAttr(); |
| 143 | |
| 144 | // Memref globals always have an identity layout. |
| 145 | auto memrefType = |
| 146 | cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type)); |
| 147 | if (memorySpace) |
| 148 | memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); |
| 149 | auto global = globalBuilder.create<memref::GlobalOp>( |
| 150 | constantOp.getLoc(), (Twine("__constant_" ) + os.str()).str(), |
| 151 | /*sym_visibility=*/globalBuilder.getStringAttr("private" ), |
| 152 | /*type=*/memrefType, |
| 153 | /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), |
| 154 | /*constant=*/true, |
| 155 | /*alignment=*/memrefAlignment); |
| 156 | symbolTable.insert(symbol: global); |
| 157 | // The symbol table inserts at the end of the module, but globals are a bit |
| 158 | // nicer if they are at the beginning. |
| 159 | global->moveBefore(&moduleOp.front()); |
| 160 | return global; |
| 161 | } |
| 162 | |
| 163 | namespace mlir::bufferization { |
| 164 | void removeSymbol(Operation *op, BufferizationState &state) { |
| 165 | SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( |
| 166 | op: op->getParentWithTrait<OpTrait::SymbolTable>()); |
| 167 | |
| 168 | symbolTable.remove(op); |
| 169 | } |
| 170 | |
| 171 | void insertSymbol(Operation *op, BufferizationState &state) { |
| 172 | SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( |
| 173 | op: op->getParentWithTrait<OpTrait::SymbolTable>()); |
| 174 | |
| 175 | symbolTable.insert(op); |
| 176 | } |
| 177 | } // namespace mlir::bufferization |
| 178 | |