| 1 | //===-- LLVMInsertChainFolder.cpp -----------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" |
| 10 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| 11 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "llvm/Support/Debug.h" |
| 14 | |
| 15 | #define DEBUG_TYPE "flang-insert-folder" |
| 16 | |
| 17 | #include <deque> |
| 18 | |
| 19 | namespace { |
| 20 | // Helper class to construct the attribute elements of an aggregate value being |
| 21 | // folded without creating a full mlir::Attribute representation for each step |
| 22 | // of the insert value chain, which would both be expensive in terms of |
| 23 | // compilation time and memory (since the intermediate Attribute would survive, |
| 24 | // unused, inside the mlir context). |
| 25 | class InsertChainBackwardFolder { |
| 26 | // Type for the current value of an element of the aggregate value being |
| 27 | // constructed by the insert chain. |
| 28 | // At any point of the insert chain, the value of an element is either: |
| 29 | // - nullptr: not yet known, the insert has not yet been seen. |
| 30 | // - an mlir::Attribute: the element is fully defined. |
| 31 | // - a nested InsertChainBackwardFolder: the element is itself an aggregate |
| 32 | // and its sub-elements have been partially defined (insert with mutliple |
| 33 | // indices have been seen). |
| 34 | |
| 35 | // The insertion folder assumes backward walk of the insert chain. Once an |
| 36 | // element or sub-element has been defined, it is not overriden by new |
| 37 | // insertions (last insert wins). |
| 38 | using InFlightValue = |
| 39 | llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>; |
| 40 | |
| 41 | public: |
| 42 | InsertChainBackwardFolder( |
| 43 | mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage) |
| 44 | : values(getNumElements(type), mlir::Attribute{}), |
| 45 | folderStorage{folderStorage}, type{type} {} |
| 46 | |
| 47 | /// Push |
| 48 | bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at); |
| 49 | |
| 50 | mlir::Attribute finalize(mlir::Attribute defaultFieldValue); |
| 51 | |
| 52 | private: |
| 53 | static int64_t getNumElements(mlir::Type type) { |
| 54 | if (auto structTy = |
| 55 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) |
| 56 | return structTy.getBody().size(); |
| 57 | if (auto arrayTy = |
| 58 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) |
| 59 | return arrayTy.getNumElements(); |
| 60 | return 0; |
| 61 | } |
| 62 | |
| 63 | static mlir::Type getSubElementType(mlir::Type type, int64_t field) { |
| 64 | if (auto arrayTy = |
| 65 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) |
| 66 | return arrayTy.getElementType(); |
| 67 | if (auto structTy = |
| 68 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) |
| 69 | return structTy.getBody()[field]; |
| 70 | return nullptr; |
| 71 | } |
| 72 | |
| 73 | // Current element value of the aggregate value being built. |
| 74 | llvm::SmallVector<InFlightValue> values; |
| 75 | // std::deque is used to allocate storage for nested list and guarantee the |
| 76 | // stability of the InsertChainBackwardFolder* used as element value. |
| 77 | std::deque<InsertChainBackwardFolder> *folderStorage; |
| 78 | // Type of the aggregate value being built. |
| 79 | mlir::Type type; |
| 80 | }; |
| 81 | } // namespace |
| 82 | |
| 83 | // Helper to fold the value being inserted by an llvm.insert_value. |
| 84 | // This may call tryFoldingLLVMInsertChain if the value is an aggregate and |
| 85 | // was itself constructed by a different insert chain. |
| 86 | // Returns a nullptr Attribute if the value could not be folded. |
| 87 | static mlir::Attribute getAttrIfConstant(mlir::Value val, |
| 88 | mlir::OpBuilder &rewriter) { |
| 89 | if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) |
| 90 | return cst.getValue(); |
| 91 | if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { |
| 92 | llvm::FailureOr<mlir::Attribute> attr = |
| 93 | fir::tryFoldingLLVMInsertChain(val, rewriter); |
| 94 | if (succeeded(attr)) |
| 95 | return *attr; |
| 96 | return nullptr; |
| 97 | } |
| 98 | if (val.getDefiningOp<mlir::LLVM::ZeroOp>()) |
| 99 | return mlir::LLVM::ZeroAttr::get(val.getContext()); |
| 100 | if (val.getDefiningOp<mlir::LLVM::UndefOp>()) |
| 101 | return mlir::LLVM::UndefAttr::get(val.getContext()); |
| 102 | if (mlir::Operation *op = val.getDefiningOp()) { |
| 103 | unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber(); |
| 104 | llvm::SmallVector<mlir::Value> results; |
| 105 | if (mlir::succeeded(rewriter.tryFold(op, results)) && |
| 106 | results.size() > resNum) { |
| 107 | if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>()) |
| 108 | return cst.getValue(); |
| 109 | } |
| 110 | } |
| 111 | if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>()) |
| 112 | if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter)) |
| 113 | if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr)) |
| 114 | return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); |
| 115 | LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val |
| 116 | << "\n" ); |
| 117 | return nullptr; |
| 118 | } |
| 119 | |
| 120 | mlir::Attribute |
| 121 | InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { |
| 122 | llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector( |
| 123 | values, [&](InFlightValue inFlight) -> mlir::Attribute { |
| 124 | if (!inFlight) |
| 125 | return defaultFieldValue; |
| 126 | if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) |
| 127 | return attr; |
| 128 | return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize( |
| 129 | defaultFieldValue); |
| 130 | }); |
| 131 | return mlir::ArrayAttr::get(type.getContext(), attrs); |
| 132 | } |
| 133 | |
| 134 | bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, |
| 135 | llvm::ArrayRef<int64_t> at) { |
| 136 | if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size())) |
| 137 | return false; |
| 138 | InFlightValue &inFlight = values[at[0]]; |
| 139 | if (!inFlight) { |
| 140 | if (at.size() == 1) { |
| 141 | inFlight = val; |
| 142 | return true; |
| 143 | } |
| 144 | // This is the first insert to a nested field. Create a |
| 145 | // InsertChainBackwardFolder for the current element value. |
| 146 | mlir::Type subType = getSubElementType(type, at[0]); |
| 147 | if (!subType) |
| 148 | return false; |
| 149 | InsertChainBackwardFolder &inFlightList = |
| 150 | folderStorage->emplace_back(subType, folderStorage); |
| 151 | inFlight = &inFlightList; |
| 152 | return inFlightList.pushValue(val, at.drop_front()); |
| 153 | } |
| 154 | // Keep last inserted value if already set. |
| 155 | if (llvm::isa<mlir::Attribute>(inFlight)) |
| 156 | return true; |
| 157 | auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); |
| 158 | if (at.size() == 1) { |
| 159 | if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) { |
| 160 | LLVM_DEBUG(llvm::dbgs() |
| 161 | << "insert chain sub-element partially overwritten initial " |
| 162 | "value is not zero or undef\n" ); |
| 163 | return false; |
| 164 | } |
| 165 | inFlight = inFlightList->finalize(val); |
| 166 | return true; |
| 167 | } |
| 168 | return inFlightList->pushValue(val, at.drop_front()); |
| 169 | } |
| 170 | |
| 171 | llvm::FailureOr<mlir::Attribute> |
| 172 | fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) { |
| 173 | if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) |
| 174 | return cst.getValue(); |
| 175 | if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { |
| 176 | LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n" ); |
| 177 | if (auto structTy = |
| 178 | llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) { |
| 179 | mlir::LLVM::InsertValueOp currentInsert = insert; |
| 180 | mlir::LLVM::InsertValueOp lastInsert; |
| 181 | std::deque<InsertChainBackwardFolder> folderStorage; |
| 182 | InsertChainBackwardFolder inFlightList(structTy, &folderStorage); |
| 183 | while (currentInsert) { |
| 184 | mlir::Attribute attr = |
| 185 | getAttrIfConstant(currentInsert.getValue(), rewriter); |
| 186 | if (!attr) |
| 187 | return llvm::failure(); |
| 188 | if (!inFlightList.pushValue(attr, currentInsert.getPosition())) |
| 189 | return llvm::failure(); |
| 190 | lastInsert = currentInsert; |
| 191 | currentInsert = currentInsert.getContainer() |
| 192 | .getDefiningOp<mlir::LLVM::InsertValueOp>(); |
| 193 | } |
| 194 | mlir::Attribute defaultVal; |
| 195 | if (lastInsert) { |
| 196 | if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>()) |
| 197 | defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext()); |
| 198 | else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>()) |
| 199 | defaultVal = mlir::LLVM::UndefAttr::get(val.getContext()); |
| 200 | } |
| 201 | if (!defaultVal) { |
| 202 | LLVM_DEBUG(llvm::dbgs() |
| 203 | << "insert chain initial value is not Zero or Undef\n" ); |
| 204 | return llvm::failure(); |
| 205 | } |
| 206 | return inFlightList.finalize(defaultVal); |
| 207 | } |
| 208 | } |
| 209 | return llvm::failure(); |
| 210 | } |
| 211 | |