1 | //===-- Optimizer/Builder/TemporaryStorage.cpp ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // Implementation of utility data structures to create and manipulate temporary |
9 | // storages to stack Fortran values or pointers in HLFIR. |
10 | //===----------------------------------------------------------------------===// |
11 | |
12 | #include "flang/Optimizer/Builder/TemporaryStorage.h" |
13 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
14 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
15 | #include "flang/Optimizer/Builder/Runtime/TemporaryStack.h" |
16 | #include "flang/Optimizer/Builder/Todo.h" |
17 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // fir::factory::Counter implementation. |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | fir::factory::Counter::Counter(mlir::Location loc, fir::FirOpBuilder &builder, |
24 | mlir::Value initialValue, |
25 | bool canCountThroughLoops) |
26 | : canCountThroughLoops{canCountThroughLoops}, initialValue{initialValue} { |
27 | mlir::Type type = initialValue.getType(); |
28 | one = builder.createIntegerConstant(loc, type, 1); |
29 | if (canCountThroughLoops) { |
30 | index = builder.createTemporary(loc, type); |
31 | builder.create<fir::StoreOp>(loc, initialValue, index); |
32 | } else { |
33 | index = initialValue; |
34 | } |
35 | } |
36 | |
37 | mlir::Value |
38 | fir::factory::Counter::getAndIncrementIndex(mlir::Location loc, |
39 | fir::FirOpBuilder &builder) { |
40 | if (canCountThroughLoops) { |
41 | mlir::Value indexValue = builder.create<fir::LoadOp>(loc, index); |
42 | mlir::Value newValue = |
43 | builder.create<mlir::arith::AddIOp>(loc, indexValue, one); |
44 | builder.create<fir::StoreOp>(loc, newValue, index); |
45 | return indexValue; |
46 | } |
47 | mlir::Value indexValue = index; |
48 | index = builder.create<mlir::arith::AddIOp>(loc, indexValue, one); |
49 | return indexValue; |
50 | } |
51 | |
52 | void fir::factory::Counter::reset(mlir::Location loc, |
53 | fir::FirOpBuilder &builder) { |
54 | if (canCountThroughLoops) |
55 | builder.create<fir::StoreOp>(loc, initialValue, index); |
56 | else |
57 | index = initialValue; |
58 | } |
59 | |
60 | //===----------------------------------------------------------------------===// |
61 | // fir::factory::HomogeneousScalarStack implementation. |
62 | //===----------------------------------------------------------------------===// |
63 | |
64 | fir::factory::HomogeneousScalarStack::HomogeneousScalarStack( |
65 | mlir::Location loc, fir::FirOpBuilder &builder, |
66 | fir::SequenceType declaredType, mlir::Value extent, |
67 | llvm::ArrayRef<mlir::Value> lengths, bool allocateOnHeap, |
68 | bool stackThroughLoops, llvm::StringRef tempName) |
69 | : allocateOnHeap{allocateOnHeap}, |
70 | counter{loc, builder, |
71 | builder.createIntegerConstant(loc, builder.getIndexType(), 1), |
72 | stackThroughLoops} { |
73 | // Allocate the temporary storage. |
74 | llvm::SmallVector<mlir::Value, 1> extents{extent}; |
75 | mlir::Value tempStorage; |
76 | if (allocateOnHeap) |
77 | tempStorage = builder.createHeapTemporary(loc, declaredType, tempName, |
78 | extents, lengths); |
79 | else |
80 | tempStorage = |
81 | builder.createTemporary(loc, declaredType, tempName, extents, lengths); |
82 | |
83 | mlir::Value shape = builder.genShape(loc, extents); |
84 | temp = builder |
85 | .create<hlfir::DeclareOp>(loc, tempStorage, tempName, shape, |
86 | lengths, fir::FortranVariableFlagsAttr{}) |
87 | .getBase(); |
88 | } |
89 | |
90 | void fir::factory::HomogeneousScalarStack::pushValue(mlir::Location loc, |
91 | fir::FirOpBuilder &builder, |
92 | mlir::Value value) { |
93 | hlfir::Entity entity{value}; |
94 | assert(entity.isScalar() && "cannot use inlined temp with array" ); |
95 | mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder); |
96 | hlfir::Entity tempElement = hlfir::getElementAt( |
97 | loc, builder, hlfir::Entity{temp}, mlir::ValueRange{indexValue}); |
98 | // TODO: "copy" would probably be better than assign to ensure there are no |
99 | // side effects (user assignments, temp, lhs finalization)? |
100 | // This only makes a difference for derived types, and for now derived types |
101 | // will use the runtime strategy to avoid any bad behaviors. So the todo |
102 | // below should not get hit but is added as a remainder/safety. |
103 | if (!entity.hasIntrinsicType()) |
104 | TODO(loc, "creating inlined temporary stack for derived types" ); |
105 | builder.create<hlfir::AssignOp>(loc, value, tempElement); |
106 | } |
107 | |
108 | void fir::factory::HomogeneousScalarStack::resetFetchPosition( |
109 | mlir::Location loc, fir::FirOpBuilder &builder) { |
110 | counter.reset(loc, builder); |
111 | } |
112 | |
113 | mlir::Value |
114 | fir::factory::HomogeneousScalarStack::fetch(mlir::Location loc, |
115 | fir::FirOpBuilder &builder) { |
116 | mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder); |
117 | hlfir::Entity tempElement = hlfir::getElementAt( |
118 | loc, builder, hlfir::Entity{temp}, mlir::ValueRange{indexValue}); |
119 | return hlfir::loadTrivialScalar(loc, builder, tempElement); |
120 | } |
121 | |
122 | void fir::factory::HomogeneousScalarStack::destroy(mlir::Location loc, |
123 | fir::FirOpBuilder &builder) { |
124 | if (allocateOnHeap) { |
125 | auto declare = temp.getDefiningOp<hlfir::DeclareOp>(); |
126 | assert(declare && "temp must have been declared" ); |
127 | builder.create<fir::FreeMemOp>(loc, declare.getMemref()); |
128 | } |
129 | } |
130 | |
131 | hlfir::Entity fir::factory::HomogeneousScalarStack::moveStackAsArrayExpr( |
132 | mlir::Location loc, fir::FirOpBuilder &builder) { |
133 | mlir::Value mustFree = builder.createBool(loc, allocateOnHeap); |
134 | auto hlfirExpr = builder.create<hlfir::AsExprOp>(loc, temp, mustFree); |
135 | return hlfir::Entity{hlfirExpr}; |
136 | } |
137 | |
138 | //===----------------------------------------------------------------------===// |
139 | // fir::factory::SimpleCopy implementation. |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | fir::factory::SimpleCopy::SimpleCopy(mlir::Location loc, |
143 | fir::FirOpBuilder &builder, |
144 | hlfir::Entity source, |
145 | llvm::StringRef tempName) { |
146 | // Use hlfir.as_expr and hlfir.associate to create a copy and leave |
147 | // bufferization deals with how best to make the copy. |
148 | if (source.isVariable()) |
149 | source = hlfir::Entity{builder.create<hlfir::AsExprOp>(loc, source)}; |
150 | copy = hlfir::genAssociateExpr(loc, builder, source, |
151 | source.getFortranElementType(), tempName); |
152 | } |
153 | |
154 | void fir::factory::SimpleCopy::destroy(mlir::Location loc, |
155 | fir::FirOpBuilder &builder) { |
156 | builder.create<hlfir::EndAssociateOp>(loc, copy); |
157 | } |
158 | |
159 | //===----------------------------------------------------------------------===// |
160 | // fir::factory::AnyValueStack implementation. |
161 | //===----------------------------------------------------------------------===// |
162 | |
163 | fir::factory::AnyValueStack::AnyValueStack(mlir::Location loc, |
164 | fir::FirOpBuilder &builder, |
165 | mlir::Type valueStaticType) |
166 | : valueStaticType{valueStaticType}, |
167 | counter{loc, builder, |
168 | builder.createIntegerConstant(loc, builder.getI64Type(), 0), |
169 | /*stackThroughLoops=*/true} { |
170 | opaquePtr = fir::runtime::genCreateValueStack(loc, builder); |
171 | // Compute the storage type. I1 are stored as fir.logical<1>. This is required |
172 | // to use descriptor. |
173 | mlir::Type storageType = |
174 | hlfir::getFortranElementOrSequenceType(valueStaticType); |
175 | mlir::Type i1Type = builder.getI1Type(); |
176 | if (storageType == i1Type) |
177 | storageType = fir::LogicalType::get(builder.getContext(), 1); |
178 | assert(hlfir::getFortranElementType(storageType) != i1Type && |
179 | "array of i1 should not be used" ); |
180 | mlir::Type heapType = fir::HeapType::get(storageType); |
181 | mlir::Type boxType; |
182 | if (hlfir::isPolymorphicType(valueStaticType)) |
183 | boxType = fir::ClassType::get(heapType); |
184 | else |
185 | boxType = fir::BoxType::get(heapType); |
186 | retValueBox = builder.createTemporary(loc, boxType); |
187 | } |
188 | |
189 | void fir::factory::AnyValueStack::pushValue(mlir::Location loc, |
190 | fir::FirOpBuilder &builder, |
191 | mlir::Value value) { |
192 | hlfir::Entity entity{value}; |
193 | mlir::Type storageElementType = |
194 | hlfir::getFortranElementType(retValueBox.getType()); |
195 | auto [box, maybeCleanUp] = |
196 | hlfir::convertToBox(loc, builder, entity, storageElementType); |
197 | fir::runtime::genPushValue(loc, builder, opaquePtr, fir::getBase(box)); |
198 | if (maybeCleanUp) |
199 | (*maybeCleanUp)(); |
200 | } |
201 | |
202 | void fir::factory::AnyValueStack::resetFetchPosition( |
203 | mlir::Location loc, fir::FirOpBuilder &builder) { |
204 | counter.reset(loc, builder); |
205 | } |
206 | |
207 | mlir::Value fir::factory::AnyValueStack::fetch(mlir::Location loc, |
208 | fir::FirOpBuilder &builder) { |
209 | mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder); |
210 | fir::runtime::genValueAt(loc, builder, opaquePtr, indexValue, retValueBox); |
211 | // Dereference the allocatable "retValueBox", and load if trivial scalar |
212 | // value. |
213 | mlir::Value result = |
214 | hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{retValueBox}); |
215 | if (valueStaticType != result.getType()) { |
216 | // Cast back saved simple scalars stored with another type to their original |
217 | // type (like i1). |
218 | if (fir::isa_trivial(valueStaticType)) |
219 | return builder.createConvert(loc, valueStaticType, result); |
220 | // Memory type mismatches (e.g. fir.ref vs fir.heap) or hlfir.expr vs |
221 | // variable type mismatches are OK, but the base Fortran type must be the |
222 | // same. |
223 | assert(hlfir::getFortranElementOrSequenceType(valueStaticType) == |
224 | hlfir::getFortranElementOrSequenceType(result.getType()) && |
225 | "non trivial values must be saved with their original type" ); |
226 | } |
227 | return result; |
228 | } |
229 | |
230 | void fir::factory::AnyValueStack::destroy(mlir::Location loc, |
231 | fir::FirOpBuilder &builder) { |
232 | fir::runtime::genDestroyValueStack(loc, builder, opaquePtr); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // fir::factory::AnyVariableStack implementation. |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | fir::factory::AnyVariableStack::AnyVariableStack(mlir::Location loc, |
240 | fir::FirOpBuilder &builder, |
241 | mlir::Type variableStaticType) |
242 | : variableStaticType{variableStaticType}, |
243 | counter{loc, builder, |
244 | builder.createIntegerConstant(loc, builder.getI64Type(), 0), |
245 | /*stackThroughLoops=*/true} { |
246 | opaquePtr = fir::runtime::genCreateDescriptorStack(loc, builder); |
247 | mlir::Type storageType = |
248 | hlfir::getFortranElementOrSequenceType(variableStaticType); |
249 | mlir::Type ptrType = fir::PointerType::get(storageType); |
250 | mlir::Type boxType; |
251 | if (hlfir::isPolymorphicType(variableStaticType)) |
252 | boxType = fir::ClassType::get(ptrType); |
253 | else |
254 | boxType = fir::BoxType::get(ptrType); |
255 | retValueBox = builder.createTemporary(loc, boxType); |
256 | } |
257 | |
258 | void fir::factory::AnyVariableStack::pushValue(mlir::Location loc, |
259 | fir::FirOpBuilder &builder, |
260 | mlir::Value variable) { |
261 | hlfir::Entity entity{variable}; |
262 | mlir::Type storageElementType = |
263 | hlfir::getFortranElementType(retValueBox.getType()); |
264 | auto [box, maybeCleanUp] = |
265 | hlfir::convertToBox(loc, builder, entity, storageElementType); |
266 | fir::runtime::genPushDescriptor(loc, builder, opaquePtr, fir::getBase(box)); |
267 | if (maybeCleanUp) |
268 | (*maybeCleanUp)(); |
269 | } |
270 | |
271 | void fir::factory::AnyVariableStack::resetFetchPosition( |
272 | mlir::Location loc, fir::FirOpBuilder &builder) { |
273 | counter.reset(loc, builder); |
274 | } |
275 | |
276 | mlir::Value fir::factory::AnyVariableStack::fetch(mlir::Location loc, |
277 | fir::FirOpBuilder &builder) { |
278 | mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder); |
279 | fir::runtime::genDescriptorAt(loc, builder, opaquePtr, indexValue, |
280 | retValueBox); |
281 | hlfir::Entity retBox{builder.create<fir::LoadOp>(loc, retValueBox)}; |
282 | // The runtime always tracks variable as address, but the form of the variable |
283 | // that was saved may be different (raw address, fir.boxchar), ensure |
284 | // the returned variable has the same form of the one that was saved. |
285 | if (mlir::isa<fir::BaseBoxType>(variableStaticType)) |
286 | return builder.createConvert(loc, variableStaticType, retBox); |
287 | if (mlir::isa<fir::BoxCharType>(variableStaticType)) |
288 | return hlfir::genVariableBoxChar(loc, builder, retBox); |
289 | mlir::Value rawAddr = genVariableRawAddress(loc, builder, retBox); |
290 | return builder.createConvert(loc, variableStaticType, rawAddr); |
291 | } |
292 | |
293 | void fir::factory::AnyVariableStack::destroy(mlir::Location loc, |
294 | fir::FirOpBuilder &builder) { |
295 | fir::runtime::genDestroyDescriptorStack(loc, builder, opaquePtr); |
296 | } |
297 | |
298 | //===----------------------------------------------------------------------===// |
299 | // fir::factory::AnyVectorSubscriptStack implementation. |
300 | //===----------------------------------------------------------------------===// |
301 | |
302 | fir::factory::AnyVectorSubscriptStack::AnyVectorSubscriptStack( |
303 | mlir::Location loc, fir::FirOpBuilder &builder, |
304 | mlir::Type variableStaticType, bool shapeCanBeSavedAsRegister, int rank) |
305 | : AnyVariableStack{loc, builder, variableStaticType} { |
306 | if (shapeCanBeSavedAsRegister) { |
307 | shapeTemp = |
308 | std::unique_ptr<TemporaryStorage>(new TemporaryStorage{SSARegister{}}); |
309 | return; |
310 | } |
311 | // The shape will be tracked as the dimension inside a descriptor because |
312 | // that is the easiest from a lowering point of view, and this is an |
313 | // edge case situation that will probably not very well be exercised. |
314 | mlir::Type type = |
315 | fir::BoxType::get(builder.getVarLenSeqTy(builder.getI32Type(), rank)); |
316 | boxType = type; |
317 | shapeTemp = std::unique_ptr<TemporaryStorage>( |
318 | new TemporaryStorage{AnyVariableStack{loc, builder, type}}); |
319 | } |
320 | |
321 | void fir::factory::AnyVectorSubscriptStack::pushShape( |
322 | mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) { |
323 | if (boxType) { |
324 | // The shape is saved as a dimensions inside a descriptors. |
325 | mlir::Type refType = fir::ReferenceType::get( |
326 | hlfir::getFortranElementOrSequenceType(*boxType)); |
327 | mlir::Value null = builder.createNullConstant(loc, refType); |
328 | mlir::Value descriptor = |
329 | builder.create<fir::EmboxOp>(loc, *boxType, null, shape); |
330 | shapeTemp->pushValue(loc, builder, descriptor); |
331 | return; |
332 | } |
333 | // Otherwise, simply keep track of the fir.shape itself, it is invariant. |
334 | shapeTemp->cast<SSARegister>().pushValue(loc, builder, shape); |
335 | } |
336 | |
337 | void fir::factory::AnyVectorSubscriptStack::resetFetchPosition( |
338 | mlir::Location loc, fir::FirOpBuilder &builder) { |
339 | static_cast<AnyVariableStack *>(this)->resetFetchPosition(loc, builder); |
340 | shapeTemp->resetFetchPosition(loc, builder); |
341 | } |
342 | |
343 | mlir::Value |
344 | fir::factory::AnyVectorSubscriptStack::fetchShape(mlir::Location loc, |
345 | fir::FirOpBuilder &builder) { |
346 | if (boxType) { |
347 | hlfir::Entity descriptor{shapeTemp->fetch(loc, builder)}; |
348 | return hlfir::genShape(loc, builder, descriptor); |
349 | } |
350 | return shapeTemp->cast<SSARegister>().fetch(loc, builder); |
351 | } |
352 | |
353 | void fir::factory::AnyVectorSubscriptStack::destroy( |
354 | mlir::Location loc, fir::FirOpBuilder &builder) { |
355 | static_cast<AnyVariableStack *>(this)->destroy(loc, builder); |
356 | shapeTemp->destroy(loc, builder); |
357 | } |
358 | |