1 | //===-- LLVMInsertChainFolder.cpp -----------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
11 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "llvm/Support/Debug.h" |
14 | |
15 | #define DEBUG_TYPE "flang-insert-folder" |
16 | |
17 | #include <deque> |
18 | |
19 | namespace { |
20 | // Helper class to construct the attribute elements of an aggregate value being |
21 | // folded without creating a full mlir::Attribute representation for each step |
22 | // of the insert value chain, which would both be expensive in terms of |
23 | // compilation time and memory (since the intermediate Attribute would survive, |
24 | // unused, inside the mlir context). |
25 | class InsertChainBackwardFolder { |
26 | // Type for the current value of an element of the aggregate value being |
27 | // constructed by the insert chain. |
28 | // At any point of the insert chain, the value of an element is either: |
29 | // - nullptr: not yet known, the insert has not yet been seen. |
30 | // - an mlir::Attribute: the element is fully defined. |
31 | // - a nested InsertChainBackwardFolder: the element is itself an aggregate |
32 | // and its sub-elements have been partially defined (insert with mutliple |
33 | // indices have been seen). |
34 | |
35 | // The insertion folder assumes backward walk of the insert chain. Once an |
36 | // element or sub-element has been defined, it is not overriden by new |
37 | // insertions (last insert wins). |
38 | using InFlightValue = |
39 | llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>; |
40 | |
41 | public: |
42 | InsertChainBackwardFolder( |
43 | mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage) |
44 | : values(getNumElements(type), mlir::Attribute{}), |
45 | folderStorage{folderStorage}, type{type} {} |
46 | |
47 | /// Push |
48 | bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at); |
49 | |
50 | mlir::Attribute finalize(mlir::Attribute defaultFieldValue); |
51 | |
52 | private: |
53 | static int64_t getNumElements(mlir::Type type) { |
54 | if (auto structTy = |
55 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) |
56 | return structTy.getBody().size(); |
57 | if (auto arrayTy = |
58 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) |
59 | return arrayTy.getNumElements(); |
60 | return 0; |
61 | } |
62 | |
63 | static mlir::Type getSubElementType(mlir::Type type, int64_t field) { |
64 | if (auto arrayTy = |
65 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type)) |
66 | return arrayTy.getElementType(); |
67 | if (auto structTy = |
68 | llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type)) |
69 | return structTy.getBody()[field]; |
70 | return nullptr; |
71 | } |
72 | |
73 | // Current element value of the aggregate value being built. |
74 | llvm::SmallVector<InFlightValue> values; |
75 | // std::deque is used to allocate storage for nested list and guarantee the |
76 | // stability of the InsertChainBackwardFolder* used as element value. |
77 | std::deque<InsertChainBackwardFolder> *folderStorage; |
78 | // Type of the aggregate value being built. |
79 | mlir::Type type; |
80 | }; |
81 | } // namespace |
82 | |
83 | // Helper to fold the value being inserted by an llvm.insert_value. |
84 | // This may call tryFoldingLLVMInsertChain if the value is an aggregate and |
85 | // was itself constructed by a different insert chain. |
86 | // Returns a nullptr Attribute if the value could not be folded. |
87 | static mlir::Attribute getAttrIfConstant(mlir::Value val, |
88 | mlir::OpBuilder &rewriter) { |
89 | if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) |
90 | return cst.getValue(); |
91 | if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { |
92 | llvm::FailureOr<mlir::Attribute> attr = |
93 | fir::tryFoldingLLVMInsertChain(val, rewriter); |
94 | if (succeeded(attr)) |
95 | return *attr; |
96 | return nullptr; |
97 | } |
98 | if (val.getDefiningOp<mlir::LLVM::ZeroOp>()) |
99 | return mlir::LLVM::ZeroAttr::get(val.getContext()); |
100 | if (val.getDefiningOp<mlir::LLVM::UndefOp>()) |
101 | return mlir::LLVM::UndefAttr::get(val.getContext()); |
102 | if (mlir::Operation *op = val.getDefiningOp()) { |
103 | unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber(); |
104 | llvm::SmallVector<mlir::Value> results; |
105 | if (mlir::succeeded(rewriter.tryFold(op, results)) && |
106 | results.size() > resNum) { |
107 | if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>()) |
108 | return cst.getValue(); |
109 | } |
110 | } |
111 | if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>()) |
112 | if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter)) |
113 | if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr)) |
114 | return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); |
115 | LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val |
116 | << "\n" ); |
117 | return nullptr; |
118 | } |
119 | |
120 | mlir::Attribute |
121 | InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { |
122 | llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector( |
123 | values, [&](InFlightValue inFlight) -> mlir::Attribute { |
124 | if (!inFlight) |
125 | return defaultFieldValue; |
126 | if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) |
127 | return attr; |
128 | return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize( |
129 | defaultFieldValue); |
130 | }); |
131 | return mlir::ArrayAttr::get(type.getContext(), attrs); |
132 | } |
133 | |
134 | bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, |
135 | llvm::ArrayRef<int64_t> at) { |
136 | if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size())) |
137 | return false; |
138 | InFlightValue &inFlight = values[at[0]]; |
139 | if (!inFlight) { |
140 | if (at.size() == 1) { |
141 | inFlight = val; |
142 | return true; |
143 | } |
144 | // This is the first insert to a nested field. Create a |
145 | // InsertChainBackwardFolder for the current element value. |
146 | mlir::Type subType = getSubElementType(type, at[0]); |
147 | if (!subType) |
148 | return false; |
149 | InsertChainBackwardFolder &inFlightList = |
150 | folderStorage->emplace_back(subType, folderStorage); |
151 | inFlight = &inFlightList; |
152 | return inFlightList.pushValue(val, at.drop_front()); |
153 | } |
154 | // Keep last inserted value if already set. |
155 | if (llvm::isa<mlir::Attribute>(inFlight)) |
156 | return true; |
157 | auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight); |
158 | if (at.size() == 1) { |
159 | if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) { |
160 | LLVM_DEBUG(llvm::dbgs() |
161 | << "insert chain sub-element partially overwritten initial " |
162 | "value is not zero or undef\n" ); |
163 | return false; |
164 | } |
165 | inFlight = inFlightList->finalize(val); |
166 | return true; |
167 | } |
168 | return inFlightList->pushValue(val, at.drop_front()); |
169 | } |
170 | |
171 | llvm::FailureOr<mlir::Attribute> |
172 | fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) { |
173 | if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>()) |
174 | return cst.getValue(); |
175 | if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) { |
176 | LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n" ); |
177 | if (auto structTy = |
178 | llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) { |
179 | mlir::LLVM::InsertValueOp currentInsert = insert; |
180 | mlir::LLVM::InsertValueOp lastInsert; |
181 | std::deque<InsertChainBackwardFolder> folderStorage; |
182 | InsertChainBackwardFolder inFlightList(structTy, &folderStorage); |
183 | while (currentInsert) { |
184 | mlir::Attribute attr = |
185 | getAttrIfConstant(currentInsert.getValue(), rewriter); |
186 | if (!attr) |
187 | return llvm::failure(); |
188 | if (!inFlightList.pushValue(attr, currentInsert.getPosition())) |
189 | return llvm::failure(); |
190 | lastInsert = currentInsert; |
191 | currentInsert = currentInsert.getContainer() |
192 | .getDefiningOp<mlir::LLVM::InsertValueOp>(); |
193 | } |
194 | mlir::Attribute defaultVal; |
195 | if (lastInsert) { |
196 | if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>()) |
197 | defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext()); |
198 | else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>()) |
199 | defaultVal = mlir::LLVM::UndefAttr::get(val.getContext()); |
200 | } |
201 | if (!defaultVal) { |
202 | LLVM_DEBUG(llvm::dbgs() |
203 | << "insert chain initial value is not Zero or Undef\n" ); |
204 | return llvm::failure(); |
205 | } |
206 | return inFlightList.finalize(defaultVal); |
207 | } |
208 | } |
209 | return llvm::failure(); |
210 | } |
211 | |