1 | //===-- Lower/DirectivesCommon.h --------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | /// |
13 | /// A location to place directive utilities shared across multiple lowering |
14 | /// files, e.g. utilities shared in OpenMP and OpenACC. The header file can |
15 | /// be used for both declarations and templated/inline implementations |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #ifndef FORTRAN_LOWER_DIRECTIVES_COMMON_H |
19 | #define FORTRAN_LOWER_DIRECTIVES_COMMON_H |
20 | |
21 | #include "flang/Common/idioms.h" |
22 | #include "flang/Evaluate/tools.h" |
23 | #include "flang/Lower/AbstractConverter.h" |
24 | #include "flang/Lower/Bridge.h" |
25 | #include "flang/Lower/ConvertExpr.h" |
26 | #include "flang/Lower/ConvertVariable.h" |
27 | #include "flang/Lower/OpenACC.h" |
28 | #include "flang/Lower/OpenMP.h" |
29 | #include "flang/Lower/PFTBuilder.h" |
30 | #include "flang/Lower/StatementContext.h" |
31 | #include "flang/Lower/Support/Utils.h" |
32 | #include "flang/Optimizer/Builder/BoxValue.h" |
33 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
34 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
35 | #include "flang/Optimizer/Builder/Todo.h" |
36 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
37 | #include "flang/Parser/parse-tree.h" |
38 | #include "flang/Semantics/openmp-directive-sets.h" |
39 | #include "flang/Semantics/tools.h" |
40 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
41 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
42 | #include "mlir/Dialect/SCF/IR/SCF.h" |
43 | #include "mlir/IR/Value.h" |
44 | #include "llvm/Frontend/OpenMP/OMPConstants.h" |
45 | #include <list> |
46 | #include <type_traits> |
47 | |
48 | namespace Fortran { |
49 | namespace lower { |
50 | |
51 | /// Information gathered to generate bounds operation and data entry/exit |
52 | /// operations. |
53 | struct AddrAndBoundsInfo { |
54 | explicit AddrAndBoundsInfo() {} |
55 | explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput) |
56 | : addr(addr), rawInput(rawInput) {} |
57 | explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput, |
58 | mlir::Value isPresent) |
59 | : addr(addr), rawInput(rawInput), isPresent(isPresent) {} |
60 | mlir::Value addr = nullptr; |
61 | mlir::Value rawInput = nullptr; |
62 | mlir::Value isPresent = nullptr; |
63 | }; |
64 | |
65 | /// Checks if the assignment statement has a single variable on the RHS. |
66 | static inline bool checkForSingleVariableOnRHS( |
67 | const Fortran::parser::AssignmentStmt &assignmentStmt) { |
68 | const Fortran::parser::Expr &expr{ |
69 | std::get<Fortran::parser::Expr>(assignmentStmt.t)}; |
70 | const Fortran::common::Indirection<Fortran::parser::Designator> *designator = |
71 | std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>( |
72 | &expr.u); |
73 | return designator != nullptr; |
74 | } |
75 | |
76 | /// Checks if the symbol on the LHS of the assignment statement is present in |
77 | /// the RHS expression. |
78 | static inline bool |
79 | checkForSymbolMatch(const Fortran::parser::AssignmentStmt &assignmentStmt) { |
80 | const auto &var{std::get<Fortran::parser::Variable>(assignmentStmt.t)}; |
81 | const auto &expr{std::get<Fortran::parser::Expr>(assignmentStmt.t)}; |
82 | const auto *e{Fortran::semantics::GetExpr(expr)}; |
83 | const auto *v{Fortran::semantics::GetExpr(var)}; |
84 | auto varSyms{Fortran::evaluate::GetSymbolVector(*v)}; |
85 | const Fortran::semantics::Symbol &varSymbol{*varSyms.front()}; |
86 | for (const Fortran::semantics::Symbol &symbol : |
87 | Fortran::evaluate::GetSymbolVector(*e)) |
88 | if (varSymbol == symbol) |
89 | return true; |
90 | return false; |
91 | } |
92 | |
93 | /// Populates \p hint and \p memoryOrder with appropriate clause information |
94 | /// if present on atomic construct. |
95 | static inline void genOmpAtomicHintAndMemoryOrderClauses( |
96 | Fortran::lower::AbstractConverter &converter, |
97 | const Fortran::parser::OmpAtomicClauseList &clauseList, |
98 | mlir::IntegerAttr &hint, |
99 | mlir::omp::ClauseMemoryOrderKindAttr &memoryOrder) { |
100 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
101 | for (const Fortran::parser::OmpAtomicClause &clause : clauseList.v) { |
102 | if (const auto *ompClause = |
103 | std::get_if<Fortran::parser::OmpClause>(&clause.u)) { |
104 | if (const auto *hintClause = |
105 | std::get_if<Fortran::parser::OmpClause::Hint>(&ompClause->u)) { |
106 | const auto *expr = Fortran::semantics::GetExpr(hintClause->v); |
107 | uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr); |
108 | hint = firOpBuilder.getI64IntegerAttr(hintExprValue); |
109 | } |
110 | } else if (const auto *ompMemoryOrderClause = |
111 | std::get_if<Fortran::parser::OmpMemoryOrderClause>( |
112 | &clause.u)) { |
113 | if (std::get_if<Fortran::parser::OmpClause::Acquire>( |
114 | &ompMemoryOrderClause->v.u)) { |
115 | memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( |
116 | firOpBuilder.getContext(), |
117 | mlir::omp::ClauseMemoryOrderKind::Acquire); |
118 | } else if (std::get_if<Fortran::parser::OmpClause::Relaxed>( |
119 | &ompMemoryOrderClause->v.u)) { |
120 | memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( |
121 | firOpBuilder.getContext(), |
122 | mlir::omp::ClauseMemoryOrderKind::Relaxed); |
123 | } else if (std::get_if<Fortran::parser::OmpClause::SeqCst>( |
124 | &ompMemoryOrderClause->v.u)) { |
125 | memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( |
126 | firOpBuilder.getContext(), |
127 | mlir::omp::ClauseMemoryOrderKind::Seq_cst); |
128 | } else if (std::get_if<Fortran::parser::OmpClause::Release>( |
129 | &ompMemoryOrderClause->v.u)) { |
130 | memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( |
131 | firOpBuilder.getContext(), |
132 | mlir::omp::ClauseMemoryOrderKind::Release); |
133 | } |
134 | } |
135 | } |
136 | } |
137 | |
138 | /// Used to generate atomic.read operation which is created in existing |
139 | /// location set by builder. |
140 | template <typename AtomicListT> |
141 | static inline void genOmpAccAtomicCaptureStatement( |
142 | Fortran::lower::AbstractConverter &converter, mlir::Value fromAddress, |
143 | mlir::Value toAddress, |
144 | [[maybe_unused]] const AtomicListT *leftHandClauseList, |
145 | [[maybe_unused]] const AtomicListT *rightHandClauseList, |
146 | mlir::Type elementType, mlir::Location loc) { |
147 | // Generate `atomic.read` operation for atomic assigment statements |
148 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
149 | |
150 | if constexpr (std::is_same<AtomicListT, |
151 | Fortran::parser::OmpAtomicClauseList>()) { |
152 | // If no hint clause is specified, the effect is as if |
153 | // hint(omp_sync_hint_none) had been specified. |
154 | mlir::IntegerAttr hint = nullptr; |
155 | |
156 | mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; |
157 | if (leftHandClauseList) |
158 | genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, |
159 | hint, memoryOrder); |
160 | if (rightHandClauseList) |
161 | genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, |
162 | hint, memoryOrder); |
163 | firOpBuilder.create<mlir::omp::AtomicReadOp>( |
164 | loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType), hint, |
165 | memoryOrder); |
166 | } else { |
167 | firOpBuilder.create<mlir::acc::AtomicReadOp>( |
168 | loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType)); |
169 | } |
170 | } |
171 | |
172 | /// Used to generate atomic.write operation which is created in existing |
173 | /// location set by builder. |
174 | template <typename AtomicListT> |
175 | static inline void genOmpAccAtomicWriteStatement( |
176 | Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr, |
177 | mlir::Value rhsExpr, [[maybe_unused]] const AtomicListT *leftHandClauseList, |
178 | [[maybe_unused]] const AtomicListT *rightHandClauseList, mlir::Location loc, |
179 | mlir::Value *evaluatedExprValue = nullptr) { |
180 | // Generate `atomic.write` operation for atomic assignment statements |
181 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
182 | |
183 | if constexpr (std::is_same<AtomicListT, |
184 | Fortran::parser::OmpAtomicClauseList>()) { |
185 | // If no hint clause is specified, the effect is as if |
186 | // hint(omp_sync_hint_none) had been specified. |
187 | mlir::IntegerAttr hint = nullptr; |
188 | mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; |
189 | if (leftHandClauseList) |
190 | genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, |
191 | hint, memoryOrder); |
192 | if (rightHandClauseList) |
193 | genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, |
194 | hint, memoryOrder); |
195 | firOpBuilder.create<mlir::omp::AtomicWriteOp>(loc, lhsAddr, rhsExpr, hint, |
196 | memoryOrder); |
197 | } else { |
198 | firOpBuilder.create<mlir::acc::AtomicWriteOp>(loc, lhsAddr, rhsExpr); |
199 | } |
200 | } |
201 | |
202 | /// Used to generate atomic.update operation which is created in existing |
203 | /// location set by builder. |
204 | template <typename AtomicListT> |
205 | static inline void genOmpAccAtomicUpdateStatement( |
206 | Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr, |
207 | mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable, |
208 | const Fortran::parser::Expr &assignmentStmtExpr, |
209 | [[maybe_unused]] const AtomicListT *leftHandClauseList, |
210 | [[maybe_unused]] const AtomicListT *rightHandClauseList, mlir::Location loc, |
211 | mlir::Operation *atomicCaptureOp = nullptr) { |
212 | // Generate `atomic.update` operation for atomic assignment statements |
213 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
214 | mlir::Location currentLocation = converter.getCurrentLocation(); |
215 | |
216 | // Create the omp.atomic.update or acc.atomic.update operation |
217 | // |
218 | // func.func @_QPsb() { |
219 | // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} |
220 | // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"} |
221 | // %2 = fir.load %1 : !fir.ref<i32> |
222 | // omp.atomic.update %0 : !fir.ref<i32> { |
223 | // ^bb0(%arg0: i32): |
224 | // %3 = arith.addi %arg0, %2 : i32 |
225 | // omp.yield(%3 : i32) |
226 | // } |
227 | // return |
228 | // } |
229 | |
230 | auto getArgExpression = |
231 | [](std::list<parser::ActualArgSpec>::const_iterator it) { |
232 | const auto &arg{std::get<parser::ActualArg>((*it).t)}; |
233 | const auto *parserExpr{ |
234 | std::get_if<common::Indirection<parser::Expr>>(&arg.u)}; |
235 | return parserExpr; |
236 | }; |
237 | |
238 | // Lower any non atomic sub-expression before the atomic operation, and |
239 | // map its lowered value to the semantic representation. |
240 | Fortran::lower::ExprToValueMap exprValueOverrides; |
241 | // Max and min intrinsics can have a list of Args. Hence we need a list |
242 | // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted. |
243 | llvm::SmallVector<const Fortran::lower::SomeExpr *> nonAtomicSubExprs; |
244 | Fortran::common::visit( |
245 | Fortran::common::visitors{ |
246 | [&](const common::Indirection<parser::FunctionReference> &funcRef) |
247 | -> void { |
248 | const auto &args{std::get<std::list<parser::ActualArgSpec>>( |
249 | funcRef.value().v.t)}; |
250 | std::list<parser::ActualArgSpec>::const_iterator beginIt = |
251 | args.begin(); |
252 | std::list<parser::ActualArgSpec>::const_iterator endIt = args.end(); |
253 | const auto *exprFirst{getArgExpression(beginIt)}; |
254 | if (exprFirst && exprFirst->value().source == |
255 | assignmentStmtVariable.GetSource()) { |
256 | // Add everything except the first |
257 | beginIt++; |
258 | } else { |
259 | // Add everything except the last |
260 | endIt--; |
261 | } |
262 | std::list<parser::ActualArgSpec>::const_iterator it; |
263 | for (it = beginIt; it != endIt; it++) { |
264 | const common::Indirection<parser::Expr> *expr = |
265 | getArgExpression(it); |
266 | if (expr) |
267 | nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr)); |
268 | } |
269 | }, |
270 | [&](const auto &op) -> void { |
271 | using T = std::decay_t<decltype(op)>; |
272 | if constexpr (std::is_base_of< |
273 | Fortran::parser::Expr::IntrinsicBinary, |
274 | T>::value) { |
275 | const auto &exprLeft{std::get<0>(op.t)}; |
276 | const auto &exprRight{std::get<1>(op.t)}; |
277 | if (exprLeft.value().source == assignmentStmtVariable.GetSource()) |
278 | nonAtomicSubExprs.push_back( |
279 | Fortran::semantics::GetExpr(exprRight)); |
280 | else |
281 | nonAtomicSubExprs.push_back( |
282 | Fortran::semantics::GetExpr(exprLeft)); |
283 | } |
284 | }, |
285 | }, |
286 | assignmentStmtExpr.u); |
287 | StatementContext nonAtomicStmtCtx; |
288 | if (!nonAtomicSubExprs.empty()) { |
289 | // Generate non atomic part before all the atomic operations. |
290 | auto insertionPoint = firOpBuilder.saveInsertionPoint(); |
291 | if (atomicCaptureOp) |
292 | firOpBuilder.setInsertionPoint(atomicCaptureOp); |
293 | mlir::Value nonAtomicVal; |
294 | for (auto *nonAtomicSubExpr : nonAtomicSubExprs) { |
295 | nonAtomicVal = fir::getBase(converter.genExprValue( |
296 | currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx)); |
297 | exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal); |
298 | } |
299 | if (atomicCaptureOp) |
300 | firOpBuilder.restoreInsertionPoint(insertionPoint); |
301 | } |
302 | |
303 | mlir::Operation *atomicUpdateOp = nullptr; |
304 | if constexpr (std::is_same<AtomicListT, |
305 | Fortran::parser::OmpAtomicClauseList>()) { |
306 | // If no hint clause is specified, the effect is as if |
307 | // hint(omp_sync_hint_none) had been specified. |
308 | mlir::IntegerAttr hint = nullptr; |
309 | mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; |
310 | if (leftHandClauseList) |
311 | genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, |
312 | hint, memoryOrder); |
313 | if (rightHandClauseList) |
314 | genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, |
315 | hint, memoryOrder); |
316 | atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>( |
317 | currentLocation, lhsAddr, hint, memoryOrder); |
318 | } else { |
319 | atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>( |
320 | currentLocation, lhsAddr); |
321 | } |
322 | |
323 | llvm::SmallVector<mlir::Type> varTys = {varType}; |
324 | llvm::SmallVector<mlir::Location> locs = {currentLocation}; |
325 | firOpBuilder.createBlock(&atomicUpdateOp->getRegion(index: 0), {}, varTys, locs); |
326 | mlir::Value val = |
327 | fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0)); |
328 | |
329 | exprValueOverrides.try_emplace( |
330 | Fortran::semantics::GetExpr(assignmentStmtVariable), val); |
331 | { |
332 | // statement context inside the atomic block. |
333 | converter.overrideExprValues(&exprValueOverrides); |
334 | Fortran::lower::StatementContext atomicStmtCtx; |
335 | mlir::Value rhsExpr = fir::getBase(converter.genExprValue( |
336 | *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx)); |
337 | mlir::Value convertResult = |
338 | firOpBuilder.createConvert(currentLocation, varType, rhsExpr); |
339 | if constexpr (std::is_same<AtomicListT, |
340 | Fortran::parser::OmpAtomicClauseList>()) { |
341 | firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult); |
342 | } else { |
343 | firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult); |
344 | } |
345 | converter.resetExprOverrides(); |
346 | } |
347 | firOpBuilder.setInsertionPointAfter(atomicUpdateOp); |
348 | } |
349 | |
350 | /// Processes an atomic construct with write clause. |
351 | template <typename AtomicT, typename AtomicListT> |
352 | void genOmpAccAtomicWrite(Fortran::lower::AbstractConverter &converter, |
353 | const AtomicT &atomicWrite, mlir::Location loc) { |
354 | const AtomicListT *rightHandClauseList = nullptr; |
355 | const AtomicListT *leftHandClauseList = nullptr; |
356 | if constexpr (std::is_same<AtomicListT, |
357 | Fortran::parser::OmpAtomicClauseList>()) { |
358 | // Get the address of atomic read operands. |
359 | rightHandClauseList = &std::get<2>(atomicWrite.t); |
360 | leftHandClauseList = &std::get<0>(atomicWrite.t); |
361 | } |
362 | |
363 | const Fortran::parser::AssignmentStmt &stmt = |
364 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
365 | atomicWrite.t) |
366 | .statement; |
367 | const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v; |
368 | Fortran::lower::StatementContext stmtCtx; |
369 | // Get the value and address of atomic write operands. |
370 | mlir::Value rhsExpr = |
371 | fir::getBase(converter.genExprValue(assign.rhs, stmtCtx)); |
372 | mlir::Value lhsAddr = |
373 | fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx)); |
374 | genOmpAccAtomicWriteStatement(converter, lhsAddr, rhsExpr, leftHandClauseList, |
375 | rightHandClauseList, loc); |
376 | } |
377 | |
378 | /// Processes an atomic construct with read clause. |
379 | template <typename AtomicT, typename AtomicListT> |
380 | void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter, |
381 | const AtomicT &atomicRead, mlir::Location loc) { |
382 | const AtomicListT *rightHandClauseList = nullptr; |
383 | const AtomicListT *leftHandClauseList = nullptr; |
384 | if constexpr (std::is_same<AtomicListT, |
385 | Fortran::parser::OmpAtomicClauseList>()) { |
386 | // Get the address of atomic read operands. |
387 | rightHandClauseList = &std::get<2>(atomicRead.t); |
388 | leftHandClauseList = &std::get<0>(atomicRead.t); |
389 | } |
390 | |
391 | const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>( |
392 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
393 | atomicRead.t) |
394 | .statement.t); |
395 | const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( |
396 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
397 | atomicRead.t) |
398 | .statement.t); |
399 | |
400 | Fortran::lower::StatementContext stmtCtx; |
401 | const Fortran::semantics::SomeExpr &fromExpr = |
402 | *Fortran::semantics::GetExpr(assignmentStmtExpr); |
403 | mlir::Type elementType = converter.genType(fromExpr); |
404 | mlir::Value fromAddress = |
405 | fir::getBase(converter.genExprAddr(fromExpr, stmtCtx)); |
406 | mlir::Value toAddress = fir::getBase(converter.genExprAddr( |
407 | *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); |
408 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
409 | if (fromAddress.getType() != toAddress.getType()) |
410 | fromAddress = |
411 | builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress); |
412 | genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress, |
413 | leftHandClauseList, rightHandClauseList, |
414 | elementType, loc); |
415 | } |
416 | |
417 | /// Processes an atomic construct with update clause. |
418 | template <typename AtomicT, typename AtomicListT> |
419 | void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter, |
420 | const AtomicT &atomicUpdate, mlir::Location loc) { |
421 | const AtomicListT *rightHandClauseList = nullptr; |
422 | const AtomicListT *leftHandClauseList = nullptr; |
423 | if constexpr (std::is_same<AtomicListT, |
424 | Fortran::parser::OmpAtomicClauseList>()) { |
425 | // Get the address of atomic read operands. |
426 | rightHandClauseList = &std::get<2>(atomicUpdate.t); |
427 | leftHandClauseList = &std::get<0>(atomicUpdate.t); |
428 | } |
429 | |
430 | const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>( |
431 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
432 | atomicUpdate.t) |
433 | .statement.t); |
434 | const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( |
435 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
436 | atomicUpdate.t) |
437 | .statement.t); |
438 | |
439 | Fortran::lower::StatementContext stmtCtx; |
440 | mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( |
441 | *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); |
442 | mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); |
443 | genOmpAccAtomicUpdateStatement<AtomicListT>( |
444 | converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr, |
445 | leftHandClauseList, rightHandClauseList, loc); |
446 | } |
447 | |
448 | /// Processes an atomic construct with no clause - which implies update clause. |
449 | template <typename AtomicT, typename AtomicListT> |
450 | void genOmpAtomic(Fortran::lower::AbstractConverter &converter, |
451 | const AtomicT &atomicConstruct, mlir::Location loc) { |
452 | const AtomicListT &atomicClauseList = |
453 | std::get<AtomicListT>(atomicConstruct.t); |
454 | const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>( |
455 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
456 | atomicConstruct.t) |
457 | .statement.t); |
458 | const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( |
459 | std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( |
460 | atomicConstruct.t) |
461 | .statement.t); |
462 | Fortran::lower::StatementContext stmtCtx; |
463 | mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( |
464 | *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); |
465 | mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); |
466 | // If atomic-clause is not present on the construct, the behaviour is as if |
467 | // the update clause is specified (for both OpenMP and OpenACC). |
468 | genOmpAccAtomicUpdateStatement<AtomicListT>( |
469 | converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr, |
470 | &atomicClauseList, nullptr, loc); |
471 | } |
472 | |
473 | /// Processes an atomic construct with capture clause. |
474 | template <typename AtomicT, typename AtomicListT> |
475 | void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter, |
476 | const AtomicT &atomicCapture, mlir::Location loc) { |
477 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
478 | |
479 | const Fortran::parser::AssignmentStmt &stmt1 = |
480 | std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement; |
481 | const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v; |
482 | const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)}; |
483 | const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)}; |
484 | const Fortran::parser::AssignmentStmt &stmt2 = |
485 | std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement; |
486 | const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v; |
487 | const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)}; |
488 | const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)}; |
489 | |
490 | // Pre-evaluate expressions to be used in the various operations inside |
491 | // `atomic.capture` since it is not desirable to have anything other than |
492 | // a `atomic.read`, `atomic.write`, or `atomic.update` operation |
493 | // inside `atomic.capture` |
494 | Fortran::lower::StatementContext stmtCtx; |
495 | mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg; |
496 | mlir::Type elementType; |
497 | // LHS evaluations are common to all combinations of `atomic.capture` |
498 | stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx)); |
499 | stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx)); |
500 | |
501 | // Operation specific RHS evaluations |
502 | if (checkForSingleVariableOnRHS(stmt1)) { |
503 | // Atomic capture construct is of the form [capture-stmt, update-stmt] or |
504 | // of the form [capture-stmt, write-stmt] |
505 | stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx)); |
506 | stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx)); |
507 | } else { |
508 | // Atomic capture construct is of the form [update-stmt, capture-stmt] |
509 | stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx)); |
510 | stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx)); |
511 | } |
512 | // Type information used in generation of `atomic.update` operation |
513 | mlir::Type stmt1VarType = |
514 | fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType(); |
515 | mlir::Type stmt2VarType = |
516 | fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType(); |
517 | |
518 | mlir::Operation *atomicCaptureOp = nullptr; |
519 | if constexpr (std::is_same<AtomicListT, |
520 | Fortran::parser::OmpAtomicClauseList>()) { |
521 | mlir::IntegerAttr hint = nullptr; |
522 | mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; |
523 | const AtomicListT &rightHandClauseList = std::get<2>(atomicCapture.t); |
524 | const AtomicListT &leftHandClauseList = std::get<0>(atomicCapture.t); |
525 | genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, |
526 | memoryOrder); |
527 | genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, |
528 | memoryOrder); |
529 | atomicCaptureOp = |
530 | firOpBuilder.create<mlir::omp::AtomicCaptureOp>(loc, hint, memoryOrder); |
531 | } else { |
532 | atomicCaptureOp = firOpBuilder.create<mlir::acc::AtomicCaptureOp>(loc); |
533 | } |
534 | |
535 | firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(index: 0))); |
536 | mlir::Block &block = atomicCaptureOp->getRegion(index: 0).back(); |
537 | firOpBuilder.setInsertionPointToStart(&block); |
538 | if (checkForSingleVariableOnRHS(stmt1)) { |
539 | if (checkForSymbolMatch(stmt2)) { |
540 | // Atomic capture construct is of the form [capture-stmt, update-stmt] |
541 | const Fortran::semantics::SomeExpr &fromExpr = |
542 | *Fortran::semantics::GetExpr(stmt1Expr); |
543 | elementType = converter.genType(fromExpr); |
544 | genOmpAccAtomicCaptureStatement<AtomicListT>( |
545 | converter, stmt1RHSArg, stmt1LHSArg, |
546 | /*leftHandClauseList=*/nullptr, |
547 | /*rightHandClauseList=*/nullptr, elementType, loc); |
548 | genOmpAccAtomicUpdateStatement<AtomicListT>( |
549 | converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr, |
550 | /*leftHandClauseList=*/nullptr, |
551 | /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp); |
552 | } else { |
553 | // Atomic capture construct is of the form [capture-stmt, write-stmt] |
554 | const Fortran::semantics::SomeExpr &fromExpr = |
555 | *Fortran::semantics::GetExpr(stmt1Expr); |
556 | elementType = converter.genType(fromExpr); |
557 | genOmpAccAtomicCaptureStatement<AtomicListT>( |
558 | converter, stmt1RHSArg, stmt1LHSArg, |
559 | /*leftHandClauseList=*/nullptr, |
560 | /*rightHandClauseList=*/nullptr, elementType, loc); |
561 | genOmpAccAtomicWriteStatement<AtomicListT>( |
562 | converter, stmt1RHSArg, stmt2RHSArg, |
563 | /*leftHandClauseList=*/nullptr, |
564 | /*rightHandClauseList=*/nullptr, loc); |
565 | } |
566 | } else { |
567 | // Atomic capture construct is of the form [update-stmt, capture-stmt] |
568 | firOpBuilder.setInsertionPointToEnd(&block); |
569 | const Fortran::semantics::SomeExpr &fromExpr = |
570 | *Fortran::semantics::GetExpr(stmt2Expr); |
571 | elementType = converter.genType(fromExpr); |
572 | genOmpAccAtomicCaptureStatement<AtomicListT>( |
573 | converter, stmt1LHSArg, stmt2LHSArg, |
574 | /*leftHandClauseList=*/nullptr, |
575 | /*rightHandClauseList=*/nullptr, elementType, loc); |
576 | firOpBuilder.setInsertionPointToStart(&block); |
577 | genOmpAccAtomicUpdateStatement<AtomicListT>( |
578 | converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr, |
579 | /*leftHandClauseList=*/nullptr, |
580 | /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp); |
581 | } |
582 | firOpBuilder.setInsertionPointToEnd(&block); |
583 | if constexpr (std::is_same<AtomicListT, |
584 | Fortran::parser::OmpAtomicClauseList>()) { |
585 | firOpBuilder.create<mlir::omp::TerminatorOp>(loc); |
586 | } else { |
587 | firOpBuilder.create<mlir::acc::TerminatorOp>(loc); |
588 | } |
589 | firOpBuilder.setInsertionPointToStart(&block); |
590 | } |
591 | |
592 | /// Create empty blocks for the current region. |
593 | /// These blocks replace blocks parented to an enclosing region. |
594 | template <typename... TerminatorOps> |
595 | void createEmptyRegionBlocks( |
596 | fir::FirOpBuilder &builder, |
597 | std::list<Fortran::lower::pft::Evaluation> &evaluationList) { |
598 | mlir::Region *region = &builder.getRegion(); |
599 | for (Fortran::lower::pft::Evaluation &eval : evaluationList) { |
600 | if (eval.block) { |
601 | if (eval.block->empty()) { |
602 | eval.block->erase(); |
603 | eval.block = builder.createBlock(region); |
604 | } else { |
605 | [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back(); |
606 | assert(mlir::isa<TerminatorOps...>(terminatorOp) && |
607 | "expected terminator op" ); |
608 | } |
609 | } |
610 | if (!eval.isDirective() && eval.hasNestedEvaluations()) |
611 | createEmptyRegionBlocks<TerminatorOps...>(builder, |
612 | eval.getNestedEvaluations()); |
613 | } |
614 | } |
615 | |
616 | inline AddrAndBoundsInfo |
617 | getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, |
618 | fir::FirOpBuilder &builder, |
619 | Fortran::lower::SymbolRef sym, mlir::Location loc) { |
620 | mlir::Value symAddr = converter.getSymbolAddress(sym); |
621 | mlir::Value rawInput = symAddr; |
622 | if (auto declareOp = |
623 | mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) { |
624 | symAddr = declareOp.getResults()[0]; |
625 | rawInput = declareOp.getResults()[1]; |
626 | } |
627 | |
628 | // TODO: Might need revisiting to handle for non-shared clauses |
629 | if (!symAddr) { |
630 | if (const auto *details = |
631 | sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { |
632 | symAddr = converter.getSymbolAddress(details->symbol()); |
633 | rawInput = symAddr; |
634 | } |
635 | } |
636 | |
637 | if (!symAddr) |
638 | llvm::report_fatal_error(reason: "could not retrieve symbol address" ); |
639 | |
640 | mlir::Value isPresent; |
641 | if (Fortran::semantics::IsOptional(sym)) |
642 | isPresent = |
643 | builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput); |
644 | |
645 | if (auto boxTy = |
646 | fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) { |
647 | if (boxTy.getEleTy().isa<fir::RecordType>()) |
648 | TODO(loc, "derived type" ); |
649 | |
650 | // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a |
651 | // `fir.ref<fir.class<T>>` type. |
652 | if (symAddr.getType().isa<fir::ReferenceType>()) { |
653 | if (Fortran::semantics::IsOptional(sym)) { |
654 | mlir::Value addr = |
655 | builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true) |
656 | .genThen([&]() { |
657 | mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr); |
658 | builder.create<fir::ResultOp>(loc, mlir::ValueRange{load}); |
659 | }) |
660 | .genElse([&] { |
661 | mlir::Value absent = |
662 | builder.create<fir::AbsentOp>(loc, boxTy); |
663 | builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent}); |
664 | }) |
665 | .getResults()[0]; |
666 | return AddrAndBoundsInfo(addr, rawInput, isPresent); |
667 | } |
668 | mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr); |
669 | return AddrAndBoundsInfo(addr, rawInput, isPresent); |
670 | } |
671 | } |
672 | return AddrAndBoundsInfo(symAddr, rawInput, isPresent); |
673 | } |
674 | |
675 | template <typename BoundsOp, typename BoundsType> |
676 | llvm::SmallVector<mlir::Value> |
677 | gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc, |
678 | fir::ExtendedValue dataExv, mlir::Value box, |
679 | bool collectValuesOnly = false) { |
680 | llvm::SmallVector<mlir::Value> values; |
681 | mlir::Value byteStride; |
682 | mlir::Type idxTy = builder.getIndexType(); |
683 | mlir::Type boundTy = builder.getType<BoundsType>(); |
684 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
685 | for (unsigned dim = 0; dim < dataExv.rank(); ++dim) { |
686 | mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim); |
687 | mlir::Value baseLb = |
688 | fir::factory::readLowerBound(builder, loc, dataExv, dim, one); |
689 | auto dimInfo = |
690 | builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, d); |
691 | mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0); |
692 | mlir::Value ub = |
693 | builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one); |
694 | if (dim == 0) // First stride is the element size. |
695 | byteStride = dimInfo.getByteStride(); |
696 | if (collectValuesOnly) { |
697 | values.push_back(Elt: lb); |
698 | values.push_back(Elt: ub); |
699 | values.push_back(Elt: dimInfo.getExtent()); |
700 | values.push_back(Elt: byteStride); |
701 | values.push_back(Elt: baseLb); |
702 | } else { |
703 | mlir::Value bound = builder.create<BoundsOp>( |
704 | loc, boundTy, lb, ub, dimInfo.getExtent(), byteStride, true, baseLb); |
705 | values.push_back(Elt: bound); |
706 | } |
707 | // Compute the stride for the next dimension. |
708 | byteStride = builder.create<mlir::arith::MulIOp>(loc, byteStride, |
709 | dimInfo.getExtent()); |
710 | } |
711 | return values; |
712 | } |
713 | |
714 | /// Generate the bounds operation from the descriptor information. |
715 | template <typename BoundsOp, typename BoundsType> |
716 | llvm::SmallVector<mlir::Value> |
717 | genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc, |
718 | Fortran::lower::AbstractConverter &converter, |
719 | fir::ExtendedValue dataExv, |
720 | Fortran::lower::AddrAndBoundsInfo &info) { |
721 | llvm::SmallVector<mlir::Value> bounds; |
722 | mlir::Type idxTy = builder.getIndexType(); |
723 | mlir::Type boundTy = builder.getType<BoundsType>(); |
724 | |
725 | assert(info.addr.getType().isa<fir::BaseBoxType>() && |
726 | "expect fir.box or fir.class" ); |
727 | |
728 | if (info.isPresent) { |
729 | llvm::SmallVector<mlir::Type> resTypes; |
730 | constexpr unsigned nbValuesPerBound = 5; |
731 | for (unsigned dim = 0; dim < dataExv.rank() * nbValuesPerBound; ++dim) |
732 | resTypes.push_back(Elt: idxTy); |
733 | |
734 | mlir::Operation::result_range ifRes = |
735 | builder.genIfOp(loc, resTypes, info.isPresent, /*withElseRegion=*/true) |
736 | .genThen([&]() { |
737 | llvm::SmallVector<mlir::Value> boundValues = |
738 | gatherBoundsOrBoundValues<BoundsOp, BoundsType>( |
739 | builder, loc, dataExv, info.addr, |
740 | /*collectValuesOnly=*/true); |
741 | builder.create<fir::ResultOp>(loc, boundValues); |
742 | }) |
743 | .genElse([&] { |
744 | // Box is not present. Populate bound values with default values. |
745 | llvm::SmallVector<mlir::Value> boundValues; |
746 | mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
747 | mlir::Value mOne = builder.createMinusOneInteger(loc, idxTy); |
748 | for (unsigned dim = 0; dim < dataExv.rank(); ++dim) { |
749 | boundValues.push_back(Elt: zero); // lb |
750 | boundValues.push_back(Elt: mOne); // ub |
751 | boundValues.push_back(Elt: zero); // extent |
752 | boundValues.push_back(Elt: zero); // byteStride |
753 | boundValues.push_back(Elt: zero); // baseLb |
754 | } |
755 | builder.create<fir::ResultOp>(loc, boundValues); |
756 | }) |
757 | .getResults(); |
758 | // Create the bound operations outside the if-then-else with the if op |
759 | // results. |
760 | for (unsigned i = 0; i < ifRes.size(); i += nbValuesPerBound) { |
761 | mlir::Value bound = builder.create<BoundsOp>( |
762 | loc, boundTy, ifRes[i], ifRes[i + 1], ifRes[i + 2], ifRes[i + 3], |
763 | true, ifRes[i + 4]); |
764 | bounds.push_back(Elt: bound); |
765 | } |
766 | } else { |
767 | bounds = gatherBoundsOrBoundValues<BoundsOp, BoundsType>( |
768 | builder, loc, dataExv, info.addr); |
769 | } |
770 | return bounds; |
771 | } |
772 | |
773 | /// Generate bounds operation for base array without any subscripts |
774 | /// provided. |
775 | template <typename BoundsOp, typename BoundsType> |
776 | llvm::SmallVector<mlir::Value> |
777 | genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, |
778 | Fortran::lower::AbstractConverter &converter, |
779 | fir::ExtendedValue dataExv, bool isAssumedSize) { |
780 | mlir::Type idxTy = builder.getIndexType(); |
781 | mlir::Type boundTy = builder.getType<BoundsType>(); |
782 | llvm::SmallVector<mlir::Value> bounds; |
783 | |
784 | if (dataExv.rank() == 0) |
785 | return bounds; |
786 | |
787 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
788 | const unsigned rank = dataExv.rank(); |
789 | for (unsigned dim = 0; dim < rank; ++dim) { |
790 | mlir::Value baseLb = |
791 | fir::factory::readLowerBound(builder, loc, dataExv, dim, one); |
792 | mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
793 | mlir::Value ub; |
794 | mlir::Value lb = zero; |
795 | mlir::Value ext = fir::factory::readExtent(builder, loc, dataExv, dim); |
796 | if (isAssumedSize && dim + 1 == rank) { |
797 | ext = zero; |
798 | ub = lb; |
799 | } else { |
800 | // ub = extent - 1 |
801 | ub = builder.create<mlir::arith::SubIOp>(loc, ext, one); |
802 | } |
803 | |
804 | mlir::Value bound = |
805 | builder.create<BoundsOp>(loc, boundTy, lb, ub, ext, one, false, baseLb); |
806 | bounds.push_back(Elt: bound); |
807 | } |
808 | return bounds; |
809 | } |
810 | |
811 | namespace detail { |
812 | template <typename T> // |
813 | static T &&AsRvalueRef(T &&t) { |
814 | return std::move(t); |
815 | } |
816 | template <typename T> // |
817 | static T AsRvalueRef(T &t) { |
818 | return t; |
819 | } |
820 | template <typename T> // |
821 | static T AsRvalueRef(const T &t) { |
822 | return t; |
823 | } |
824 | |
825 | // Helper class for stripping enclosing parentheses and a conversion that |
826 | // preserves type category. This is used for triplet elements, which are |
827 | // always of type integer(kind=8). The lower/upper bounds are converted to |
828 | // an "index" type, which is 64-bit, so the explicit conversion to kind=8 |
829 | // (if present) is not needed. When it's present, though, it causes generated |
830 | // names to contain "int(..., kind=8)". |
831 | struct PeelConvert { |
832 | template <Fortran::common::TypeCategory Category, int Kind> |
833 | static Fortran::semantics::MaybeExpr visit_with_category( |
834 | const Fortran::evaluate::Expr<Fortran::evaluate::Type<Category, Kind>> |
835 | &expr) { |
836 | return std::visit( |
837 | [](auto &&s) { return visit_with_category<Category, Kind>(s); }, |
838 | expr.u); |
839 | } |
840 | template <Fortran::common::TypeCategory Category, int Kind> |
841 | static Fortran::semantics::MaybeExpr visit_with_category( |
842 | const Fortran::evaluate::Convert<Fortran::evaluate::Type<Category, Kind>, |
843 | Category> &expr) { |
844 | return AsGenericExpr(AsRvalueRef(expr.left())); |
845 | } |
846 | template <Fortran::common::TypeCategory Category, int Kind, typename T> |
847 | static Fortran::semantics::MaybeExpr visit_with_category(const T &) { |
848 | return std::nullopt; // |
849 | } |
850 | template <Fortran::common::TypeCategory Category, typename T> |
851 | static Fortran::semantics::MaybeExpr visit_with_category(const T &) { |
852 | return std::nullopt; // |
853 | } |
854 | |
855 | template <Fortran::common::TypeCategory Category> |
856 | static Fortran::semantics::MaybeExpr |
857 | visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeKind<Category>> |
858 | &expr) { |
859 | return std::visit([](auto &&s) { return visit_with_category<Category>(s); }, |
860 | expr.u); |
861 | } |
862 | static Fortran::semantics::MaybeExpr |
863 | visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr) { |
864 | return std::visit([](auto &&s) { return visit(s); }, expr.u); |
865 | } |
866 | template <typename T> // |
867 | static Fortran::semantics::MaybeExpr visit(const T &) { |
868 | return std::nullopt; |
869 | } |
870 | }; |
871 | |
872 | static Fortran::semantics::SomeExpr |
873 | peelOuterConvert(Fortran::semantics::SomeExpr &expr) { |
874 | if (auto peeled = PeelConvert::visit(expr)) |
875 | return *peeled; |
876 | return expr; |
877 | } |
878 | } // namespace detail |
879 | |
880 | /// Generate bounds operations for an array section when subscripts are |
881 | /// provided. |
882 | template <typename BoundsOp, typename BoundsType> |
883 | llvm::SmallVector<mlir::Value> |
884 | genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, |
885 | Fortran::lower::AbstractConverter &converter, |
886 | Fortran::lower::StatementContext &stmtCtx, |
887 | const std::vector<Fortran::evaluate::Subscript> &subscripts, |
888 | std::stringstream &asFortran, fir::ExtendedValue &dataExv, |
889 | bool dataExvIsAssumedSize, AddrAndBoundsInfo &info, |
890 | bool treatIndexAsSection = false) { |
891 | int dimension = 0; |
892 | mlir::Type idxTy = builder.getIndexType(); |
893 | mlir::Type boundTy = builder.getType<BoundsType>(); |
894 | llvm::SmallVector<mlir::Value> bounds; |
895 | |
896 | mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
897 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
898 | const int dataExvRank = static_cast<int>(dataExv.rank()); |
899 | for (const auto &subscript : subscripts) { |
900 | const auto *triplet{std::get_if<Fortran::evaluate::Triplet>(&subscript.u)}; |
901 | if (triplet || treatIndexAsSection) { |
902 | if (dimension != 0) |
903 | asFortran << ','; |
904 | mlir::Value lbound, ubound, extent; |
905 | std::optional<std::int64_t> lval, uval; |
906 | mlir::Value baseLb = |
907 | fir::factory::readLowerBound(builder, loc, dataExv, dimension, one); |
908 | bool defaultLb = baseLb == one; |
909 | mlir::Value stride = one; |
910 | bool strideInBytes = false; |
911 | |
912 | if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) { |
913 | if (info.isPresent) { |
914 | stride = |
915 | builder |
916 | .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true) |
917 | .genThen([&]() { |
918 | mlir::Value d = |
919 | builder.createIntegerConstant(loc, idxTy, dimension); |
920 | auto dimInfo = builder.create<fir::BoxDimsOp>( |
921 | loc, idxTy, idxTy, idxTy, info.addr, d); |
922 | builder.create<fir::ResultOp>(loc, dimInfo.getByteStride()); |
923 | }) |
924 | .genElse([&] { |
925 | mlir::Value zero = |
926 | builder.createIntegerConstant(loc, idxTy, 0); |
927 | builder.create<fir::ResultOp>(loc, zero); |
928 | }) |
929 | .getResults()[0]; |
930 | } else { |
931 | mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); |
932 | auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, |
933 | idxTy, info.addr, d); |
934 | stride = dimInfo.getByteStride(); |
935 | } |
936 | strideInBytes = true; |
937 | } |
938 | |
939 | Fortran::semantics::MaybeExpr lower; |
940 | if (triplet) { |
941 | lower = Fortran::evaluate::AsGenericExpr(triplet->lower()); |
942 | } else { |
943 | // Case of IndirectSubscriptIntegerExpr |
944 | using IndirectSubscriptIntegerExpr = |
945 | Fortran::evaluate::IndirectSubscriptIntegerExpr; |
946 | using SubscriptInteger = Fortran::evaluate::SubscriptInteger; |
947 | Fortran::evaluate::Expr<SubscriptInteger> oneInt = |
948 | std::get<IndirectSubscriptIntegerExpr>(subscript.u).value(); |
949 | lower = Fortran::evaluate::AsGenericExpr(std::move(oneInt)); |
950 | if (lower->Rank() > 0) { |
951 | mlir::emitError( |
952 | loc, "vector subscript cannot be used for an array section" ); |
953 | break; |
954 | } |
955 | } |
956 | if (lower) { |
957 | lval = Fortran::evaluate::ToInt64(*lower); |
958 | if (lval) { |
959 | if (defaultLb) { |
960 | lbound = builder.createIntegerConstant(loc, idxTy, *lval - 1); |
961 | } else { |
962 | mlir::Value lb = builder.createIntegerConstant(loc, idxTy, *lval); |
963 | lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb); |
964 | } |
965 | asFortran << *lval; |
966 | } else { |
967 | mlir::Value lb = |
968 | fir::getBase(converter.genExprValue(loc, *lower, stmtCtx)); |
969 | lb = builder.createConvert(loc, baseLb.getType(), lb); |
970 | lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb); |
971 | asFortran << detail::peelOuterConvert(*lower).AsFortran(); |
972 | } |
973 | } else { |
974 | // If the lower bound is not specified, then the section |
975 | // starts from offset 0 of the dimension. |
976 | // Note that the lowerbound in the BoundsOp is always 0-based. |
977 | lbound = zero; |
978 | } |
979 | |
980 | if (!triplet) { |
981 | // If it is a scalar subscript, then the upper bound |
982 | // is equal to the lower bound, and the extent is one. |
983 | ubound = lbound; |
984 | extent = one; |
985 | } else { |
986 | asFortran << ':'; |
987 | Fortran::semantics::MaybeExpr upper = |
988 | Fortran::evaluate::AsGenericExpr(triplet->upper()); |
989 | |
990 | if (upper) { |
991 | uval = Fortran::evaluate::ToInt64(*upper); |
992 | if (uval) { |
993 | if (defaultLb) { |
994 | ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1); |
995 | } else { |
996 | mlir::Value ub = builder.createIntegerConstant(loc, idxTy, *uval); |
997 | ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb); |
998 | } |
999 | asFortran << *uval; |
1000 | } else { |
1001 | mlir::Value ub = |
1002 | fir::getBase(converter.genExprValue(loc, *upper, stmtCtx)); |
1003 | ub = builder.createConvert(loc, baseLb.getType(), ub); |
1004 | ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb); |
1005 | asFortran << detail::peelOuterConvert(*upper).AsFortran(); |
1006 | } |
1007 | } |
1008 | if (lower && upper) { |
1009 | if (lval && uval && *uval < *lval) { |
1010 | mlir::emitError(loc, "zero sized array section" ); |
1011 | break; |
1012 | } else { |
1013 | // Stride is mandatory in evaluate::Triplet. Make sure it's 1. |
1014 | auto val = Fortran::evaluate::ToInt64(triplet->GetStride()); |
1015 | if (!val || *val != 1) { |
1016 | mlir::emitError(loc, "stride cannot be specified on " |
1017 | "an array section" ); |
1018 | break; |
1019 | } |
1020 | } |
1021 | } |
1022 | |
1023 | if (info.isPresent && |
1024 | fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) { |
1025 | extent = |
1026 | builder |
1027 | .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true) |
1028 | .genThen([&]() { |
1029 | mlir::Value ext = fir::factory::readExtent( |
1030 | builder, loc, dataExv, dimension); |
1031 | builder.create<fir::ResultOp>(loc, ext); |
1032 | }) |
1033 | .genElse([&] { |
1034 | mlir::Value zero = |
1035 | builder.createIntegerConstant(loc, idxTy, 0); |
1036 | builder.create<fir::ResultOp>(loc, zero); |
1037 | }) |
1038 | .getResults()[0]; |
1039 | } else { |
1040 | extent = fir::factory::readExtent(builder, loc, dataExv, dimension); |
1041 | } |
1042 | |
1043 | if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) { |
1044 | extent = zero; |
1045 | if (ubound && lbound) { |
1046 | mlir::Value diff = |
1047 | builder.create<mlir::arith::SubIOp>(loc, ubound, lbound); |
1048 | extent = builder.create<mlir::arith::AddIOp>(loc, diff, one); |
1049 | } |
1050 | if (!ubound) |
1051 | ubound = lbound; |
1052 | } |
1053 | |
1054 | if (!ubound) { |
1055 | // ub = extent - 1 |
1056 | ubound = builder.create<mlir::arith::SubIOp>(loc, extent, one); |
1057 | } |
1058 | } |
1059 | mlir::Value bound = builder.create<BoundsOp>( |
1060 | loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb); |
1061 | bounds.push_back(bound); |
1062 | ++dimension; |
1063 | } |
1064 | } |
1065 | return bounds; |
1066 | } |
1067 | |
1068 | namespace detail { |
1069 | template <typename Ref, typename Expr> // |
1070 | std::optional<Ref> getRef(Expr &&expr) { |
1071 | if constexpr (std::is_same_v<llvm::remove_cvref_t<Expr>, |
1072 | Fortran::evaluate::DataRef>) { |
1073 | if (auto *ref = std::get_if<Ref>(&expr.u)) |
1074 | return *ref; |
1075 | return std::nullopt; |
1076 | } else { |
1077 | auto maybeRef = Fortran::evaluate::ExtractDataRef(expr); |
1078 | if (!maybeRef || !std::holds_alternative<Ref>(maybeRef->u)) |
1079 | return std::nullopt; |
1080 | return std::get<Ref>(maybeRef->u); |
1081 | } |
1082 | } |
1083 | } // namespace detail |
1084 | |
1085 | template <typename BoundsOp, typename BoundsType> |
1086 | AddrAndBoundsInfo gatherDataOperandAddrAndBounds( |
1087 | Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, |
1088 | semantics::SemanticsContext &semaCtx, |
1089 | Fortran::lower::StatementContext &stmtCtx, |
1090 | Fortran::semantics::SymbolRef symbol, |
1091 | const Fortran::semantics::MaybeExpr &maybeDesignator, |
1092 | mlir::Location operandLocation, std::stringstream &asFortran, |
1093 | llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) { |
1094 | using namespace Fortran; |
1095 | |
1096 | AddrAndBoundsInfo info; |
1097 | |
1098 | if (!maybeDesignator) { |
1099 | info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation); |
1100 | asFortran << symbol->name().ToString(); |
1101 | return info; |
1102 | } |
1103 | |
1104 | semantics::SomeExpr designator = *maybeDesignator; |
1105 | |
1106 | if ((designator.Rank() > 0 || treatIndexAsSection) && |
1107 | IsArrayElement(designator)) { |
1108 | auto arrayRef = detail::getRef<evaluate::ArrayRef>(designator); |
1109 | // This shouldn't fail after IsArrayElement(designator). |
1110 | assert(arrayRef && "Expecting ArrayRef" ); |
1111 | |
1112 | fir::ExtendedValue dataExv; |
1113 | bool dataExvIsAssumedSize = false; |
1114 | |
1115 | auto toMaybeExpr = [&](auto &&base) { |
1116 | using BaseType = llvm::remove_cvref_t<decltype(base)>; |
1117 | evaluate::ExpressionAnalyzer ea{semaCtx}; |
1118 | |
1119 | if constexpr (std::is_same_v<evaluate::NamedEntity, BaseType>) { |
1120 | if (auto *ref = base.UnwrapSymbolRef()) |
1121 | return ea.Designate(evaluate::DataRef{*ref}); |
1122 | if (auto *ref = base.UnwrapComponent()) |
1123 | return ea.Designate(evaluate::DataRef{*ref}); |
1124 | llvm_unreachable("Unexpected NamedEntity" ); |
1125 | } else { |
1126 | static_assert(std::is_same_v<semantics::SymbolRef, BaseType>); |
1127 | return ea.Designate(evaluate::DataRef{base}); |
1128 | } |
1129 | }; |
1130 | |
1131 | auto arrayBase = toMaybeExpr(arrayRef->base()); |
1132 | assert(arrayBase); |
1133 | |
1134 | if (detail::getRef<evaluate::Component>(*arrayBase)) { |
1135 | dataExv = converter.genExprAddr(operandLocation, *arrayBase, stmtCtx); |
1136 | info.addr = fir::getBase(dataExv); |
1137 | info.rawInput = info.addr; |
1138 | asFortran << arrayBase->AsFortran(); |
1139 | } else { |
1140 | const semantics::Symbol &sym = arrayRef->GetLastSymbol(); |
1141 | dataExvIsAssumedSize = |
1142 | Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate()); |
1143 | info = getDataOperandBaseAddr(converter, builder, sym, operandLocation); |
1144 | dataExv = converter.getSymbolExtendedValue(sym); |
1145 | asFortran << sym.name().ToString(); |
1146 | } |
1147 | |
1148 | if (!arrayRef->subscript().empty()) { |
1149 | asFortran << '('; |
1150 | bounds = genBoundsOps<BoundsOp, BoundsType>( |
1151 | builder, operandLocation, converter, stmtCtx, arrayRef->subscript(), |
1152 | asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection); |
1153 | } |
1154 | asFortran << ')'; |
1155 | } else if (auto compRef = detail::getRef<evaluate::Component>(designator)) { |
1156 | fir::ExtendedValue compExv = |
1157 | converter.genExprAddr(operandLocation, designator, stmtCtx); |
1158 | info.addr = fir::getBase(compExv); |
1159 | info.rawInput = info.addr; |
1160 | if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>()) |
1161 | bounds = genBaseBoundsOps<BoundsOp, BoundsType>(builder, operandLocation, |
1162 | converter, compExv, |
1163 | /*isAssumedSize=*/false); |
1164 | asFortran << designator.AsFortran(); |
1165 | |
1166 | if (semantics::IsOptional(compRef->GetLastSymbol())) { |
1167 | info.isPresent = builder.create<fir::IsPresentOp>( |
1168 | operandLocation, builder.getI1Type(), info.rawInput); |
1169 | } |
1170 | |
1171 | if (auto loadOp = |
1172 | mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) { |
1173 | if (fir::isAllocatableType(loadOp.getType()) || |
1174 | fir::isPointerType(loadOp.getType())) |
1175 | info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr); |
1176 | info.rawInput = info.addr; |
1177 | } |
1178 | |
1179 | // If the component is an allocatable or pointer the result of |
1180 | // genExprAddr will be the result of a fir.box_addr operation or |
1181 | // a fir.box_addr has been inserted just before. |
1182 | // Retrieve the box so we handle it like other descriptor. |
1183 | if (auto boxAddrOp = |
1184 | mlir::dyn_cast_or_null<fir::BoxAddrOp>(info.addr.getDefiningOp())) { |
1185 | info.addr = boxAddrOp.getVal(); |
1186 | info.rawInput = info.addr; |
1187 | bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>( |
1188 | builder, operandLocation, converter, compExv, info); |
1189 | } |
1190 | } else { |
1191 | if (detail::getRef<evaluate::ArrayRef>(designator)) { |
1192 | fir::ExtendedValue compExv = |
1193 | converter.genExprAddr(operandLocation, designator, stmtCtx); |
1194 | info.addr = fir::getBase(compExv); |
1195 | info.rawInput = info.addr; |
1196 | asFortran << designator.AsFortran(); |
1197 | } else if (auto symRef = detail::getRef<semantics::SymbolRef>(designator)) { |
1198 | // Scalar or full array. |
1199 | fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef); |
1200 | info = |
1201 | getDataOperandBaseAddr(converter, builder, *symRef, operandLocation); |
1202 | if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) { |
1203 | bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>( |
1204 | builder, operandLocation, converter, dataExv, info); |
1205 | } |
1206 | bool dataExvIsAssumedSize = |
1207 | Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate()); |
1208 | if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>()) |
1209 | bounds = genBaseBoundsOps<BoundsOp, BoundsType>( |
1210 | builder, operandLocation, converter, dataExv, dataExvIsAssumedSize); |
1211 | asFortran << symRef->get().name().ToString(); |
1212 | } else { // Unsupported |
1213 | llvm::report_fatal_error(reason: "Unsupported type of OpenACC operand" ); |
1214 | } |
1215 | } |
1216 | |
1217 | return info; |
1218 | } |
1219 | } // namespace lower |
1220 | } // namespace Fortran |
1221 | |
1222 | #endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H |
1223 | |