1//===-- LLVMInsertChainFolder.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
10#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
11#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12#include "mlir/IR/Builders.h"
13#include "llvm/Support/Debug.h"
14
15#define DEBUG_TYPE "flang-insert-folder"
16
17#include <deque>
18
19namespace {
20// Helper class to construct the attribute elements of an aggregate value being
21// folded without creating a full mlir::Attribute representation for each step
22// of the insert value chain, which would both be expensive in terms of
23// compilation time and memory (since the intermediate Attribute would survive,
24// unused, inside the mlir context).
25class InsertChainBackwardFolder {
26 // Type for the current value of an element of the aggregate value being
27 // constructed by the insert chain.
28 // At any point of the insert chain, the value of an element is either:
29 // - nullptr: not yet known, the insert has not yet been seen.
30 // - an mlir::Attribute: the element is fully defined.
31 // - a nested InsertChainBackwardFolder: the element is itself an aggregate
32 // and its sub-elements have been partially defined (insert with mutliple
33 // indices have been seen).
34
35 // The insertion folder assumes backward walk of the insert chain. Once an
36 // element or sub-element has been defined, it is not overriden by new
37 // insertions (last insert wins).
38 using InFlightValue =
39 llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>;
40
41public:
42 InsertChainBackwardFolder(
43 mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage)
44 : values(getNumElements(type), mlir::Attribute{}),
45 folderStorage{folderStorage}, type{type} {}
46
47 /// Push
48 bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at);
49
50 mlir::Attribute finalize(mlir::Attribute defaultFieldValue);
51
52private:
53 static int64_t getNumElements(mlir::Type type) {
54 if (auto structTy =
55 llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
56 return structTy.getBody().size();
57 if (auto arrayTy =
58 llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type))
59 return arrayTy.getNumElements();
60 return 0;
61 }
62
63 static mlir::Type getSubElementType(mlir::Type type, int64_t field) {
64 if (auto arrayTy =
65 llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type))
66 return arrayTy.getElementType();
67 if (auto structTy =
68 llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
69 return structTy.getBody()[field];
70 return nullptr;
71 }
72
73 // Current element value of the aggregate value being built.
74 llvm::SmallVector<InFlightValue> values;
75 // std::deque is used to allocate storage for nested list and guarantee the
76 // stability of the InsertChainBackwardFolder* used as element value.
77 std::deque<InsertChainBackwardFolder> *folderStorage;
78 // Type of the aggregate value being built.
79 mlir::Type type;
80};
81} // namespace
82
83// Helper to fold the value being inserted by an llvm.insert_value.
84// This may call tryFoldingLLVMInsertChain if the value is an aggregate and
85// was itself constructed by a different insert chain.
86// Returns a nullptr Attribute if the value could not be folded.
87static mlir::Attribute getAttrIfConstant(mlir::Value val,
88 mlir::OpBuilder &rewriter) {
89 if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
90 return cst.getValue();
91 if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
92 llvm::FailureOr<mlir::Attribute> attr =
93 fir::tryFoldingLLVMInsertChain(val, rewriter);
94 if (succeeded(attr))
95 return *attr;
96 return nullptr;
97 }
98 if (val.getDefiningOp<mlir::LLVM::ZeroOp>())
99 return mlir::LLVM::ZeroAttr::get(val.getContext());
100 if (val.getDefiningOp<mlir::LLVM::UndefOp>())
101 return mlir::LLVM::UndefAttr::get(val.getContext());
102 if (mlir::Operation *op = val.getDefiningOp()) {
103 unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber();
104 llvm::SmallVector<mlir::Value> results;
105 if (mlir::succeeded(rewriter.tryFold(op, results)) &&
106 results.size() > resNum) {
107 if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>())
108 return cst.getValue();
109 }
110 }
111 if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>())
112 if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter))
113 if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr))
114 return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt());
115 LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val
116 << "\n");
117 return nullptr;
118}
119
120mlir::Attribute
121InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) {
122 llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector(
123 values, [&](InFlightValue inFlight) -> mlir::Attribute {
124 if (!inFlight)
125 return defaultFieldValue;
126 if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight))
127 return attr;
128 return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize(
129 defaultFieldValue);
130 });
131 return mlir::ArrayAttr::get(type.getContext(), attrs);
132}
133
134bool InsertChainBackwardFolder::pushValue(mlir::Attribute val,
135 llvm::ArrayRef<int64_t> at) {
136 if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size()))
137 return false;
138 InFlightValue &inFlight = values[at[0]];
139 if (!inFlight) {
140 if (at.size() == 1) {
141 inFlight = val;
142 return true;
143 }
144 // This is the first insert to a nested field. Create a
145 // InsertChainBackwardFolder for the current element value.
146 mlir::Type subType = getSubElementType(type, at[0]);
147 if (!subType)
148 return false;
149 InsertChainBackwardFolder &inFlightList =
150 folderStorage->emplace_back(subType, folderStorage);
151 inFlight = &inFlightList;
152 return inFlightList.pushValue(val, at.drop_front());
153 }
154 // Keep last inserted value if already set.
155 if (llvm::isa<mlir::Attribute>(inFlight))
156 return true;
157 auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight);
158 if (at.size() == 1) {
159 if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) {
160 LLVM_DEBUG(llvm::dbgs()
161 << "insert chain sub-element partially overwritten initial "
162 "value is not zero or undef\n");
163 return false;
164 }
165 inFlight = inFlightList->finalize(val);
166 return true;
167 }
168 return inFlightList->pushValue(val, at.drop_front());
169}
170
171llvm::FailureOr<mlir::Attribute>
172fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) {
173 if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
174 return cst.getValue();
175 if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
176 LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n");
177 if (auto structTy =
178 llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) {
179 mlir::LLVM::InsertValueOp currentInsert = insert;
180 mlir::LLVM::InsertValueOp lastInsert;
181 std::deque<InsertChainBackwardFolder> folderStorage;
182 InsertChainBackwardFolder inFlightList(structTy, &folderStorage);
183 while (currentInsert) {
184 mlir::Attribute attr =
185 getAttrIfConstant(currentInsert.getValue(), rewriter);
186 if (!attr)
187 return llvm::failure();
188 if (!inFlightList.pushValue(attr, currentInsert.getPosition()))
189 return llvm::failure();
190 lastInsert = currentInsert;
191 currentInsert = currentInsert.getContainer()
192 .getDefiningOp<mlir::LLVM::InsertValueOp>();
193 }
194 mlir::Attribute defaultVal;
195 if (lastInsert) {
196 if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>())
197 defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext());
198 else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>())
199 defaultVal = mlir::LLVM::UndefAttr::get(val.getContext());
200 }
201 if (!defaultVal) {
202 LLVM_DEBUG(llvm::dbgs()
203 << "insert chain initial value is not Zero or Undef\n");
204 return llvm::failure();
205 }
206 return inFlightList.finalize(defaultVal);
207 }
208 }
209 return llvm::failure();
210}
211

source code of flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp