| 1 | //===- BufferUtils.cpp - buffer transformation utilities ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements utilities for buffer optimization passes. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 16 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 19 | #include "mlir/IR/Operation.h" |
| 20 | #include <optional> |
| 21 | |
| 22 | using namespace mlir; |
| 23 | using namespace mlir::bufferization; |
| 24 | |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | // BufferPlacementAllocs |
| 27 | //===----------------------------------------------------------------------===// |
| 28 | |
| 29 | /// Get the start operation to place the given alloc value withing the |
| 30 | // specified placement block. |
| 31 | Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, |
| 32 | Block *placementBlock, |
| 33 | const Liveness &liveness) { |
| 34 | // We have to ensure that we place the alloc before its first use in this |
| 35 | // block. |
| 36 | const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(block: placementBlock); |
| 37 | Operation *startOperation = livenessInfo.getStartOperation(value: allocValue); |
| 38 | // Check whether the start operation lies in the desired placement block. |
| 39 | // If not, we will use the terminator as this is the last operation in |
| 40 | // this block. |
| 41 | if (startOperation->getBlock() != placementBlock) { |
| 42 | Operation *opInPlacementBlock = |
| 43 | placementBlock->findAncestorOpInBlock(op&: *startOperation); |
| 44 | startOperation = opInPlacementBlock ? opInPlacementBlock |
| 45 | : placementBlock->getTerminator(); |
| 46 | } |
| 47 | |
| 48 | return startOperation; |
| 49 | } |
| 50 | |
| 51 | /// Initializes the internal list by discovering all supported allocation |
| 52 | /// nodes. |
| 53 | BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } |
| 54 | |
| 55 | /// Searches for and registers all supported allocation entries. |
| 56 | void BufferPlacementAllocs::build(Operation *op) { |
| 57 | op->walk(callback: [&](MemoryEffectOpInterface opInterface) { |
| 58 | // Try to find a single allocation result. |
| 59 | SmallVector<MemoryEffects::EffectInstance, 2> effects; |
| 60 | opInterface.getEffects(effects); |
| 61 | |
| 62 | SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; |
| 63 | llvm::copy_if( |
| 64 | Range&: effects, Out: std::back_inserter(x&: allocateResultEffects), |
| 65 | P: [=](MemoryEffects::EffectInstance &it) { |
| 66 | Value value = it.getValue(); |
| 67 | return isa<MemoryEffects::Allocate>(Val: it.getEffect()) && value && |
| 68 | isa<OpResult>(Val: value) && |
| 69 | it.getResource() != |
| 70 | SideEffects::AutomaticAllocationScopeResource::get(); |
| 71 | }); |
| 72 | // If there is one result only, we will be able to move the allocation and |
| 73 | // (possibly existing) deallocation ops. |
| 74 | if (allocateResultEffects.size() != 1) |
| 75 | return; |
| 76 | // Get allocation result. |
| 77 | Value allocValue = allocateResultEffects[0].getValue(); |
| 78 | // Find the associated dealloc value and register the allocation entry. |
| 79 | std::optional<Operation *> dealloc = memref::findDealloc(allocValue); |
| 80 | // If the allocation has > 1 dealloc associated with it, skip handling it. |
| 81 | if (!dealloc) |
| 82 | return; |
| 83 | allocs.push_back(Elt: std::make_tuple(args&: allocValue, args&: *dealloc)); |
| 84 | }); |
| 85 | } |
| 86 | |
| 87 | //===----------------------------------------------------------------------===// |
| 88 | // BufferPlacementTransformationBase |
| 89 | //===----------------------------------------------------------------------===// |
| 90 | |
| 91 | /// Constructs a new transformation base using the given root operation. |
| 92 | BufferPlacementTransformationBase::BufferPlacementTransformationBase( |
| 93 | Operation *op) |
| 94 | : aliases(op), allocs(op), liveness(op) {} |
| 95 | |
| 96 | //===----------------------------------------------------------------------===// |
| 97 | // BufferPlacementTransformationBase |
| 98 | //===----------------------------------------------------------------------===// |
| 99 | |
| 100 | FailureOr<memref::GlobalOp> |
| 101 | bufferization::getGlobalFor(arith::ConstantOp constantOp, |
| 102 | SymbolTableCollection &symbolTables, |
| 103 | uint64_t alignment, Attribute memorySpace) { |
| 104 | auto type = cast<RankedTensorType>(Val: constantOp.getType()); |
| 105 | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); |
| 106 | if (!moduleOp) |
| 107 | return failure(); |
| 108 | |
| 109 | // If we already have a global for this constant value, no need to do |
| 110 | // anything else. |
| 111 | for (Operation &op : moduleOp.getRegion().getOps()) { |
| 112 | auto globalOp = dyn_cast<memref::GlobalOp>(Val: &op); |
| 113 | if (!globalOp) |
| 114 | continue; |
| 115 | if (!globalOp.getInitialValue().has_value()) |
| 116 | continue; |
| 117 | uint64_t opAlignment = globalOp.getAlignment().value_or(u: 0); |
| 118 | Attribute initialValue = globalOp.getInitialValue().value(); |
| 119 | if (opAlignment == alignment && initialValue == constantOp.getValue()) |
| 120 | return globalOp; |
| 121 | } |
| 122 | |
| 123 | // Create a builder without an insertion point. We will insert using the |
| 124 | // symbol table to guarantee unique names. |
| 125 | OpBuilder globalBuilder(moduleOp.getContext()); |
| 126 | SymbolTable &symbolTable = symbolTables.getSymbolTable(op: moduleOp); |
| 127 | |
| 128 | // Create a pretty name. |
| 129 | SmallString<64> buf; |
| 130 | llvm::raw_svector_ostream os(buf); |
| 131 | interleave(c: type.getShape(), os, separator: "x" ); |
| 132 | os << "x" << type.getElementType(); |
| 133 | |
| 134 | // Add an optional alignment to the global memref. |
| 135 | IntegerAttr memrefAlignment = |
| 136 | alignment > 0 ? IntegerAttr::get(type: globalBuilder.getI64Type(), value: alignment) |
| 137 | : IntegerAttr(); |
| 138 | |
| 139 | // Memref globals always have an identity layout. |
| 140 | auto memrefType = |
| 141 | cast<MemRefType>(Val: getMemRefTypeWithStaticIdentityLayout(tensorType: type)); |
| 142 | if (memorySpace) |
| 143 | memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); |
| 144 | auto global = globalBuilder.create<memref::GlobalOp>( |
| 145 | location: constantOp.getLoc(), args: (Twine("__constant_" ) + os.str()).str(), |
| 146 | /*sym_visibility=*/args: globalBuilder.getStringAttr(bytes: "private" ), |
| 147 | /*type=*/args&: memrefType, |
| 148 | /*initial_value=*/args: cast<ElementsAttr>(Val: constantOp.getValue()), |
| 149 | /*constant=*/args: true, |
| 150 | /*alignment=*/args&: memrefAlignment); |
| 151 | symbolTable.insert(symbol: global); |
| 152 | // The symbol table inserts at the end of the module, but globals are a bit |
| 153 | // nicer if they are at the beginning. |
| 154 | global->moveBefore(existingOp: &moduleOp.front()); |
| 155 | return global; |
| 156 | } |
| 157 | |
| 158 | namespace mlir::bufferization { |
| 159 | void removeSymbol(Operation *op, BufferizationState &state) { |
| 160 | SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( |
| 161 | op: op->getParentWithTrait<OpTrait::SymbolTable>()); |
| 162 | |
| 163 | symbolTable.remove(op); |
| 164 | } |
| 165 | |
| 166 | void insertSymbol(Operation *op, BufferizationState &state) { |
| 167 | SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( |
| 168 | op: op->getParentWithTrait<OpTrait::SymbolTable>()); |
| 169 | |
| 170 | symbolTable.insert(symbol: op); |
| 171 | } |
| 172 | } // namespace mlir::bufferization |
| 173 | |