1//===-- OpenACC.cpp -- OpenACC directive lowering -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/OpenACC.h"
14
15#include "flang/Common/idioms.h"
16#include "flang/Lower/Bridge.h"
17#include "flang/Lower/ConvertType.h"
18#include "flang/Lower/DirectivesCommon.h"
19#include "flang/Lower/Mangler.h"
20#include "flang/Lower/PFTBuilder.h"
21#include "flang/Lower/StatementContext.h"
22#include "flang/Lower/Support/Utils.h"
23#include "flang/Optimizer/Builder/BoxValue.h"
24#include "flang/Optimizer/Builder/Complex.h"
25#include "flang/Optimizer/Builder/FIRBuilder.h"
26#include "flang/Optimizer/Builder/HLFIRTools.h"
27#include "flang/Optimizer/Builder/IntrinsicCall.h"
28#include "flang/Optimizer/Builder/Todo.h"
29#include "flang/Optimizer/Dialect/FIRType.h"
30#include "flang/Parser/parse-tree-visitor.h"
31#include "flang/Parser/parse-tree.h"
32#include "flang/Semantics/expression.h"
33#include "flang/Semantics/scope.h"
34#include "flang/Semantics/tools.h"
35#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
36#include "mlir/IR/MLIRContext.h"
37#include "mlir/Support/LLVM.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/Frontend/OpenACC/ACC.h.inc"
40#include "llvm/Support/CommandLine.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/ErrorHandling.h"
43
44#define DEBUG_TYPE "flang-lower-openacc"
45
46static llvm::cl::opt<bool> unwrapFirBox(
47 "openacc-unwrap-fir-box",
48 llvm::cl::desc(
49 "Whether to use the address from fix.box in data clause operations."),
50 llvm::cl::init(Val: false));
51
52static llvm::cl::opt<bool> generateDefaultBounds(
53 "openacc-generate-default-bounds",
54 llvm::cl::desc("Whether to generate default bounds for arrays."),
55 llvm::cl::init(Val: false));
56
57static llvm::cl::opt<bool> strideIncludeLowerExtent(
58 "openacc-stride-include-lower-extent",
59 llvm::cl::desc(
60 "Whether to include the lower dimensions extents in the stride."),
61 llvm::cl::init(Val: true));
62
63// Special value for * passed in device_type or gang clauses.
64static constexpr std::int64_t starCst = -1;
65
66static unsigned routineCounter = 0;
67static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
68static constexpr llvm::StringRef accPrivateInitName = "acc.private.init";
69static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init";
70static constexpr llvm::StringRef accFirDescriptorPostfix = "_desc";
71
72static mlir::Location
73genOperandLocation(Fortran::lower::AbstractConverter &converter,
74 const Fortran::parser::AccObject &accObject) {
75 mlir::Location loc = converter.genUnknownLocation();
76 Fortran::common::visit(
77 Fortran::common::visitors{
78 [&](const Fortran::parser::Designator &designator) {
79 loc = converter.genLocation(designator.source);
80 },
81 [&](const Fortran::parser::Name &name) {
82 loc = converter.genLocation(name.source);
83 }},
84 accObject.u);
85 return loc;
86}
87
88static void addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
89 llvm::SmallVectorImpl<int32_t> &operandSegments,
90 llvm::ArrayRef<mlir::Value> clauseOperands) {
91 operands.append(in_start: clauseOperands.begin(), in_end: clauseOperands.end());
92 operandSegments.push_back(Elt: clauseOperands.size());
93}
94
95static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
96 llvm::SmallVectorImpl<int32_t> &operandSegments,
97 const mlir::Value &clauseOperand) {
98 if (clauseOperand) {
99 operands.push_back(Elt: clauseOperand);
100 operandSegments.push_back(Elt: 1);
101 } else {
102 operandSegments.push_back(Elt: 0);
103 }
104}
105
106template <typename Op>
107static Op
108createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
109 mlir::Value baseAddr, std::stringstream &name,
110 mlir::SmallVector<mlir::Value> bounds, bool structured,
111 bool implicit, mlir::acc::DataClause dataClause,
112 mlir::Type retTy, llvm::ArrayRef<mlir::Value> async,
113 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
114 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
115 bool unwrapBoxAddr = false, mlir::Value isPresent = {}) {
116 mlir::Value varPtrPtr;
117 // The data clause may apply to either the box reference itself or the
118 // pointer to the data it holds. So use `unwrapBoxAddr` to decide.
119 // When we have a box value - assume it refers to the data inside box.
120 if (unwrapFirBox &&
121 ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) ||
122 fir::isa_box_type(baseAddr.getType()))) {
123 if (isPresent) {
124 mlir::Type ifRetTy =
125 mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))
126 .getEleTy();
127 if (!fir::isa_ref_type(ifRetTy))
128 ifRetTy = fir::ReferenceType::get(ifRetTy);
129 baseAddr =
130 builder
131 .genIfOp(loc, {ifRetTy}, isPresent,
132 /*withElseRegion=*/true)
133 .genThen([&]() {
134 if (fir::isBoxAddress(baseAddr.getType()))
135 baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
136 mlir::Value boxAddr =
137 builder.create<fir::BoxAddrOp>(loc, baseAddr);
138 builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr});
139 })
140 .genElse([&] {
141 mlir::Value absent =
142 builder.create<fir::AbsentOp>(loc, ifRetTy);
143 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
144 })
145 .getResults()[0];
146 } else {
147 if (fir::isBoxAddress(baseAddr.getType()))
148 baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
149 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
150 }
151 retTy = baseAddr.getType();
152 }
153
154 llvm::SmallVector<mlir::Value, 8> operands;
155 llvm::SmallVector<int32_t, 8> operandSegments;
156
157 addOperand(operands, operandSegments, clauseOperand: baseAddr);
158 addOperand(operands, operandSegments, clauseOperand: varPtrPtr);
159 addOperands(operands, operandSegments, clauseOperands: bounds);
160 addOperands(operands, operandSegments, clauseOperands: async);
161
162 Op op = builder.create<Op>(loc, retTy, operands);
163 op.setNameAttr(builder.getStringAttr(name.str()));
164 op.setStructured(structured);
165 op.setImplicit(implicit);
166 op.setDataClause(dataClause);
167 if (auto pointerLikeTy =
168 mlir::dyn_cast<mlir::acc::PointerLikeType>(baseAddr.getType())) {
169 op.setVarType(pointerLikeTy.getElementType());
170 } else {
171 assert(mlir::isa<mlir::acc::MappableType>(baseAddr.getType()) &&
172 "expected mappable");
173 op.setVarType(baseAddr.getType());
174 }
175
176 op->setAttr(Op::getOperandSegmentSizeAttr(),
177 builder.getDenseI32ArrayAttr(operandSegments));
178 if (!asyncDeviceTypes.empty())
179 op.setAsyncOperandsDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
180 if (!asyncOnlyDeviceTypes.empty())
181 op.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
182 return op;
183}
184
185static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
186 mlir::acc::DataClause clause) {
187 if (!op)
188 return;
189 op->setAttr(mlir::acc::getDeclareAttrName(),
190 mlir::acc::DeclareAttr::get(builder.getContext(),
191 mlir::acc::DataClauseAttr::get(
192 builder.getContext(), clause)));
193}
194
195static mlir::func::FuncOp
196createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
197 mlir::Location loc, llvm::StringRef funcName,
198 llvm::SmallVector<mlir::Type> argsTy = {},
199 llvm::SmallVector<mlir::Location> locs = {}) {
200 auto funcTy = mlir::FunctionType::get(context: modBuilder.getContext(), inputs: argsTy, results: {});
201 auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
202 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
203 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
204 locs);
205 builder.setInsertionPointToEnd(&funcOp.getRegion().back());
206 builder.create<mlir::func::ReturnOp>(loc);
207 builder.setInsertionPointToStart(&funcOp.getRegion().back());
208 return funcOp;
209}
210
211template <typename Op>
212static Op
213createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
214 const llvm::SmallVectorImpl<mlir::Value> &operands,
215 const llvm::SmallVectorImpl<int32_t> &operandSegments) {
216 llvm::ArrayRef<mlir::Type> argTy;
217 Op op = builder.create<Op>(loc, argTy, operands);
218 op->setAttr(Op::getOperandSegmentSizeAttr(),
219 builder.getDenseI32ArrayAttr(operandSegments));
220 return op;
221}
222
223template <typename EntryOp>
224static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
225 fir::FirOpBuilder &builder,
226 mlir::Location loc, mlir::Type descTy,
227 llvm::StringRef funcNamePrefix,
228 std::stringstream &asFortran,
229 mlir::acc::DataClause clause) {
230 auto crtInsPt = builder.saveInsertionPoint();
231 std::stringstream registerFuncName;
232 registerFuncName << funcNamePrefix.str()
233 << Fortran::lower::declarePostAllocSuffix.str();
234
235 if (!mlir::isa<fir::ReferenceType>(descTy))
236 descTy = fir::ReferenceType::get(descTy);
237 auto registerFuncOp = createDeclareFunc(
238 modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc});
239
240 llvm::SmallVector<mlir::Value> bounds;
241 std::stringstream asFortranDesc;
242 asFortranDesc << asFortran.str();
243 if (unwrapFirBox)
244 asFortranDesc << accFirDescriptorPostfix.str();
245
246 // Updating descriptor must occur before the mapping of the data so that
247 // attached data pointer is not overwritten.
248 mlir::acc::UpdateDeviceOp updateDeviceOp =
249 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
250 builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
251 /*structured=*/false, /*implicit=*/true,
252 mlir::acc::DataClause::acc_update_device, descTy,
253 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
254 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
255 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
256 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
257
258 if (unwrapFirBox) {
259 mlir::Value desc =
260 builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
261 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
262 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
263 EntryOp entryOp = createDataEntryOp<EntryOp>(
264 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
265 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
266 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
267 builder.create<mlir::acc::DeclareEnterOp>(
268 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
269 mlir::ValueRange(entryOp.getAccVar()));
270 }
271
272 modBuilder.setInsertionPointAfter(registerFuncOp);
273 builder.restoreInsertionPoint(crtInsPt);
274}
275
276template <typename ExitOp>
277static void createDeclareDeallocFuncWithArg(
278 mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
279 mlir::Type descTy, llvm::StringRef funcNamePrefix,
280 std::stringstream &asFortran, mlir::acc::DataClause clause) {
281 auto crtInsPt = builder.saveInsertionPoint();
282 // Generate the pre dealloc function.
283 std::stringstream preDeallocFuncName;
284 preDeallocFuncName << funcNamePrefix.str()
285 << Fortran::lower::declarePreDeallocSuffix.str();
286 if (!mlir::isa<fir::ReferenceType>(descTy))
287 descTy = fir::ReferenceType::get(descTy);
288 auto preDeallocOp = createDeclareFunc(
289 modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
290
291 mlir::Value var = preDeallocOp.getArgument(0);
292 if (unwrapFirBox) {
293 mlir::Value loadOp =
294 builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
295 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
296 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
297 var = boxAddrOp.getResult();
298 }
299
300 llvm::SmallVector<mlir::Value> bounds;
301 mlir::acc::GetDevicePtrOp entryOp =
302 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
303 builder, loc, var, asFortran, bounds,
304 /*structured=*/false, /*implicit=*/false, clause, var.getType(),
305 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
306 builder.create<mlir::acc::DeclareExitOp>(
307 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccVar()));
308
309 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
310 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
311 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
312 entryOp.getVar(), entryOp.getVarType(),
313 entryOp.getBounds(), entryOp.getAsyncOperands(),
314 entryOp.getAsyncOperandsDeviceTypeAttr(),
315 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
316 /*structured=*/false, /*implicit=*/false,
317 builder.getStringAttr(*entryOp.getName()));
318 else
319 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
320 entryOp.getBounds(), entryOp.getAsyncOperands(),
321 entryOp.getAsyncOperandsDeviceTypeAttr(),
322 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
323 /*structured=*/false, /*implicit=*/false,
324 builder.getStringAttr(*entryOp.getName()));
325
326 // Generate the post dealloc function.
327 modBuilder.setInsertionPointAfter(preDeallocOp);
328 std::stringstream postDeallocFuncName;
329 postDeallocFuncName << funcNamePrefix.str()
330 << Fortran::lower::declarePostDeallocSuffix.str();
331 auto postDeallocOp = createDeclareFunc(
332 modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
333
334 var = postDeallocOp.getArgument(0);
335 if (unwrapFirBox) {
336 var = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
337 asFortran << accFirDescriptorPostfix.str();
338 }
339
340 mlir::acc::UpdateDeviceOp updateDeviceOp =
341 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
342 builder, loc, var, asFortran, bounds,
343 /*structured=*/false, /*implicit=*/true,
344 mlir::acc::DataClause::acc_update_device, var.getType(),
345 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
346 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
347 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
348 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
349 modBuilder.setInsertionPointAfter(postDeallocOp);
350 builder.restoreInsertionPoint(crtInsPt);
351}
352
353Fortran::semantics::Symbol &
354getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
355 if (const auto *designator =
356 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
357 if (const auto *name =
358 Fortran::semantics::getDesignatorNameIfDataRef(*designator))
359 return *name->symbol;
360 if (const auto *arrayElement =
361 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
362 *designator)) {
363 const Fortran::parser::Name &name =
364 Fortran::parser::GetLastName(arrayElement->base);
365 return *name.symbol;
366 }
367 if (const auto *component =
368 Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
369 *designator)) {
370 return *component->component.symbol;
371 }
372 } else if (const auto *name =
373 std::get_if<Fortran::parser::Name>(&accObject.u)) {
374 return *name->symbol;
375 }
376 llvm::report_fatal_error(reason: "Could not find symbol");
377}
378
379/// Used to generate atomic.read operation which is created in existing
380/// location set by builder.
381static inline void
382genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter,
383 mlir::Value fromAddress, mlir::Value toAddress,
384 mlir::Type elementType, mlir::Location loc) {
385 // Generate `atomic.read` operation for atomic assigment statements
386 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
387
388 firOpBuilder.create<mlir::acc::AtomicReadOp>(
389 loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
390}
391
392/// Used to generate atomic.write operation which is created in existing
393/// location set by builder.
394static inline void
395genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter,
396 mlir::Value lhsAddr, mlir::Value rhsExpr,
397 mlir::Location loc,
398 mlir::Value *evaluatedExprValue = nullptr) {
399 // Generate `atomic.write` operation for atomic assignment statements
400 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
401
402 mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
403 // Create a conversion outside the capture block.
404 auto insertionPoint = firOpBuilder.saveInsertionPoint();
405 firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp());
406 rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
407 firOpBuilder.restoreInsertionPoint(insertionPoint);
408
409 firOpBuilder.create<mlir::acc::AtomicWriteOp>(loc, lhsAddr, rhsExpr);
410}
411
412/// Used to generate atomic.update operation which is created in existing
413/// location set by builder.
414static inline void genAtomicUpdateStatement(
415 Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
416 mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
417 const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc,
418 mlir::Operation *atomicCaptureOp = nullptr,
419 Fortran::lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
420 // Generate `atomic.update` operation for atomic assignment statements
421 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
422 mlir::Location currentLocation = converter.getCurrentLocation();
423
424 // Create the omp.atomic.update or acc.atomic.update operation
425 //
426 // func.func @_QPsb() {
427 // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
428 // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
429 // %2 = fir.load %1 : !fir.ref<i32>
430 // omp.atomic.update %0 : !fir.ref<i32> {
431 // ^bb0(%arg0: i32):
432 // %3 = arith.addi %arg0, %2 : i32
433 // omp.yield(%3 : i32)
434 // }
435 // return
436 // }
437
438 auto getArgExpression =
439 [](std::list<Fortran::parser::ActualArgSpec>::const_iterator it) {
440 const auto &arg{std::get<Fortran::parser::ActualArg>((*it).t)};
441 const auto *parserExpr{
442 std::get_if<Fortran::common::Indirection<Fortran::parser::Expr>>(
443 &arg.u)};
444 return parserExpr;
445 };
446
447 // Lower any non atomic sub-expression before the atomic operation, and
448 // map its lowered value to the semantic representation.
449 Fortran::lower::ExprToValueMap exprValueOverrides;
450 // Max and min intrinsics can have a list of Args. Hence we need a list
451 // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted.
452 llvm::SmallVector<const Fortran::lower::SomeExpr *> nonAtomicSubExprs;
453 Fortran::common::visit(
454 Fortran::common::visitors{
455 [&](const Fortran::common::Indirection<
456 Fortran::parser::FunctionReference> &funcRef) -> void {
457 const auto &args{
458 std::get<std::list<Fortran::parser::ActualArgSpec>>(
459 funcRef.value().v.t)};
460 std::list<Fortran::parser::ActualArgSpec>::const_iterator beginIt =
461 args.begin();
462 std::list<Fortran::parser::ActualArgSpec>::const_iterator endIt =
463 args.end();
464 const auto *exprFirst{getArgExpression(beginIt)};
465 if (exprFirst && exprFirst->value().source ==
466 assignmentStmtVariable.GetSource()) {
467 // Add everything except the first
468 beginIt++;
469 } else {
470 // Add everything except the last
471 endIt--;
472 }
473 std::list<Fortran::parser::ActualArgSpec>::const_iterator it;
474 for (it = beginIt; it != endIt; it++) {
475 const Fortran::common::Indirection<Fortran::parser::Expr> *expr =
476 getArgExpression(it);
477 if (expr)
478 nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr));
479 }
480 },
481 [&](const auto &op) -> void {
482 using T = std::decay_t<decltype(op)>;
483 if constexpr (std::is_base_of<
484 Fortran::parser::Expr::IntrinsicBinary,
485 T>::value) {
486 const auto &exprLeft{std::get<0>(op.t)};
487 const auto &exprRight{std::get<1>(op.t)};
488 if (exprLeft.value().source == assignmentStmtVariable.GetSource())
489 nonAtomicSubExprs.push_back(
490 Fortran::semantics::GetExpr(exprRight));
491 else
492 nonAtomicSubExprs.push_back(
493 Fortran::semantics::GetExpr(exprLeft));
494 }
495 },
496 },
497 assignmentStmtExpr.u);
498 Fortran::lower::StatementContext nonAtomicStmtCtx;
499 Fortran::lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
500 if (!nonAtomicSubExprs.empty()) {
501 // Generate non atomic part before all the atomic operations.
502 auto insertionPoint = firOpBuilder.saveInsertionPoint();
503 if (atomicCaptureOp) {
504 assert(atomicCaptureStmtCtx && "must specify statement context");
505 firOpBuilder.setInsertionPoint(atomicCaptureOp);
506 // Any clean-ups associated with the expression lowering
507 // must also be generated outside of the atomic update operation
508 // and after the atomic capture operation.
509 // The atomicCaptureStmtCtx will be finalized at the end
510 // of the atomic capture operation generation.
511 stmtCtxPtr = atomicCaptureStmtCtx;
512 }
513 mlir::Value nonAtomicVal;
514 for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
515 nonAtomicVal = fir::getBase(converter.genExprValue(
516 currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
517 exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
518 }
519 if (atomicCaptureOp)
520 firOpBuilder.restoreInsertionPoint(insertionPoint);
521 }
522
523 mlir::Operation *atomicUpdateOp = nullptr;
524 atomicUpdateOp =
525 firOpBuilder.create<mlir::acc::AtomicUpdateOp>(currentLocation, lhsAddr);
526
527 llvm::SmallVector<mlir::Type> varTys = {varType};
528 llvm::SmallVector<mlir::Location> locs = {currentLocation};
529 firOpBuilder.createBlock(&atomicUpdateOp->getRegion(index: 0), {}, varTys, locs);
530 mlir::Value val =
531 fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
532
533 exprValueOverrides.try_emplace(
534 Fortran::semantics::GetExpr(assignmentStmtVariable), val);
535 {
536 // statement context inside the atomic block.
537 converter.overrideExprValues(&exprValueOverrides);
538 Fortran::lower::StatementContext atomicStmtCtx;
539 mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
540 *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
541 mlir::Value convertResult =
542 firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
543 firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
544 converter.resetExprOverrides();
545 }
546 firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
547}
548
549/// Processes an atomic construct with write clause.
550void genAtomicWrite(Fortran::lower::AbstractConverter &converter,
551 const Fortran::parser::AccAtomicWrite &atomicWrite,
552 mlir::Location loc) {
553 const Fortran::parser::AssignmentStmt &stmt =
554 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
555 atomicWrite.t)
556 .statement;
557 const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v;
558 Fortran::lower::StatementContext stmtCtx;
559 // Get the value and address of atomic write operands.
560 mlir::Value rhsExpr =
561 fir::getBase(converter.genExprValue(assign.rhs, stmtCtx));
562 mlir::Value lhsAddr =
563 fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx));
564 genAtomicWriteStatement(converter, lhsAddr, rhsExpr, loc);
565}
566
567/// Processes an atomic construct with read clause.
568void genAtomicRead(Fortran::lower::AbstractConverter &converter,
569 const Fortran::parser::AccAtomicRead &atomicRead,
570 mlir::Location loc) {
571 const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
572 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
573 atomicRead.t)
574 .statement.t);
575 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
576 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
577 atomicRead.t)
578 .statement.t);
579
580 Fortran::lower::StatementContext stmtCtx;
581 const Fortran::semantics::SomeExpr &fromExpr =
582 *Fortran::semantics::GetExpr(assignmentStmtExpr);
583 mlir::Type elementType = converter.genType(fromExpr);
584 mlir::Value fromAddress =
585 fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
586 mlir::Value toAddress = fir::getBase(converter.genExprAddr(
587 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
588 genAtomicCaptureStatement(converter, fromAddress, toAddress, elementType,
589 loc);
590}
591
592/// Processes an atomic construct with update clause.
593void genAtomicUpdate(Fortran::lower::AbstractConverter &converter,
594 const Fortran::parser::AccAtomicUpdate &atomicUpdate,
595 mlir::Location loc) {
596 const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
597 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
598 atomicUpdate.t)
599 .statement.t);
600 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
601 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
602 atomicUpdate.t)
603 .statement.t);
604
605 Fortran::lower::StatementContext stmtCtx;
606 mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
607 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
608 mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
609 genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable,
610 assignmentStmtExpr, loc);
611}
612
613/// Processes an atomic construct with capture clause.
614void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
615 const Fortran::parser::AccAtomicCapture &atomicCapture,
616 mlir::Location loc) {
617 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
618
619 const Fortran::parser::AssignmentStmt &stmt1 =
620 std::get<Fortran::parser::AccAtomicCapture::Stmt1>(atomicCapture.t)
621 .v.statement;
622 const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
623 const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
624 const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
625 const Fortran::parser::AssignmentStmt &stmt2 =
626 std::get<Fortran::parser::AccAtomicCapture::Stmt2>(atomicCapture.t)
627 .v.statement;
628 const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
629 const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
630 const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
631
632 // Pre-evaluate expressions to be used in the various operations inside
633 // `atomic.capture` since it is not desirable to have anything other than
634 // a `atomic.read`, `atomic.write`, or `atomic.update` operation
635 // inside `atomic.capture`
636 Fortran::lower::StatementContext stmtCtx;
637 // LHS evaluations are common to all combinations of `atomic.capture`
638 mlir::Value stmt1LHSArg =
639 fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
640 mlir::Value stmt2LHSArg =
641 fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
642
643 // Type information used in generation of `atomic.update` operation
644 mlir::Type stmt1VarType =
645 fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
646 mlir::Type stmt2VarType =
647 fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
648
649 mlir::Operation *atomicCaptureOp = nullptr;
650 atomicCaptureOp = firOpBuilder.create<mlir::acc::AtomicCaptureOp>(loc);
651
652 firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(index: 0)));
653 mlir::Block &block = atomicCaptureOp->getRegion(index: 0).back();
654 firOpBuilder.setInsertionPointToStart(&block);
655 if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
656 if (Fortran::evaluate::CheckForSymbolMatch(
657 Fortran::semantics::GetExpr(stmt2Var),
658 Fortran::semantics::GetExpr(stmt2Expr))) {
659 // Atomic capture construct is of the form [capture-stmt, update-stmt]
660 const Fortran::semantics::SomeExpr &fromExpr =
661 *Fortran::semantics::GetExpr(stmt1Expr);
662 mlir::Type elementType = converter.genType(fromExpr);
663 genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
664 elementType, loc);
665 genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var,
666 stmt2Expr, loc, atomicCaptureOp, &stmtCtx);
667 } else {
668 // Atomic capture construct is of the form [capture-stmt, write-stmt]
669 firOpBuilder.setInsertionPoint(atomicCaptureOp);
670 mlir::Value stmt2RHSArg =
671 fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
672 firOpBuilder.setInsertionPointToStart(&block);
673 const Fortran::semantics::SomeExpr &fromExpr =
674 *Fortran::semantics::GetExpr(stmt1Expr);
675 mlir::Type elementType = converter.genType(fromExpr);
676 genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
677 elementType, loc);
678 genAtomicWriteStatement(converter, stmt2LHSArg, stmt2RHSArg, loc);
679 }
680 } else {
681 // Atomic capture construct is of the form [update-stmt, capture-stmt]
682 const Fortran::semantics::SomeExpr &fromExpr =
683 *Fortran::semantics::GetExpr(stmt2Expr);
684 mlir::Type elementType = converter.genType(fromExpr);
685 genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var,
686 stmt1Expr, loc, atomicCaptureOp, &stmtCtx);
687 genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType,
688 loc);
689 }
690 firOpBuilder.setInsertionPointToEnd(&block);
691 firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
692 // The clean-ups associated with the statements inside the capture
693 // construct must be generated after the AtomicCaptureOp.
694 firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
695}
696
697template <typename Op>
698static void
699genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
700 Fortran::lower::AbstractConverter &converter,
701 Fortran::semantics::SemanticsContext &semanticsContext,
702 Fortran::lower::StatementContext &stmtCtx,
703 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
704 mlir::acc::DataClause dataClause, bool structured,
705 bool implicit, llvm::ArrayRef<mlir::Value> async,
706 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
707 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
708 bool setDeclareAttr = false) {
709 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
710 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
711 for (const auto &accObject : objectList.v) {
712 llvm::SmallVector<mlir::Value> bounds;
713 std::stringstream asFortran;
714 mlir::Location operandLocation = genOperandLocation(converter, accObject);
715 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
716 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
717 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
718 fir::factory::AddrAndBoundsInfo info =
719 Fortran::lower::gatherDataOperandAddrAndBounds<
720 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
721 converter, builder, semanticsContext, stmtCtx, symbol, designator,
722 operandLocation, asFortran, bounds,
723 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
724 /*genDefaultBounds=*/generateDefaultBounds,
725 /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
726 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
727
728 // If the input value is optional and is not a descriptor, we use the
729 // rawInput directly.
730 mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) !=
731 fir::unwrapRefType(info.rawInput.getType())) &&
732 info.isPresent)
733 ? info.rawInput
734 : info.addr;
735 Op op = createDataEntryOp<Op>(
736 builder, operandLocation, baseAddr, asFortran, bounds, structured,
737 implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
738 asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
739 dataOperands.push_back(op.getAccVar());
740 }
741}
742
743template <typename EntryOp, typename ExitOp>
744static void genDeclareDataOperandOperations(
745 const Fortran::parser::AccObjectList &objectList,
746 Fortran::lower::AbstractConverter &converter,
747 Fortran::semantics::SemanticsContext &semanticsContext,
748 Fortran::lower::StatementContext &stmtCtx,
749 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
750 mlir::acc::DataClause dataClause, bool structured, bool implicit) {
751 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
752 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
753 for (const auto &accObject : objectList.v) {
754 llvm::SmallVector<mlir::Value> bounds;
755 std::stringstream asFortran;
756 mlir::Location operandLocation = genOperandLocation(converter, accObject);
757 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
758 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
759 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
760 fir::factory::AddrAndBoundsInfo info =
761 Fortran::lower::gatherDataOperandAddrAndBounds<
762 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
763 converter, builder, semanticsContext, stmtCtx, symbol, designator,
764 operandLocation, asFortran, bounds,
765 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
766 /*genDefaultBounds=*/generateDefaultBounds,
767 /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
768 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
769 EntryOp op = createDataEntryOp<EntryOp>(
770 builder, operandLocation, info.addr, asFortran, bounds, structured,
771 implicit, dataClause, info.addr.getType(),
772 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
773 dataOperands.push_back(op.getAccVar());
774 addDeclareAttr(builder, op.getVar().getDefiningOp(), dataClause);
775 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
776 mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
777 modBuilder.setInsertionPointAfter(builder.getFunction());
778 std::string prefix = converter.mangleName(symbol);
779 createDeclareAllocFuncWithArg<EntryOp>(
780 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
781 asFortran, dataClause);
782 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
783 createDeclareDeallocFuncWithArg<ExitOp>(
784 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
785 asFortran, dataClause);
786 }
787 }
788}
789
790template <typename EntryOp, typename ExitOp, typename Clause>
791static void genDeclareDataOperandOperationsWithModifier(
792 const Clause *x, Fortran::lower::AbstractConverter &converter,
793 Fortran::semantics::SemanticsContext &semanticsContext,
794 Fortran::lower::StatementContext &stmtCtx,
795 Fortran::parser::AccDataModifier::Modifier mod,
796 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
797 const mlir::acc::DataClause clause,
798 const mlir::acc::DataClause clauseWithModifier) {
799 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
800 const auto &accObjectList =
801 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
802 const auto &modifier =
803 std::get<std::optional<Fortran::parser::AccDataModifier>>(
804 listWithModifier.t);
805 mlir::acc::DataClause dataClause =
806 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
807 genDeclareDataOperandOperations<EntryOp, ExitOp>(
808 accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
809 dataClause,
810 /*structured=*/true, /*implicit=*/false);
811}
812
813template <typename EntryOp, typename ExitOp>
814static void
815genDataExitOperations(fir::FirOpBuilder &builder,
816 llvm::SmallVector<mlir::Value> operands, bool structured,
817 std::optional<mlir::Location> exitLoc = std::nullopt) {
818 for (mlir::Value operand : operands) {
819 auto entryOp = mlir::dyn_cast_or_null<EntryOp>(operand.getDefiningOp());
820 assert(entryOp && "data entry op expected");
821 mlir::Location opLoc = exitLoc ? *exitLoc : entryOp.getLoc();
822 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
823 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
824 builder.create<ExitOp>(
825 opLoc, entryOp.getAccVar(), entryOp.getVar(), entryOp.getVarType(),
826 entryOp.getBounds(), entryOp.getAsyncOperands(),
827 entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
828 entryOp.getDataClause(), structured, entryOp.getImplicit(),
829 builder.getStringAttr(*entryOp.getName()));
830 else
831 builder.create<ExitOp>(
832 opLoc, entryOp.getAccVar(), entryOp.getBounds(),
833 entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(),
834 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), structured,
835 entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
836 }
837}
838
839fir::ShapeOp genShapeOp(mlir::OpBuilder &builder, fir::SequenceType seqTy,
840 mlir::Location loc) {
841 llvm::SmallVector<mlir::Value> extents;
842 mlir::Type idxTy = builder.getIndexType();
843 for (auto extent : seqTy.getShape())
844 extents.push_back(builder.create<mlir::arith::ConstantOp>(
845 loc, idxTy, builder.getIntegerAttr(idxTy, extent)));
846 return builder.create<fir::ShapeOp>(loc, extents);
847}
848
849/// Get the initial value for reduction operator.
850template <typename R>
851static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) {
852 if (op == mlir::acc::ReductionOperator::AccMin) {
853 // min init value -> largest
854 if constexpr (std::is_same_v<R, llvm::APInt>) {
855 assert(ty.isIntOrIndex() && "expect integer or index type");
856 return llvm::APInt::getSignedMaxValue(numBits: ty.getIntOrFloatBitWidth());
857 }
858 if constexpr (std::is_same_v<R, llvm::APFloat>) {
859 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(Val&: ty);
860 assert(floatTy && "expect float type");
861 return llvm::APFloat::getLargest(Sem: floatTy.getFloatSemantics(),
862 /*negative=*/Negative: false);
863 }
864 } else if (op == mlir::acc::ReductionOperator::AccMax) {
865 // max init value -> smallest
866 if constexpr (std::is_same_v<R, llvm::APInt>) {
867 assert(ty.isIntOrIndex() && "expect integer or index type");
868 return llvm::APInt::getSignedMinValue(numBits: ty.getIntOrFloatBitWidth());
869 }
870 if constexpr (std::is_same_v<R, llvm::APFloat>) {
871 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(Val&: ty);
872 assert(floatTy && "expect float type");
873 return llvm::APFloat::getSmallest(Sem: floatTy.getFloatSemantics(),
874 /*negative=*/Negative: true);
875 }
876 } else if (op == mlir::acc::ReductionOperator::AccIand) {
877 if constexpr (std::is_same_v<R, llvm::APInt>) {
878 assert(ty.isIntOrIndex() && "expect integer type");
879 unsigned bits = ty.getIntOrFloatBitWidth();
880 return llvm::APInt::getAllOnes(numBits: bits);
881 }
882 } else {
883 assert(op != mlir::acc::ReductionOperator::AccNone);
884 // +, ior, ieor init value -> 0
885 // * init value -> 1
886 int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0;
887 if constexpr (std::is_same_v<R, llvm::APInt>) {
888 assert(ty.isIntOrIndex() && "expect integer or index type");
889 return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true);
890 }
891
892 if constexpr (std::is_same_v<R, llvm::APFloat>) {
893 assert(mlir::isa<mlir::FloatType>(ty) && "expect float type");
894 auto floatTy = mlir::dyn_cast<mlir::FloatType>(Val&: ty);
895 return llvm::APFloat(floatTy.getFloatSemantics(), value);
896 }
897
898 if constexpr (std::is_same_v<R, int64_t>)
899 return value;
900 }
901 llvm_unreachable("OpenACC reduction unsupported type");
902}
903
904/// Return a constant with the initial value for the reduction operator and
905/// type combination.
906static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
907 mlir::Location loc, mlir::Type ty,
908 mlir::acc::ReductionOperator op) {
909 if (op == mlir::acc::ReductionOperator::AccLand ||
910 op == mlir::acc::ReductionOperator::AccLor ||
911 op == mlir::acc::ReductionOperator::AccEqv ||
912 op == mlir::acc::ReductionOperator::AccNeqv) {
913 assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type");
914 bool value = true; // .true. for .and. and .eqv.
915 if (op == mlir::acc::ReductionOperator::AccLor ||
916 op == mlir::acc::ReductionOperator::AccNeqv)
917 value = false; // .false. for .or. and .neqv.
918 return builder.createBool(loc, value);
919 }
920 if (ty.isIntOrIndex())
921 return builder.create<mlir::arith::ConstantOp>(
922 loc, ty,
923 builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty)));
924 if (op == mlir::acc::ReductionOperator::AccMin ||
925 op == mlir::acc::ReductionOperator::AccMax) {
926 if (mlir::isa<mlir::ComplexType>(Val: ty))
927 llvm::report_fatal_error(
928 reason: "min/max reduction not supported for complex type");
929 if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
930 return builder.create<mlir::arith::ConstantOp>(
931 loc, ty,
932 builder.getFloatAttr(ty,
933 getReductionInitValue<llvm::APFloat>(op, ty)));
934 } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(Val&: ty)) {
935 return builder.create<mlir::arith::ConstantOp>(
936 loc, ty,
937 builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
938 } else if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(Val&: ty)) {
939 mlir::Type floatTy = cmplxTy.getElementType();
940 mlir::Value realInit = builder.createRealConstant(
941 loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
942 mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0);
943 return fir::factory::Complex{builder, loc}.createComplex(cmplxTy, realInit,
944 imagInit);
945 }
946
947 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
948 return getReductionInitValue(builder, loc, seqTy.getEleTy(), op);
949
950 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
951 return getReductionInitValue(builder, loc, boxTy.getEleTy(), op);
952
953 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
954 return getReductionInitValue(builder, loc, heapTy.getEleTy(), op);
955
956 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
957 return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op);
958
959 llvm::report_fatal_error(reason: "Unsupported OpenACC reduction type");
960}
961
962template <typename RecipeOp>
963static RecipeOp genRecipeOp(
964 fir::FirOpBuilder &builder, mlir::ModuleOp mod, llvm::StringRef recipeName,
965 mlir::Location loc, mlir::Type ty,
966 mlir::acc::ReductionOperator op = mlir::acc::ReductionOperator::AccNone) {
967 mlir::OpBuilder modBuilder(mod.getBodyRegion());
968 RecipeOp recipe;
969 if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) {
970 recipe = modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName,
971 ty, op);
972 } else {
973 recipe = modBuilder.create<RecipeOp>(loc, recipeName, ty);
974 }
975
976 llvm::SmallVector<mlir::Type> argsTy{ty};
977 llvm::SmallVector<mlir::Location> argsLoc{loc};
978 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
979 if (auto seqTy =
980 mlir::dyn_cast_or_null<fir::SequenceType>(refTy.getEleTy())) {
981 if (seqTy.hasDynamicExtents()) {
982 mlir::Type idxTy = builder.getIndexType();
983 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
984 argsTy.push_back(Elt: idxTy);
985 argsLoc.push_back(Elt: loc);
986 }
987 }
988 }
989 }
990 auto initBlock = builder.createBlock(
991 &recipe.getInitRegion(), recipe.getInitRegion().end(), argsTy, argsLoc);
992 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
993 mlir::Value initValue;
994 if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) {
995 assert(op != mlir::acc::ReductionOperator::AccNone);
996 initValue = getReductionInitValue(builder, loc, fir::unwrapRefType(ty), op);
997 }
998
999 // Since we reuse the same recipe for all variables of the same type - we
1000 // cannot use the actual variable name. Thus use a temporary name.
1001 llvm::StringRef initName;
1002 if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>)
1003 initName = accReductionInitName;
1004 else
1005 initName = accPrivateInitName;
1006
1007 auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty);
1008 assert(mappableTy &&
1009 "Expected that all variable types are considered mappable");
1010 auto retVal = mappableTy.generatePrivateInit(
1011 builder, loc,
1012 mlir::cast<mlir::TypedValue<mlir::acc::MappableType>>(
1013 initBlock->getArgument(0)),
1014 initName,
1015 initBlock->getArguments().take_back(initBlock->getArguments().size() - 1),
1016 initValue);
1017 builder.create<mlir::acc::YieldOp>(loc, retVal ? retVal
1018 : initBlock->getArgument(0));
1019 return recipe;
1020}
1021
1022mlir::acc::PrivateRecipeOp
1023Fortran::lower::createOrGetPrivateRecipe(fir::FirOpBuilder &builder,
1024 llvm::StringRef recipeName,
1025 mlir::Location loc, mlir::Type ty) {
1026 mlir::ModuleOp mod =
1027 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
1028 if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName))
1029 return recipe;
1030
1031 auto ip = builder.saveInsertionPoint();
1032 auto recipe = genRecipeOp<mlir::acc::PrivateRecipeOp>(builder, mod,
1033 recipeName, loc, ty);
1034 builder.restoreInsertionPoint(ip);
1035 return recipe;
1036}
1037
1038/// Check if the DataBoundsOp is a constant bound (lb and ub are constants or
1039/// extent is a constant).
1040bool isConstantBound(mlir::acc::DataBoundsOp &op) {
1041 if (op.getLowerbound() && fir::getIntIfConstant(op.getLowerbound()) &&
1042 op.getUpperbound() && fir::getIntIfConstant(op.getUpperbound()))
1043 return true;
1044 if (op.getExtent() && fir::getIntIfConstant(op.getExtent()))
1045 return true;
1046 return false;
1047}
1048
1049/// Return true iff all the bounds are expressed with constant values.
1050bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
1051 for (auto bound : bounds) {
1052 auto dataBound =
1053 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1054 assert(dataBound && "Must be DataBoundOp operation");
1055 if (!isConstantBound(dataBound))
1056 return false;
1057 }
1058 return true;
1059}
1060
1061static llvm::SmallVector<mlir::Value>
1062genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
1063 mlir::acc::DataBoundsOp &dataBound) {
1064 mlir::Type idxTy = builder.getIndexType();
1065 mlir::Value lb, ub, step;
1066 if (dataBound.getLowerbound() &&
1067 fir::getIntIfConstant(dataBound.getLowerbound()) &&
1068 dataBound.getUpperbound() &&
1069 fir::getIntIfConstant(dataBound.getUpperbound())) {
1070 lb = builder.createIntegerConstant(
1071 loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound()));
1072 ub = builder.createIntegerConstant(
1073 loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound()));
1074 step = builder.createIntegerConstant(loc, idxTy, 1);
1075 } else if (dataBound.getExtent()) {
1076 lb = builder.createIntegerConstant(loc, idxTy, 0);
1077 ub = builder.createIntegerConstant(
1078 loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1);
1079 step = builder.createIntegerConstant(loc, idxTy, 1);
1080 } else {
1081 llvm::report_fatal_error(reason: "Expect constant lb/ub or extent");
1082 }
1083 return {lb, ub, step};
1084}
1085
1086static mlir::Value genShapeFromBoundsOrArgs(
1087 mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
1088 const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
1089 llvm::SmallVector<mlir::Value> args;
1090 if (bounds.empty() && seqTy) {
1091 if (seqTy.hasDynamicExtents()) {
1092 assert(!arguments.empty() && "arguments must hold the entity");
1093 auto entity = hlfir::Entity{arguments[0]};
1094 return hlfir::genShape(loc, builder, entity);
1095 }
1096 return genShapeOp(builder, seqTy, loc).getResult();
1097 } else if (areAllBoundConstant(bounds)) {
1098 for (auto bound : llvm::reverse(C: bounds)) {
1099 auto dataBound =
1100 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1101 args.append(genConstantBounds(builder, loc, dataBound));
1102 }
1103 } else {
1104 assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
1105 "Expect 3 block arguments per dimension");
1106 for (auto arg : arguments.drop_front(n: 2))
1107 args.push_back(Elt: arg);
1108 }
1109
1110 assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
1111 llvm::SmallVector<mlir::Value> extents;
1112 mlir::Type idxTy = builder.getIndexType();
1113 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
1114 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
1115 for (unsigned i = 0; i < args.size(); i += 3) {
1116 mlir::Value s1 =
1117 builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
1118 mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
1119 mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
1120 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
1121 loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
1122 mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
1123 extents.push_back(Elt: ext);
1124 }
1125 return builder.create<fir::ShapeOp>(loc, extents);
1126}
1127
1128static hlfir::DesignateOp::Subscripts
1129getSubscriptsFromArgs(mlir::ValueRange args) {
1130 hlfir::DesignateOp::Subscripts triplets;
1131 for (unsigned i = 2; i < args.size(); i += 3)
1132 triplets.emplace_back(
1133 hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
1134 return triplets;
1135}
1136
1137static hlfir::Entity genDesignateWithTriplets(
1138 fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
1139 hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
1140 llvm::SmallVector<mlir::Value> lenParams;
1141 hlfir::genLengthParameters(loc, builder, entity, lenParams);
1142 auto designate = builder.create<hlfir::DesignateOp>(
1143 loc, entity.getBase().getType(), entity, /*component=*/"",
1144 /*componentShape=*/mlir::Value{}, triplets,
1145 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
1146 lenParams);
1147 return hlfir::Entity{designate.getResult()};
1148}
1149
1150mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
1151 fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
1152 mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
1153 mlir::ModuleOp mod =
1154 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
1155 if (auto recipe =
1156 mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName))
1157 return recipe;
1158
1159 auto ip = builder.saveInsertionPoint();
1160 auto recipe = genRecipeOp<mlir::acc::FirstprivateRecipeOp>(
1161 builder, mod, recipeName, loc, ty);
1162 bool allConstantBound = areAllBoundConstant(bounds);
1163 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
1164 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
1165 if (!allConstantBound) {
1166 for (mlir::Value bound : llvm::reverse(bounds)) {
1167 auto dataBound =
1168 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1169 argsTy.push_back(dataBound.getLowerbound().getType());
1170 argsLoc.push_back(dataBound.getLowerbound().getLoc());
1171 argsTy.push_back(dataBound.getUpperbound().getType());
1172 argsLoc.push_back(dataBound.getUpperbound().getLoc());
1173 argsTy.push_back(dataBound.getStartIdx().getType());
1174 argsLoc.push_back(dataBound.getStartIdx().getLoc());
1175 }
1176 }
1177 builder.createBlock(&recipe.getCopyRegion(), recipe.getCopyRegion().end(),
1178 argsTy, argsLoc);
1179
1180 builder.setInsertionPointToEnd(&recipe.getCopyRegion().back());
1181 ty = fir::unwrapRefType(ty);
1182 if (fir::isa_trivial(ty)) {
1183 mlir::Value initValue = builder.create<fir::LoadOp>(
1184 loc, recipe.getCopyRegion().front().getArgument(0));
1185 builder.create<fir::StoreOp>(loc, initValue,
1186 recipe.getCopyRegion().front().getArgument(1));
1187 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
1188 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
1189 auto shape = genShapeFromBoundsOrArgs(
1190 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
1191
1192 auto leftDeclOp = builder.create<hlfir::DeclareOp>(
1193 loc, recipe.getCopyRegion().getArgument(0), llvm::StringRef{}, shape,
1194 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
1195 fir::FortranVariableFlagsAttr{});
1196 auto rightDeclOp = builder.create<hlfir::DeclareOp>(
1197 loc, recipe.getCopyRegion().getArgument(1), llvm::StringRef{}, shape,
1198 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
1199 fir::FortranVariableFlagsAttr{});
1200
1201 hlfir::DesignateOp::Subscripts triplets =
1202 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
1203 auto leftEntity = hlfir::Entity{leftDeclOp.getBase()};
1204 auto left =
1205 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
1206 auto rightEntity = hlfir::Entity{rightDeclOp.getBase()};
1207 auto right =
1208 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
1209
1210 firBuilder.create<hlfir::AssignOp>(loc, left, right);
1211
1212 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
1213 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
1214 llvm::SmallVector<mlir::Value> tripletArgs;
1215 mlir::Type innerTy = fir::extractSequenceType(boxTy);
1216 fir::SequenceType seqTy =
1217 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
1218 if (!seqTy)
1219 TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
1220
1221 auto shape = genShapeFromBoundsOrArgs(
1222 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
1223 hlfir::DesignateOp::Subscripts triplets =
1224 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
1225 auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
1226 auto left =
1227 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
1228 auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
1229 auto right =
1230 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
1231 firBuilder.create<hlfir::AssignOp>(loc, left, right);
1232 }
1233
1234 builder.create<mlir::acc::TerminatorOp>(loc);
1235 builder.restoreInsertionPoint(ip);
1236 return recipe;
1237}
1238
1239/// Get a string representation of the bounds.
1240std::string getBoundsString(llvm::SmallVector<mlir::Value> &bounds) {
1241 std::stringstream boundStr;
1242 if (!bounds.empty())
1243 boundStr << "_section_";
1244 llvm::interleave(
1245 c: bounds,
1246 each_fn: [&](mlir::Value bound) {
1247 auto boundsOp =
1248 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1249 if (boundsOp.getLowerbound() &&
1250 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
1251 boundsOp.getUpperbound() &&
1252 fir::getIntIfConstant(boundsOp.getUpperbound())) {
1253 boundStr << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound())
1254 << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound());
1255 } else if (boundsOp.getExtent() &&
1256 fir::getIntIfConstant(boundsOp.getExtent())) {
1257 boundStr << "ext" << *fir::getIntIfConstant(boundsOp.getExtent());
1258 } else {
1259 boundStr << "?";
1260 }
1261 },
1262 between_fn: [&] { boundStr << "x"; });
1263 return boundStr.str();
1264}
1265
1266/// Rebuild the array type from the acc.bounds operation with constant
1267/// lowerbound/upperbound or extent.
1268mlir::Type getTypeFromBounds(llvm::SmallVector<mlir::Value> &bounds,
1269 mlir::Type ty) {
1270 auto seqTy =
1271 mlir::dyn_cast_or_null<fir::SequenceType>(fir::unwrapRefType(ty));
1272 if (!bounds.empty() && seqTy) {
1273 llvm::SmallVector<int64_t> shape;
1274 for (auto b : bounds) {
1275 auto boundsOp =
1276 mlir::dyn_cast<mlir::acc::DataBoundsOp>(b.getDefiningOp());
1277 if (boundsOp.getLowerbound() &&
1278 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
1279 boundsOp.getUpperbound() &&
1280 fir::getIntIfConstant(boundsOp.getUpperbound())) {
1281 int64_t ext = *fir::getIntIfConstant(boundsOp.getUpperbound()) -
1282 *fir::getIntIfConstant(boundsOp.getLowerbound()) + 1;
1283 shape.push_back(Elt: ext);
1284 } else if (boundsOp.getExtent() &&
1285 fir::getIntIfConstant(boundsOp.getExtent())) {
1286 shape.push_back(*fir::getIntIfConstant(boundsOp.getExtent()));
1287 } else {
1288 return ty; // TODO: handle dynamic shaped array slice.
1289 }
1290 }
1291 if (shape.empty() || shape.size() != bounds.size())
1292 return ty;
1293 auto newSeqTy = fir::SequenceType::get(shape, seqTy.getEleTy());
1294 if (mlir::isa<fir::ReferenceType, fir::PointerType>(ty))
1295 return fir::ReferenceType::get(newSeqTy);
1296 return newSeqTy;
1297 }
1298 return ty;
1299}
1300
1301template <typename RecipeOp>
1302static void genPrivatizationRecipes(
1303 const Fortran::parser::AccObjectList &objectList,
1304 Fortran::lower::AbstractConverter &converter,
1305 Fortran::semantics::SemanticsContext &semanticsContext,
1306 Fortran::lower::StatementContext &stmtCtx,
1307 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
1308 llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
1309 llvm::ArrayRef<mlir::Value> async,
1310 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
1311 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
1312 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1313 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
1314 for (const auto &accObject : objectList.v) {
1315 llvm::SmallVector<mlir::Value> bounds;
1316 std::stringstream asFortran;
1317 mlir::Location operandLocation = genOperandLocation(converter, accObject);
1318 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
1319 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
1320 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
1321 fir::factory::AddrAndBoundsInfo info =
1322 Fortran::lower::gatherDataOperandAddrAndBounds<
1323 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
1324 converter, builder, semanticsContext, stmtCtx, symbol, designator,
1325 operandLocation, asFortran, bounds,
1326 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
1327 /*genDefaultBounds=*/generateDefaultBounds,
1328 /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
1329 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
1330
1331 RecipeOp recipe;
1332 mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
1333 if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
1334 std::string recipeName =
1335 fir::getTypeAsString(retTy, converter.getKindMap(),
1336 Fortran::lower::privatizationRecipePrefix);
1337 recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
1338 operandLocation, retTy);
1339 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
1340 builder, operandLocation, info.addr, asFortran, bounds, true,
1341 /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async,
1342 asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
1343 dataOperands.push_back(op.getAccVar());
1344 } else {
1345 std::string suffix =
1346 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
1347 std::string recipeName = fir::getTypeAsString(
1348 retTy, converter.getKindMap(), "firstprivatization" + suffix);
1349 recipe = Fortran::lower::createOrGetFirstprivateRecipe(
1350 builder, recipeName, operandLocation, retTy, bounds);
1351 auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
1352 builder, operandLocation, info.addr, asFortran, bounds, true,
1353 /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy,
1354 async, asyncDeviceTypes, asyncOnlyDeviceTypes,
1355 /*unwrapBoxAddr=*/true);
1356 dataOperands.push_back(op.getAccVar());
1357 }
1358 privatizationRecipes.push_back(mlir::SymbolRefAttr::get(
1359 builder.getContext(), recipe.getSymName().str()));
1360 }
1361}
1362
1363/// Return the corresponding enum value for the mlir::acc::ReductionOperator
1364/// from the parser representation.
1365static mlir::acc::ReductionOperator
1366getReductionOperator(const Fortran::parser::ReductionOperator &op) {
1367 switch (op.v) {
1368 case Fortran::parser::ReductionOperator::Operator::Plus:
1369 return mlir::acc::ReductionOperator::AccAdd;
1370 case Fortran::parser::ReductionOperator::Operator::Multiply:
1371 return mlir::acc::ReductionOperator::AccMul;
1372 case Fortran::parser::ReductionOperator::Operator::Max:
1373 return mlir::acc::ReductionOperator::AccMax;
1374 case Fortran::parser::ReductionOperator::Operator::Min:
1375 return mlir::acc::ReductionOperator::AccMin;
1376 case Fortran::parser::ReductionOperator::Operator::Iand:
1377 return mlir::acc::ReductionOperator::AccIand;
1378 case Fortran::parser::ReductionOperator::Operator::Ior:
1379 return mlir::acc::ReductionOperator::AccIor;
1380 case Fortran::parser::ReductionOperator::Operator::Ieor:
1381 return mlir::acc::ReductionOperator::AccXor;
1382 case Fortran::parser::ReductionOperator::Operator::And:
1383 return mlir::acc::ReductionOperator::AccLand;
1384 case Fortran::parser::ReductionOperator::Operator::Or:
1385 return mlir::acc::ReductionOperator::AccLor;
1386 case Fortran::parser::ReductionOperator::Operator::Eqv:
1387 return mlir::acc::ReductionOperator::AccEqv;
1388 case Fortran::parser::ReductionOperator::Operator::Neqv:
1389 return mlir::acc::ReductionOperator::AccNeqv;
1390 }
1391 llvm_unreachable("unexpected reduction operator");
1392}
1393
1394template <typename Op>
1395static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder,
1396 mlir::Location loc, mlir::Value value1,
1397 mlir::Value value2) {
1398 mlir::Type i1 = builder.getI1Type();
1399 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1400 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1401 mlir::Value combined = builder.create<Op>(loc, v1, v2);
1402 return builder.create<fir::ConvertOp>(loc, value1.getType(), combined);
1403}
1404
1405static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder,
1406 mlir::Location loc,
1407 mlir::arith::CmpIPredicate pred,
1408 mlir::Value value1,
1409 mlir::Value value2) {
1410 mlir::Type i1 = builder.getI1Type();
1411 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1412 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1413 mlir::Value add = builder.create<mlir::arith::CmpIOp>(loc, pred, v1, v2);
1414 return builder.create<fir::ConvertOp>(loc, value1.getType(), add);
1415}
1416
1417static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
1418 mlir::Location loc,
1419 mlir::acc::ReductionOperator op,
1420 mlir::Type ty, mlir::Value value1,
1421 mlir::Value value2) {
1422 value1 = builder.loadIfRef(loc, value1);
1423 value2 = builder.loadIfRef(loc, value2);
1424 if (op == mlir::acc::ReductionOperator::AccAdd) {
1425 if (ty.isIntOrIndex())
1426 return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
1427 if (mlir::isa<mlir::FloatType>(ty))
1428 return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
1429 if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty))
1430 return builder.create<fir::AddcOp>(loc, value1, value2);
1431 TODO(loc, "reduction add type");
1432 }
1433
1434 if (op == mlir::acc::ReductionOperator::AccMul) {
1435 if (ty.isIntOrIndex())
1436 return builder.create<mlir::arith::MulIOp>(loc, value1, value2);
1437 if (mlir::isa<mlir::FloatType>(ty))
1438 return builder.create<mlir::arith::MulFOp>(loc, value1, value2);
1439 if (mlir::isa<mlir::ComplexType>(ty))
1440 return builder.create<fir::MulcOp>(loc, value1, value2);
1441 TODO(loc, "reduction mul type");
1442 }
1443
1444 if (op == mlir::acc::ReductionOperator::AccMin)
1445 return fir::genMin(builder, loc, {value1, value2});
1446
1447 if (op == mlir::acc::ReductionOperator::AccMax)
1448 return fir::genMax(builder, loc, {value1, value2});
1449
1450 if (op == mlir::acc::ReductionOperator::AccIand)
1451 return builder.create<mlir::arith::AndIOp>(loc, value1, value2);
1452
1453 if (op == mlir::acc::ReductionOperator::AccIor)
1454 return builder.create<mlir::arith::OrIOp>(loc, value1, value2);
1455
1456 if (op == mlir::acc::ReductionOperator::AccXor)
1457 return builder.create<mlir::arith::XOrIOp>(loc, value1, value2);
1458
1459 if (op == mlir::acc::ReductionOperator::AccLand)
1460 return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1,
1461 value2);
1462
1463 if (op == mlir::acc::ReductionOperator::AccLor)
1464 return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2);
1465
1466 if (op == mlir::acc::ReductionOperator::AccEqv)
1467 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq,
1468 value1, value2);
1469
1470 if (op == mlir::acc::ReductionOperator::AccNeqv)
1471 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne,
1472 value1, value2);
1473
1474 TODO(loc, "reduction operator");
1475}
1476
1477static hlfir::DesignateOp::Subscripts
1478getTripletsFromArgs(mlir::acc::ReductionRecipeOp recipe) {
1479 hlfir::DesignateOp::Subscripts triplets;
1480 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1481 i += 3)
1482 triplets.emplace_back(hlfir::DesignateOp::Triplet{
1483 recipe.getCombinerRegion().getArgument(i),
1484 recipe.getCombinerRegion().getArgument(i + 1),
1485 recipe.getCombinerRegion().getArgument(i + 2)});
1486 return triplets;
1487}
1488
1489static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
1490 mlir::acc::ReductionOperator op, mlir::Type ty,
1491 mlir::Value value1, mlir::Value value2,
1492 mlir::acc::ReductionRecipeOp &recipe,
1493 llvm::SmallVector<mlir::Value> &bounds,
1494 bool allConstantBound) {
1495 ty = fir::unwrapRefType(ty);
1496
1497 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1498 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
1499 llvm::SmallVector<fir::DoLoopOp> loops;
1500 llvm::SmallVector<mlir::Value> ivs;
1501 if (seqTy.hasDynamicExtents()) {
1502 auto shape =
1503 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds,
1504 recipe.getCombinerRegion().getArguments());
1505 auto v1DeclareOp = builder.create<hlfir::DeclareOp>(
1506 loc, value1, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1507 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
1508 auto v2DeclareOp = builder.create<hlfir::DeclareOp>(
1509 loc, value2, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1510 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
1511 hlfir::DesignateOp::Subscripts triplets = getTripletsFromArgs(recipe);
1512
1513 llvm::SmallVector<mlir::Value> lenParamsLeft;
1514 auto leftEntity = hlfir::Entity{v1DeclareOp.getBase()};
1515 hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
1516 auto leftDesignate = builder.create<hlfir::DesignateOp>(
1517 loc, v1DeclareOp.getBase().getType(), v1DeclareOp.getBase(),
1518 /*component=*/"",
1519 /*componentShape=*/mlir::Value{}, triplets,
1520 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1521 shape, lenParamsLeft);
1522 auto left = hlfir::Entity{leftDesignate.getResult()};
1523
1524 llvm::SmallVector<mlir::Value> lenParamsRight;
1525 auto rightEntity = hlfir::Entity{v2DeclareOp.getBase()};
1526 hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsLeft);
1527 auto rightDesignate = builder.create<hlfir::DesignateOp>(
1528 loc, v2DeclareOp.getBase().getType(), v2DeclareOp.getBase(),
1529 /*component=*/"",
1530 /*componentShape=*/mlir::Value{}, triplets,
1531 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1532 shape, lenParamsRight);
1533 auto right = hlfir::Entity{rightDesignate.getResult()};
1534
1535 llvm::SmallVector<mlir::Value, 1> typeParams;
1536 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1537 mlir::Location l, fir::FirOpBuilder &b,
1538 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1539 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1540 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1541 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1542 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1543 return hlfir::Entity{genScalarCombiner(
1544 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)};
1545 };
1546 mlir::Value elemental = hlfir::genElementalOp(
1547 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1548 /*isUnordered=*/true);
1549 builder.create<hlfir::AssignOp>(loc, elemental, v1DeclareOp.getBase());
1550 return;
1551 }
1552 if (bounds.empty()) {
1553 llvm::SmallVector<mlir::Value> extents;
1554 mlir::Type idxTy = builder.getIndexType();
1555 for (auto extent : seqTy.getShape()) {
1556 mlir::Value lb = builder.create<mlir::arith::ConstantOp>(
1557 loc, idxTy, builder.getIntegerAttr(idxTy, 0));
1558 mlir::Value ub = builder.create<mlir::arith::ConstantOp>(
1559 loc, idxTy, builder.getIntegerAttr(idxTy, extent - 1));
1560 mlir::Value step = builder.create<mlir::arith::ConstantOp>(
1561 loc, idxTy, builder.getIntegerAttr(idxTy, 1));
1562 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1563 /*unordered=*/false);
1564 builder.setInsertionPointToStart(loop.getBody());
1565 loops.push_back(loop);
1566 ivs.push_back(loop.getInductionVar());
1567 }
1568 } else if (allConstantBound) {
1569 // Use the constant bound directly in the combiner region so they do not
1570 // need to be passed as block argument.
1571 assert(!bounds.empty() &&
1572 "seq type with constant bounds cannot have empty bounds");
1573 for (auto bound : llvm::reverse(C&: bounds)) {
1574 auto dataBound =
1575 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1576 llvm::SmallVector<mlir::Value> values =
1577 genConstantBounds(builder, loc, dataBound);
1578 auto loop =
1579 builder.create<fir::DoLoopOp>(loc, values[0], values[1], values[2],
1580 /*unordered=*/false);
1581 builder.setInsertionPointToStart(loop.getBody());
1582 loops.push_back(loop);
1583 ivs.push_back(Elt: loop.getInductionVar());
1584 }
1585 } else {
1586 // Lowerbound, upperbound and step are passed as block arguments.
1587 [[maybe_unused]] unsigned nbRangeArgs =
1588 recipe.getCombinerRegion().getArguments().size() - 2;
1589 assert((nbRangeArgs / 3 == seqTy.getDimension()) &&
1590 "Expect 3 block arguments per dimension");
1591 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1592 i += 3) {
1593 mlir::Value lb = recipe.getCombinerRegion().getArgument(i);
1594 mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1);
1595 mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2);
1596 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1597 /*unordered=*/false);
1598 builder.setInsertionPointToStart(loop.getBody());
1599 loops.push_back(loop);
1600 ivs.push_back(Elt: loop.getInductionVar());
1601 }
1602 }
1603 auto addr1 = builder.create<fir::CoordinateOp>(loc, refTy, value1, ivs);
1604 auto addr2 = builder.create<fir::CoordinateOp>(loc, refTy, value2, ivs);
1605 auto load1 = builder.create<fir::LoadOp>(loc, addr1);
1606 auto load2 = builder.create<fir::LoadOp>(loc, addr2);
1607 mlir::Value res =
1608 genScalarCombiner(builder, loc, op, seqTy.getEleTy(), load1, load2);
1609 builder.create<fir::StoreOp>(loc, res, addr1);
1610 builder.setInsertionPointAfter(loops[0]);
1611 } else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
1612 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
1613 if (fir::isa_trivial(innerTy)) {
1614 mlir::Value boxAddr1 = value1, boxAddr2 = value2;
1615 if (fir::isBoxAddress(boxAddr1.getType()))
1616 boxAddr1 = builder.create<fir::LoadOp>(loc, boxAddr1);
1617 if (fir::isBoxAddress(boxAddr2.getType()))
1618 boxAddr2 = builder.create<fir::LoadOp>(loc, boxAddr2);
1619 boxAddr1 = builder.create<fir::BoxAddrOp>(loc, boxAddr1);
1620 boxAddr2 = builder.create<fir::BoxAddrOp>(loc, boxAddr2);
1621 auto leftEntity = hlfir::Entity{boxAddr1};
1622 auto rightEntity = hlfir::Entity{boxAddr2};
1623
1624 auto leftVal = hlfir::loadTrivialScalar(loc, builder, leftEntity);
1625 auto rightVal = hlfir::loadTrivialScalar(loc, builder, rightEntity);
1626 mlir::Value res =
1627 genScalarCombiner(builder, loc, op, innerTy, leftVal, rightVal);
1628 builder.create<hlfir::AssignOp>(loc, res, boxAddr1);
1629 } else {
1630 mlir::Type innerTy = fir::extractSequenceType(boxTy);
1631 fir::SequenceType seqTy =
1632 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
1633 if (!seqTy)
1634 TODO(loc, "Unsupported boxed type in OpenACC reduction combiner");
1635
1636 auto shape =
1637 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds,
1638 recipe.getCombinerRegion().getArguments());
1639 hlfir::DesignateOp::Subscripts triplets =
1640 getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
1641 auto leftEntity = hlfir::Entity{value1};
1642 if (fir::isBoxAddress(value1.getType()))
1643 leftEntity =
1644 hlfir::Entity{builder.create<fir::LoadOp>(loc, value1).getResult()};
1645 auto left =
1646 genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
1647 auto rightEntity = hlfir::Entity{value2};
1648 if (fir::isBoxAddress(value2.getType()))
1649 rightEntity =
1650 hlfir::Entity{builder.create<fir::LoadOp>(loc, value2).getResult()};
1651 auto right =
1652 genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
1653
1654 llvm::SmallVector<mlir::Value, 1> typeParams;
1655 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1656 mlir::Location l, fir::FirOpBuilder &b,
1657 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1658 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1659 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1660 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1661 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1662 return hlfir::Entity{genScalarCombiner(
1663 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)};
1664 };
1665 mlir::Value elemental = hlfir::genElementalOp(
1666 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1667 /*isUnordered=*/true);
1668 builder.create<hlfir::AssignOp>(loc, elemental, value1);
1669 }
1670 } else {
1671 mlir::Value res = genScalarCombiner(builder, loc, op, ty, value1, value2);
1672 builder.create<fir::StoreOp>(loc, res, value1);
1673 }
1674}
1675
1676mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
1677 fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
1678 mlir::Type ty, mlir::acc::ReductionOperator op,
1679 llvm::SmallVector<mlir::Value> &bounds) {
1680 mlir::ModuleOp mod =
1681 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
1682 if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
1683 return recipe;
1684
1685 auto ip = builder.saveInsertionPoint();
1686
1687 auto recipe = genRecipeOp<mlir::acc::ReductionRecipeOp>(
1688 builder, mod, recipeName, loc, ty, op);
1689
1690 // The two first block arguments are the two values to be combined.
1691 // The next arguments are the iteration ranges (lb, ub, step) to be used
1692 // for the combiner if needed.
1693 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
1694 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
1695 bool allConstantBound = areAllBoundConstant(bounds);
1696 if (!allConstantBound) {
1697 for (mlir::Value bound : llvm::reverse(bounds)) {
1698 auto dataBound =
1699 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1700 argsTy.push_back(dataBound.getLowerbound().getType());
1701 argsLoc.push_back(dataBound.getLowerbound().getLoc());
1702 argsTy.push_back(dataBound.getUpperbound().getType());
1703 argsLoc.push_back(dataBound.getUpperbound().getLoc());
1704 argsTy.push_back(dataBound.getStartIdx().getType());
1705 argsLoc.push_back(dataBound.getStartIdx().getLoc());
1706 }
1707 }
1708 builder.createBlock(&recipe.getCombinerRegion(),
1709 recipe.getCombinerRegion().end(), argsTy, argsLoc);
1710 builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
1711 mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
1712 mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
1713 genCombiner(builder, loc, op, ty, v1, v2, recipe, bounds, allConstantBound);
1714 builder.create<mlir::acc::YieldOp>(loc, v1);
1715 builder.restoreInsertionPoint(ip);
1716 return recipe;
1717}
1718
1719static bool isSupportedReductionType(mlir::Type ty) {
1720 ty = fir::unwrapRefType(ty);
1721 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
1722 return isSupportedReductionType(boxTy.getEleTy());
1723 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
1724 return isSupportedReductionType(seqTy.getEleTy());
1725 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
1726 return isSupportedReductionType(heapTy.getEleTy());
1727 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
1728 return isSupportedReductionType(ptrTy.getEleTy());
1729 return fir::isa_trivial(ty);
1730}
1731
1732static void
1733genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
1734 Fortran::lower::AbstractConverter &converter,
1735 Fortran::semantics::SemanticsContext &semanticsContext,
1736 Fortran::lower::StatementContext &stmtCtx,
1737 llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
1738 llvm::SmallVector<mlir::Attribute> &reductionRecipes,
1739 llvm::ArrayRef<mlir::Value> async,
1740 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
1741 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
1742 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1743 const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
1744 const auto &op = std::get<Fortran::parser::ReductionOperator>(objectList.t);
1745 mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
1746 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
1747 for (const auto &accObject : objects.v) {
1748 llvm::SmallVector<mlir::Value> bounds;
1749 std::stringstream asFortran;
1750 mlir::Location operandLocation = genOperandLocation(converter, accObject);
1751 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
1752 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
1753 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
1754 fir::factory::AddrAndBoundsInfo info =
1755 Fortran::lower::gatherDataOperandAddrAndBounds<
1756 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
1757 converter, builder, semanticsContext, stmtCtx, symbol, designator,
1758 operandLocation, asFortran, bounds,
1759 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
1760 /*genDefaultBounds=*/generateDefaultBounds,
1761 /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
1762 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
1763
1764 mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
1765 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
1766 reductionTy = seqTy.getEleTy();
1767
1768 if (!isSupportedReductionType(reductionTy))
1769 TODO(operandLocation, "reduction with unsupported type");
1770
1771 auto op = createDataEntryOp<mlir::acc::ReductionOp>(
1772 builder, operandLocation, info.addr, asFortran, bounds,
1773 /*structured=*/true, /*implicit=*/false,
1774 mlir::acc::DataClause::acc_reduction, info.addr.getType(), async,
1775 asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
1776 mlir::Type ty = op.getAccVar().getType();
1777 if (!areAllBoundConstant(bounds) ||
1778 fir::isAssumedShape(info.addr.getType()) ||
1779 fir::isAllocatableOrPointerArray(info.addr.getType()))
1780 ty = info.addr.getType();
1781 std::string suffix =
1782 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
1783 std::string recipeName = fir::getTypeAsString(
1784 ty, converter.getKindMap(),
1785 ("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
1786
1787 mlir::acc::ReductionRecipeOp recipe =
1788 Fortran::lower::createOrGetReductionRecipe(
1789 builder, recipeName, operandLocation, ty, mlirOp, bounds);
1790 reductionRecipes.push_back(mlir::SymbolRefAttr::get(
1791 builder.getContext(), recipe.getSymName().str()));
1792 reductionOperands.push_back(op.getAccVar());
1793 }
1794}
1795
1796template <typename Op, typename Terminator>
1797static Op
1798createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1799 mlir::Location returnLoc, Fortran::lower::pft::Evaluation &eval,
1800 const llvm::SmallVectorImpl<mlir::Value> &operands,
1801 const llvm::SmallVectorImpl<int32_t> &operandSegments,
1802 bool outerCombined = false,
1803 llvm::SmallVector<mlir::Type> retTy = {},
1804 mlir::Value yieldValue = {}, mlir::TypeRange argsTy = {},
1805 llvm::SmallVector<mlir::Location> locs = {}) {
1806 Op op = builder.create<Op>(loc, retTy, operands);
1807 builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs);
1808 mlir::Block &block = op.getRegion().back();
1809 builder.setInsertionPointToStart(&block);
1810
1811 op->setAttr(Op::getOperandSegmentSizeAttr(),
1812 builder.getDenseI32ArrayAttr(operandSegments));
1813
1814 // Place the insertion point to the start of the first block.
1815 builder.setInsertionPointToStart(&block);
1816
1817 // If it is an unstructured region and is not the outer region of a combined
1818 // construct, create empty blocks for all evaluations.
1819 if (eval.lowerAsUnstructured() && !outerCombined)
1820 Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp,
1821 mlir::acc::YieldOp>(
1822 builder, eval.getNestedEvaluations());
1823
1824 if (yieldValue) {
1825 if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
1826 Terminator yieldOp = builder.create<Terminator>(returnLoc, yieldValue);
1827 yieldValue.getDefiningOp()->moveBefore(yieldOp);
1828 } else {
1829 builder.create<Terminator>(returnLoc);
1830 }
1831 } else {
1832 builder.create<Terminator>(returnLoc);
1833 }
1834 builder.setInsertionPointToStart(&block);
1835 return op;
1836}
1837
1838static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
1839 const Fortran::parser::AccClause::Async *asyncClause,
1840 mlir::Value &async, bool &addAsyncAttr,
1841 Fortran::lower::StatementContext &stmtCtx) {
1842 const auto &asyncClauseValue = asyncClause->v;
1843 if (asyncClauseValue) { // async has a value.
1844 async = fir::getBase(converter.genExprValue(
1845 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1846 } else {
1847 addAsyncAttr = true;
1848 }
1849}
1850
1851static void
1852genAsyncClause(Fortran::lower::AbstractConverter &converter,
1853 const Fortran::parser::AccClause::Async *asyncClause,
1854 llvm::SmallVector<mlir::Value> &async,
1855 llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
1856 llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
1857 llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
1858 Fortran::lower::StatementContext &stmtCtx) {
1859 const auto &asyncClauseValue = asyncClause->v;
1860 if (asyncClauseValue) { // async has a value.
1861 mlir::Value asyncValue = fir::getBase(converter.genExprValue(
1862 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1863 for (auto deviceTypeAttr : deviceTypeAttrs) {
1864 async.push_back(Elt: asyncValue);
1865 asyncDeviceTypes.push_back(Elt: deviceTypeAttr);
1866 }
1867 } else {
1868 for (auto deviceTypeAttr : deviceTypeAttrs)
1869 asyncOnlyDeviceTypes.push_back(Elt: deviceTypeAttr);
1870 }
1871}
1872
1873static mlir::acc::DeviceType
1874getDeviceType(Fortran::common::OpenACCDeviceType device) {
1875 switch (device) {
1876 case Fortran::common::OpenACCDeviceType::Star:
1877 return mlir::acc::DeviceType::Star;
1878 case Fortran::common::OpenACCDeviceType::Default:
1879 return mlir::acc::DeviceType::Default;
1880 case Fortran::common::OpenACCDeviceType::Nvidia:
1881 return mlir::acc::DeviceType::Nvidia;
1882 case Fortran::common::OpenACCDeviceType::Radeon:
1883 return mlir::acc::DeviceType::Radeon;
1884 case Fortran::common::OpenACCDeviceType::Host:
1885 return mlir::acc::DeviceType::Host;
1886 case Fortran::common::OpenACCDeviceType::Multicore:
1887 return mlir::acc::DeviceType::Multicore;
1888 case Fortran::common::OpenACCDeviceType::None:
1889 return mlir::acc::DeviceType::None;
1890 }
1891 return mlir::acc::DeviceType::None;
1892}
1893
1894static void gatherDeviceTypeAttrs(
1895 fir::FirOpBuilder &builder,
1896 const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
1897 llvm::SmallVector<mlir::Attribute> &deviceTypes) {
1898 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1899 deviceTypeClause->v;
1900 for (const auto &deviceTypeExpr : deviceTypeExprList.v)
1901 deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
1902 builder.getContext(), getDeviceType(deviceTypeExpr.v)));
1903}
1904
1905static void genIfClause(Fortran::lower::AbstractConverter &converter,
1906 mlir::Location clauseLocation,
1907 const Fortran::parser::AccClause::If *ifClause,
1908 mlir::Value &ifCond,
1909 Fortran::lower::StatementContext &stmtCtx) {
1910 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1911 mlir::Value cond = fir::getBase(converter.genExprValue(
1912 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx, &clauseLocation));
1913 ifCond = firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
1914 cond);
1915}
1916
1917static void genWaitClause(Fortran::lower::AbstractConverter &converter,
1918 const Fortran::parser::AccClause::Wait *waitClause,
1919 llvm::SmallVectorImpl<mlir::Value> &operands,
1920 mlir::Value &waitDevnum, bool &addWaitAttr,
1921 Fortran::lower::StatementContext &stmtCtx) {
1922 const auto &waitClauseValue = waitClause->v;
1923 if (waitClauseValue) { // wait has a value.
1924 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1925 const auto &waitList =
1926 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1927 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1928 mlir::Value v = fir::getBase(
1929 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
1930 operands.push_back(v);
1931 }
1932
1933 const auto &waitDevnumValue =
1934 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1935 if (waitDevnumValue)
1936 waitDevnum = fir::getBase(converter.genExprValue(
1937 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
1938 } else {
1939 addWaitAttr = true;
1940 }
1941}
1942
1943static void genWaitClauseWithDeviceType(
1944 Fortran::lower::AbstractConverter &converter,
1945 const Fortran::parser::AccClause::Wait *waitClause,
1946 llvm::SmallVector<mlir::Value> &waitOperands,
1947 llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1948 llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1949 llvm::SmallVector<bool> &hasDevnums,
1950 llvm::SmallVector<int32_t> &waitOperandsSegments,
1951 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1952 Fortran::lower::StatementContext &stmtCtx) {
1953 const auto &waitClauseValue = waitClause->v;
1954 if (waitClauseValue) { // wait has a value.
1955 llvm::SmallVector<mlir::Value> waitValues;
1956
1957 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1958 const auto &waitDevnumValue =
1959 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1960 bool hasDevnum = false;
1961 if (waitDevnumValue) {
1962 waitValues.push_back(fir::getBase(converter.genExprValue(
1963 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
1964 hasDevnum = true;
1965 }
1966
1967 const auto &waitList =
1968 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1969 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1970 waitValues.push_back(fir::getBase(converter.genExprValue(
1971 *Fortran::semantics::GetExpr(value), stmtCtx)));
1972 }
1973
1974 for (auto deviceTypeAttr : deviceTypeAttrs) {
1975 for (auto value : waitValues)
1976 waitOperands.push_back(Elt: value);
1977 waitOperandsDeviceTypes.push_back(Elt: deviceTypeAttr);
1978 waitOperandsSegments.push_back(Elt: waitValues.size());
1979 hasDevnums.push_back(Elt: hasDevnum);
1980 }
1981 } else {
1982 for (auto deviceTypeAttr : deviceTypeAttrs)
1983 waitOnlyDeviceTypes.push_back(Elt: deviceTypeAttr);
1984 }
1985}
1986
1987mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
1988 const Fortran::semantics::Symbol &ivSym) {
1989 std::size_t ivTypeSize = ivSym.size();
1990 if (ivTypeSize == 0)
1991 llvm::report_fatal_error(reason: "unexpected induction variable size");
1992 // ivTypeSize is in bytes and IntegerType needs to be in bits.
1993 return builder.getIntegerType(ivTypeSize * 8);
1994}
1995
1996static void
1997privatizeIv(Fortran::lower::AbstractConverter &converter,
1998 const Fortran::semantics::Symbol &sym, mlir::Location loc,
1999 llvm::SmallVector<mlir::Type> &ivTypes,
2000 llvm::SmallVector<mlir::Location> &ivLocs,
2001 llvm::SmallVector<mlir::Value> &privateOperands,
2002 llvm::SmallVector<mlir::Value> &ivPrivate,
2003 llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
2004 bool isDoConcurrent = false) {
2005 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2006
2007 mlir::Type ivTy = getTypeFromIvTypeSize(builder, sym);
2008 ivTypes.push_back(Elt: ivTy);
2009 ivLocs.push_back(Elt: loc);
2010 mlir::Value ivValue = converter.getSymbolAddress(sym);
2011 if (!ivValue && isDoConcurrent) {
2012 // DO CONCURRENT induction variables are not mapped yet since they are local
2013 // to the DO CONCURRENT scope.
2014 mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
2015 builder.setInsertionPointToStart(builder.getAllocaBlock());
2016 ivValue = builder.createTemporaryAlloc(loc, ivTy, toStringRef(sym.name()));
2017 builder.restoreInsertionPoint(insPt);
2018 }
2019
2020 mlir::Operation *privateOp = nullptr;
2021 for (auto privateVal : privateOperands) {
2022 if (mlir::acc::getVar(privateVal.getDefiningOp()) == ivValue) {
2023 privateOp = privateVal.getDefiningOp();
2024 break;
2025 }
2026 }
2027
2028 if (privateOp == nullptr) {
2029 std::string recipeName =
2030 fir::getTypeAsString(ivValue.getType(), converter.getKindMap(),
2031 Fortran::lower::privatizationRecipePrefix);
2032 auto recipe = Fortran::lower::createOrGetPrivateRecipe(
2033 builder, recipeName, loc, ivValue.getType());
2034
2035 std::stringstream asFortran;
2036 asFortran << Fortran::lower::mangle::demangleName(toStringRef(sym.name()));
2037 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
2038 builder, loc, ivValue, asFortran, {}, true, /*implicit=*/true,
2039 mlir::acc::DataClause::acc_private, ivValue.getType(),
2040 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
2041 privateOp = op.getOperation();
2042
2043 privateOperands.push_back(Elt: op.getAccVar());
2044 privatizationRecipes.push_back(Elt: mlir::SymbolRefAttr::get(
2045 builder.getContext(), recipe.getSymName().str()));
2046 }
2047
2048 // Map the new private iv to its symbol for the scope of the loop. bindSymbol
2049 // might create a hlfir.declare op, if so, we map its result in order to
2050 // use the sym value in the scope.
2051 converter.bindSymbol(sym, mlir::acc::getAccVar(privateOp));
2052 auto privateValue = converter.getSymbolAddress(sym);
2053 if (auto declareOp =
2054 mlir::dyn_cast<hlfir::DeclareOp>(privateValue.getDefiningOp()))
2055 privateValue = declareOp.getResults()[0];
2056 ivPrivate.push_back(Elt: privateValue);
2057}
2058
2059static void determineDefaultLoopParMode(
2060 Fortran::lower::AbstractConverter &converter, mlir::acc::LoopOp &loopOp,
2061 llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
2062 llvm::SmallVector<mlir::Attribute> &independentDeviceTypes,
2063 llvm::SmallVector<mlir::Attribute> &autoDeviceTypes) {
2064 auto hasDeviceNone = [](mlir::Attribute attr) -> bool {
2065 return mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr).getValue() ==
2066 mlir::acc::DeviceType::None;
2067 };
2068 bool hasDefaultSeq = llvm::any_of(Range&: seqDeviceTypes, P: hasDeviceNone);
2069 bool hasDefaultIndependent =
2070 llvm::any_of(Range&: independentDeviceTypes, P: hasDeviceNone);
2071 bool hasDefaultAuto = llvm::any_of(Range&: autoDeviceTypes, P: hasDeviceNone);
2072 if (hasDefaultSeq || hasDefaultIndependent || hasDefaultAuto)
2073 return; // Default loop par mode is already specified.
2074
2075 mlir::Region *currentRegion =
2076 converter.getFirOpBuilder().getBlock()->getParent();
2077 mlir::Operation *parentOp = mlir::acc::getEnclosingComputeOp(*currentRegion);
2078 const bool isOrphanedLoop = !parentOp;
2079 if (isOrphanedLoop ||
2080 mlir::isa_and_present<mlir::acc::ParallelOp>(parentOp)) {
2081 // As per OpenACC 3.3 standard section 2.9.6 independent clause:
2082 // A loop construct with no auto or seq clause is treated as if it has the
2083 // independent clause when it is an orphaned loop construct or its parent
2084 // compute construct is a parallel construct.
2085 independentDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2086 converter.getFirOpBuilder().getContext(), mlir::acc::DeviceType::None));
2087 } else if (mlir::isa_and_present<mlir::acc::SerialOp>(parentOp)) {
2088 // Serial construct implies `seq` clause on loop. However, this
2089 // conflicts with parallelism assignment if already set. Therefore check
2090 // that first.
2091 bool hasDefaultGangWorkerOrVector =
2092 loopOp.hasVector() || loopOp.getVectorValue() || loopOp.hasWorker() ||
2093 loopOp.getWorkerValue() || loopOp.hasGang() ||
2094 loopOp.getGangValue(mlir::acc::GangArgType::Num) ||
2095 loopOp.getGangValue(mlir::acc::GangArgType::Dim) ||
2096 loopOp.getGangValue(mlir::acc::GangArgType::Static);
2097 if (!hasDefaultGangWorkerOrVector)
2098 seqDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2099 converter.getFirOpBuilder().getContext(),
2100 mlir::acc::DeviceType::None));
2101 // Since the loop has some parallelism assigned - we cannot assign `seq`.
2102 // However, the `acc.loop` verifier will check that one of seq, independent,
2103 // or auto is marked. Seems reasonable to mark as auto since the OpenACC
2104 // spec does say "If not, or if it is unable to make a determination, it
2105 // must treat the auto clause as if it is a seq clause, and it must
2106 // ignore any gang, worker, or vector clauses on the loop construct"
2107 else
2108 autoDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2109 converter.getFirOpBuilder().getContext(),
2110 mlir::acc::DeviceType::None));
2111 } else {
2112 // As per OpenACC 3.3 standard section 2.9.7 auto clause:
2113 // When the parent compute construct is a kernels construct, a loop
2114 // construct with no independent or seq clause is treated as if it has the
2115 // auto clause.
2116 assert(mlir::isa_and_present<mlir::acc::KernelsOp>(parentOp) &&
2117 "Expected kernels construct");
2118 autoDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2119 converter.getFirOpBuilder().getContext(), mlir::acc::DeviceType::None));
2120 }
2121}
2122
2123static mlir::acc::LoopOp createLoopOp(
2124 Fortran::lower::AbstractConverter &converter,
2125 mlir::Location currentLocation,
2126 Fortran::semantics::SemanticsContext &semanticsContext,
2127 Fortran::lower::StatementContext &stmtCtx,
2128 const Fortran::parser::DoConstruct &outerDoConstruct,
2129 Fortran::lower::pft::Evaluation &eval,
2130 const Fortran::parser::AccClauseList &accClauseList,
2131 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2132 std::nullopt,
2133 bool needEarlyReturnHandling = false) {
2134 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2135 llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
2136 reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
2137 gangOperands, lowerbounds, upperbounds, steps;
2138 llvm::SmallVector<mlir::Attribute> privatizationRecipes, reductionRecipes;
2139 llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
2140 llvm::SmallVector<int64_t> collapseValues;
2141
2142 llvm::SmallVector<mlir::Attribute> gangArgTypes;
2143 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
2144 autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
2145 vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
2146 collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
2147
2148 // device_type attribute is set to `none` until a device_type clause is
2149 // encountered.
2150 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2151 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2152 builder.getContext(), mlir::acc::DeviceType::None));
2153
2154 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2155 mlir::Location clauseLocation = converter.genLocation(clause.source);
2156 if (const auto *gangClause =
2157 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
2158 if (gangClause->v) {
2159 const Fortran::parser::AccGangArgList &x = *gangClause->v;
2160 mlir::SmallVector<mlir::Value> gangValues;
2161 mlir::SmallVector<mlir::Attribute> gangArgs;
2162 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
2163 if (const auto *num =
2164 std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
2165 gangValues.push_back(fir::getBase(converter.genExprValue(
2166 *Fortran::semantics::GetExpr(num->v), stmtCtx)));
2167 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
2168 builder.getContext(), mlir::acc::GangArgType::Num));
2169 } else if (const auto *staticArg =
2170 std::get_if<Fortran::parser::AccGangArg::Static>(
2171 &gangArg.u)) {
2172 const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
2173 if (sizeExpr.v) {
2174 gangValues.push_back(fir::getBase(converter.genExprValue(
2175 *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
2176 } else {
2177 // * was passed as value and will be represented as a special
2178 // constant.
2179 gangValues.push_back(builder.createIntegerConstant(
2180 clauseLocation, builder.getIndexType(), starCst));
2181 }
2182 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
2183 builder.getContext(), mlir::acc::GangArgType::Static));
2184 } else if (const auto *dim =
2185 std::get_if<Fortran::parser::AccGangArg::Dim>(
2186 &gangArg.u)) {
2187 gangValues.push_back(fir::getBase(converter.genExprValue(
2188 *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
2189 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
2190 builder.getContext(), mlir::acc::GangArgType::Dim));
2191 }
2192 }
2193 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2194 for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
2195 gangOperands.push_back(std::get<0>(pair));
2196 gangArgTypes.push_back(std::get<1>(pair));
2197 }
2198 gangOperandsSegments.push_back(gangValues.size());
2199 gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2200 }
2201 } else {
2202 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2203 gangDeviceTypes.push_back(crtDeviceTypeAttr);
2204 }
2205 } else if (const auto *workerClause =
2206 std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
2207 if (workerClause->v) {
2208 mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
2209 *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
2210 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2211 workerNumOperands.push_back(workerNumValue);
2212 workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2213 }
2214 } else {
2215 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2216 workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
2217 }
2218 } else if (const auto *vectorClause =
2219 std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
2220 if (vectorClause->v) {
2221 mlir::Value vectorValue = fir::getBase(converter.genExprValue(
2222 *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
2223 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2224 vectorOperands.push_back(vectorValue);
2225 vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2226 }
2227 } else {
2228 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2229 vectorDeviceTypes.push_back(crtDeviceTypeAttr);
2230 }
2231 } else if (const auto *tileClause =
2232 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
2233 const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
2234 llvm::SmallVector<mlir::Value> tileValues;
2235 for (const auto &accTileExpr : accTileExprList.v) {
2236 const auto &expr =
2237 std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
2238 accTileExpr.t);
2239 if (expr) {
2240 tileValues.push_back(fir::getBase(converter.genExprValue(
2241 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
2242 } else {
2243 // * was passed as value and will be represented as a special
2244 // constant.
2245 mlir::Value tileStar = builder.createIntegerConstant(
2246 clauseLocation, builder.getIntegerType(32), starCst);
2247 tileValues.push_back(tileStar);
2248 }
2249 }
2250 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2251 for (auto value : tileValues)
2252 tileOperands.push_back(value);
2253 tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2254 tileOperandsSegments.push_back(tileValues.size());
2255 }
2256 } else if (const auto *privateClause =
2257 std::get_if<Fortran::parser::AccClause::Private>(
2258 &clause.u)) {
2259 genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>(
2260 privateClause->v, converter, semanticsContext, stmtCtx,
2261 privateOperands, privatizationRecipes, /*async=*/{},
2262 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
2263 } else if (const auto *reductionClause =
2264 std::get_if<Fortran::parser::AccClause::Reduction>(
2265 &clause.u)) {
2266 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
2267 reductionOperands, reductionRecipes, /*async=*/{},
2268 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
2269 } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
2270 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2271 seqDeviceTypes.push_back(crtDeviceTypeAttr);
2272 } else if (std::get_if<Fortran::parser::AccClause::Independent>(
2273 &clause.u)) {
2274 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2275 independentDeviceTypes.push_back(crtDeviceTypeAttr);
2276 } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
2277 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2278 autoDeviceTypes.push_back(crtDeviceTypeAttr);
2279 } else if (const auto *deviceTypeClause =
2280 std::get_if<Fortran::parser::AccClause::DeviceType>(
2281 &clause.u)) {
2282 crtDeviceTypes.clear();
2283 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2284 } else if (const auto *collapseClause =
2285 std::get_if<Fortran::parser::AccClause::Collapse>(
2286 &clause.u)) {
2287 const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
2288 const auto &force = std::get<bool>(arg.t);
2289 if (force)
2290 TODO(clauseLocation, "OpenACC collapse force modifier");
2291
2292 const auto &intExpr =
2293 std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
2294 const auto *expr = Fortran::semantics::GetExpr(intExpr);
2295 const std::optional<int64_t> collapseValue =
2296 Fortran::evaluate::ToInt64(*expr);
2297 assert(collapseValue && "expect integer value for the collapse clause");
2298
2299 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2300 collapseValues.push_back(*collapseValue);
2301 collapseDeviceTypes.push_back(crtDeviceTypeAttr);
2302 }
2303 }
2304 }
2305
2306 llvm::SmallVector<mlir::Type> ivTypes;
2307 llvm::SmallVector<mlir::Location> ivLocs;
2308 llvm::SmallVector<bool> inclusiveBounds;
2309 llvm::SmallVector<mlir::Location> locs;
2310 locs.push_back(Elt: currentLocation); // Location of the directive
2311 Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
2312 bool isDoConcurrent = outerDoConstruct.IsDoConcurrent();
2313 if (isDoConcurrent) {
2314 locs.push_back(converter.genLocation(
2315 Fortran::parser::FindSourceLocation(outerDoConstruct)));
2316 const Fortran::parser::LoopControl *loopControl =
2317 &*outerDoConstruct.GetLoopControl();
2318 const auto &concurrent =
2319 std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);
2320 if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
2321 .empty())
2322 TODO(currentLocation, "DO CONCURRENT with locality spec");
2323
2324 const auto &concurrentHeader =
2325 std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
2326 const auto &controls =
2327 std::get<std::list<Fortran::parser::ConcurrentControl>>(
2328 concurrentHeader.t);
2329 for (const auto &control : controls) {
2330 lowerbounds.push_back(fir::getBase(converter.genExprValue(
2331 *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx)));
2332 upperbounds.push_back(fir::getBase(converter.genExprValue(
2333 *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx)));
2334 if (const auto &expr =
2335 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
2336 control.t))
2337 steps.push_back(fir::getBase(converter.genExprValue(
2338 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
2339 else // If `step` is not present, assume it is `1`.
2340 steps.push_back(builder.createIntegerConstant(
2341 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
2342
2343 const auto &name = std::get<Fortran::parser::Name>(control.t);
2344 privatizeIv(converter, *name.symbol, currentLocation, ivTypes, ivLocs,
2345 privateOperands, ivPrivate, privatizationRecipes,
2346 isDoConcurrent);
2347
2348 inclusiveBounds.push_back(true);
2349 }
2350 } else {
2351 int64_t loopCount =
2352 Fortran::lower::getLoopCountForCollapseAndTile(accClauseList);
2353 for (unsigned i = 0; i < loopCount; ++i) {
2354 const Fortran::parser::LoopControl *loopControl;
2355 if (i == 0) {
2356 loopControl = &*outerDoConstruct.GetLoopControl();
2357 locs.push_back(converter.genLocation(
2358 Fortran::parser::FindSourceLocation(outerDoConstruct)));
2359 } else {
2360 auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
2361 assert(doCons && "expect do construct");
2362 loopControl = &*doCons->GetLoopControl();
2363 locs.push_back(converter.genLocation(
2364 Fortran::parser::FindSourceLocation(*doCons)));
2365 }
2366
2367 const Fortran::parser::LoopControl::Bounds *bounds =
2368 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
2369 assert(bounds && "Expected bounds on the loop construct");
2370 lowerbounds.push_back(fir::getBase(converter.genExprValue(
2371 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
2372 upperbounds.push_back(fir::getBase(converter.genExprValue(
2373 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
2374 if (bounds->step)
2375 steps.push_back(fir::getBase(converter.genExprValue(
2376 *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
2377 else // If `step` is not present, assume it is `1`.
2378 steps.push_back(Elt: builder.createIntegerConstant(
2379 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
2380
2381 Fortran::semantics::Symbol &ivSym =
2382 bounds->name.thing.symbol->GetUltimate();
2383 privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
2384 privateOperands, ivPrivate, privatizationRecipes);
2385
2386 inclusiveBounds.push_back(Elt: true);
2387
2388 if (i < loopCount - 1)
2389 crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
2390 }
2391 }
2392
2393 // Prepare the operand segment size attribute and the operands value range.
2394 llvm::SmallVector<mlir::Value> operands;
2395 llvm::SmallVector<int32_t> operandSegments;
2396 addOperands(operands, operandSegments, clauseOperands: lowerbounds);
2397 addOperands(operands, operandSegments, clauseOperands: upperbounds);
2398 addOperands(operands, operandSegments, clauseOperands: steps);
2399 addOperands(operands, operandSegments, clauseOperands: gangOperands);
2400 addOperands(operands, operandSegments, clauseOperands: workerNumOperands);
2401 addOperands(operands, operandSegments, clauseOperands: vectorOperands);
2402 addOperands(operands, operandSegments, clauseOperands: tileOperands);
2403 addOperands(operands, operandSegments, clauseOperands: cacheOperands);
2404 addOperands(operands, operandSegments, clauseOperands: privateOperands);
2405 addOperands(operands, operandSegments, clauseOperands: reductionOperands);
2406
2407 llvm::SmallVector<mlir::Type> retTy;
2408 mlir::Value yieldValue;
2409 if (needEarlyReturnHandling) {
2410 mlir::Type i1Ty = builder.getI1Type();
2411 yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
2412 retTy.push_back(Elt: i1Ty);
2413 }
2414
2415 auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
2416 builder, builder.getFusedLoc(locs), currentLocation, eval, operands,
2417 operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes,
2418 ivLocs);
2419
2420 for (auto [arg, value] : llvm::zip(
2421 loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
2422 builder.create<fir::StoreOp>(currentLocation, arg, value);
2423
2424 loopOp.setInclusiveUpperbound(inclusiveBounds);
2425
2426 if (!gangDeviceTypes.empty())
2427 loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
2428 if (!gangArgTypes.empty())
2429 loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
2430 if (!gangOperandsSegments.empty())
2431 loopOp.setGangOperandsSegmentsAttr(
2432 builder.getDenseI32ArrayAttr(gangOperandsSegments));
2433 if (!gangOperandsDeviceTypes.empty())
2434 loopOp.setGangOperandsDeviceTypeAttr(
2435 builder.getArrayAttr(gangOperandsDeviceTypes));
2436
2437 if (!workerNumDeviceTypes.empty())
2438 loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
2439 if (!workerNumOperandsDeviceTypes.empty())
2440 loopOp.setWorkerNumOperandsDeviceTypeAttr(
2441 builder.getArrayAttr(workerNumOperandsDeviceTypes));
2442
2443 if (!vectorDeviceTypes.empty())
2444 loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
2445 if (!vectorOperandsDeviceTypes.empty())
2446 loopOp.setVectorOperandsDeviceTypeAttr(
2447 builder.getArrayAttr(vectorOperandsDeviceTypes));
2448
2449 if (!tileOperandsDeviceTypes.empty())
2450 loopOp.setTileOperandsDeviceTypeAttr(
2451 builder.getArrayAttr(tileOperandsDeviceTypes));
2452 if (!tileOperandsSegments.empty())
2453 loopOp.setTileOperandsSegmentsAttr(
2454 builder.getDenseI32ArrayAttr(tileOperandsSegments));
2455
2456 // Determine the loop's default par mode - either seq, independent, or auto.
2457 determineDefaultLoopParMode(converter, loopOp, seqDeviceTypes,
2458 independentDeviceTypes, autoDeviceTypes);
2459 if (!seqDeviceTypes.empty())
2460 loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
2461 if (!independentDeviceTypes.empty())
2462 loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
2463 if (!autoDeviceTypes.empty())
2464 loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
2465
2466 if (!privatizationRecipes.empty())
2467 loopOp.setPrivatizationRecipesAttr(
2468 mlir::ArrayAttr::get(context: builder.getContext(), value: privatizationRecipes));
2469
2470 if (!reductionRecipes.empty())
2471 loopOp.setReductionRecipesAttr(
2472 mlir::ArrayAttr::get(context: builder.getContext(), value: reductionRecipes));
2473
2474 if (!collapseValues.empty())
2475 loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
2476 if (!collapseDeviceTypes.empty())
2477 loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
2478
2479 if (combinedConstructs)
2480 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
2481 builder.getContext(), *combinedConstructs));
2482
2483 // TODO: retrieve directives from NonLabelDoStmt pft::Evaluation, and add them
2484 // as attribute to the acc.loop as an extra attribute. It is not quite clear
2485 // how useful these $dir are in acc contexts, but they could still provide
2486 // more information about the loop acc codegen. They can be obtained by
2487 // looking for the first lexicalSuccessor of eval that is a NonLabelDoStmt,
2488 // and using the related `dirs` member.
2489
2490 return loopOp;
2491}
2492
2493static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2494 bool hasReturnStmt = false;
2495 for (auto &e : eval.getNestedEvaluations()) {
2496 e.visit(Fortran::common::visitors{
2497 [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2498 [&](const auto &s) {},
2499 });
2500 if (e.hasNestedEvaluations())
2501 hasReturnStmt = hasEarlyReturn(e);
2502 }
2503 return hasReturnStmt;
2504}
2505
2506static mlir::Value
2507genACC(Fortran::lower::AbstractConverter &converter,
2508 Fortran::semantics::SemanticsContext &semanticsContext,
2509 Fortran::lower::pft::Evaluation &eval,
2510 const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
2511
2512 const auto &beginLoopDirective =
2513 std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
2514 const auto &loopDirective =
2515 std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
2516
2517 bool needEarlyExitHandling = false;
2518 if (eval.lowerAsUnstructured())
2519 needEarlyExitHandling = hasEarlyReturn(eval);
2520
2521 mlir::Location currentLocation =
2522 converter.genLocation(beginLoopDirective.source);
2523 Fortran::lower::StatementContext stmtCtx;
2524
2525 assert(loopDirective.v == llvm::acc::ACCD_loop &&
2526 "Unsupported OpenACC loop construct");
2527 (void)loopDirective;
2528
2529 const auto &accClauseList =
2530 std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
2531 const auto &outerDoConstruct =
2532 std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
2533 auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
2534 stmtCtx, *outerDoConstruct, eval, accClauseList,
2535 /*combinedConstructs=*/{}, needEarlyExitHandling);
2536 if (needEarlyExitHandling)
2537 return loopOp.getResult(0);
2538
2539 return mlir::Value{};
2540}
2541
2542template <typename Op, typename Clause>
2543static void genDataOperandOperationsWithModifier(
2544 const Clause *x, Fortran::lower::AbstractConverter &converter,
2545 Fortran::semantics::SemanticsContext &semanticsContext,
2546 Fortran::lower::StatementContext &stmtCtx,
2547 Fortran::parser::AccDataModifier::Modifier mod,
2548 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
2549 const mlir::acc::DataClause clause,
2550 const mlir::acc::DataClause clauseWithModifier,
2551 llvm::ArrayRef<mlir::Value> async,
2552 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
2553 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
2554 bool setDeclareAttr = false) {
2555 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
2556 const auto &accObjectList =
2557 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2558 const auto &modifier =
2559 std::get<std::optional<Fortran::parser::AccDataModifier>>(
2560 listWithModifier.t);
2561 mlir::acc::DataClause dataClause =
2562 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
2563 genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
2564 stmtCtx, dataClauseOperands, dataClause,
2565 /*structured=*/true, /*implicit=*/false, async,
2566 asyncDeviceTypes, asyncOnlyDeviceTypes,
2567 setDeclareAttr);
2568}
2569
2570template <typename Op>
2571static Op createComputeOp(
2572 Fortran::lower::AbstractConverter &converter,
2573 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
2574 Fortran::semantics::SemanticsContext &semanticsContext,
2575 Fortran::lower::StatementContext &stmtCtx,
2576 const Fortran::parser::AccClauseList &accClauseList,
2577 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2578 std::nullopt) {
2579
2580 // Parallel operation operands
2581 mlir::Value ifCond;
2582 mlir::Value selfCond;
2583 llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
2584 copyEntryOperands, copyinEntryOperands, copyoutEntryOperands,
2585 createEntryOperands, nocreateEntryOperands, presentEntryOperands,
2586 dataClauseOperands, numGangs, numWorkers, vectorLength, async;
2587 llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
2588 vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
2589 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2590 llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
2591 llvm::SmallVector<bool> hasWaitDevnums;
2592
2593 llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
2594 firstprivateOperands;
2595 llvm::SmallVector<mlir::Attribute> privatizationRecipes,
2596 firstPrivatizationRecipes, reductionRecipes;
2597
2598 // Self clause has optional values but can be present with
2599 // no value as well. When there is no value, the op has an attribute to
2600 // represent the clause.
2601 bool addSelfAttr = false;
2602
2603 bool hasDefaultNone = false;
2604 bool hasDefaultPresent = false;
2605
2606 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2607
2608 // device_type attribute is set to `none` until a device_type clause is
2609 // encountered.
2610 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2611 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2612 builder.getContext(), mlir::acc::DeviceType::None);
2613 crtDeviceTypes.push_back(Elt: crtDeviceTypeAttr);
2614
2615 // Lower clauses values mapped to operands and array attributes.
2616 // Keep track of each group of operands separately as clauses can appear
2617 // more than once.
2618
2619 // Process the clauses that may have a specified device_type first.
2620 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2621 if (const auto *asyncClause =
2622 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2623 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2624 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2625 } else if (const auto *waitClause =
2626 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2627 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
2628 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2629 hasWaitDevnums, waitOperandsSegments,
2630 crtDeviceTypes, stmtCtx);
2631 } else if (const auto *numGangsClause =
2632 std::get_if<Fortran::parser::AccClause::NumGangs>(
2633 &clause.u)) {
2634 llvm::SmallVector<mlir::Value> numGangValues;
2635 for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
2636 numGangValues.push_back(fir::getBase(converter.genExprValue(
2637 *Fortran::semantics::GetExpr(expr), stmtCtx)));
2638 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2639 for (auto value : numGangValues)
2640 numGangs.push_back(value);
2641 numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
2642 numGangsSegments.push_back(numGangValues.size());
2643 }
2644 } else if (const auto *numWorkersClause =
2645 std::get_if<Fortran::parser::AccClause::NumWorkers>(
2646 &clause.u)) {
2647 mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
2648 *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
2649 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2650 numWorkers.push_back(numWorkerValue);
2651 numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
2652 }
2653 } else if (const auto *vectorLengthClause =
2654 std::get_if<Fortran::parser::AccClause::VectorLength>(
2655 &clause.u)) {
2656 mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
2657 *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
2658 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2659 vectorLength.push_back(vectorLengthValue);
2660 vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
2661 }
2662 } else if (const auto *deviceTypeClause =
2663 std::get_if<Fortran::parser::AccClause::DeviceType>(
2664 &clause.u)) {
2665 crtDeviceTypes.clear();
2666 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2667 }
2668 }
2669
2670 // Process the clauses independent of device_type.
2671 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2672 mlir::Location clauseLocation = converter.genLocation(clause.source);
2673 if (const auto *ifClause =
2674 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2675 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2676 } else if (const auto *selfClause =
2677 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
2678 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
2679 selfClause->v;
2680 if (accSelfClause) {
2681 if (const auto *optCondition =
2682 std::get_if<std::optional<Fortran::parser::ScalarLogicalExpr>>(
2683 &(*accSelfClause).u)) {
2684 if (*optCondition) {
2685 mlir::Value cond = fir::getBase(converter.genExprValue(
2686 *Fortran::semantics::GetExpr(*optCondition), stmtCtx));
2687 selfCond = builder.createConvert(clauseLocation,
2688 builder.getI1Type(), cond);
2689 }
2690 } else if (const auto *accClauseList =
2691 std::get_if<Fortran::parser::AccObjectList>(
2692 &(*accSelfClause).u)) {
2693 // TODO This would be nicer to be done in canonicalization step.
2694 if (accClauseList->v.size() == 1) {
2695 const auto &accObject = accClauseList->v.front();
2696 if (const auto *designator =
2697 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
2698 if (const auto *name =
2699 Fortran::semantics::getDesignatorNameIfDataRef(
2700 *designator)) {
2701 auto cond = converter.getSymbolAddress(*name->symbol);
2702 selfCond = builder.createConvert(clauseLocation,
2703 builder.getI1Type(), cond);
2704 }
2705 }
2706 }
2707 }
2708 } else {
2709 addSelfAttr = true;
2710 }
2711 } else if (const auto *copyClause =
2712 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2713 auto crtDataStart = dataClauseOperands.size();
2714 genDataOperandOperations<mlir::acc::CopyinOp>(
2715 copyClause->v, converter, semanticsContext, stmtCtx,
2716 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2717 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2718 asyncOnlyDeviceTypes);
2719 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2720 dataClauseOperands.end());
2721 } else if (const auto *copyinClause =
2722 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2723 auto crtDataStart = dataClauseOperands.size();
2724 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2725 Fortran::parser::AccClause::Copyin>(
2726 copyinClause, converter, semanticsContext, stmtCtx,
2727 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2728 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2729 mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes,
2730 asyncOnlyDeviceTypes);
2731 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2732 dataClauseOperands.end());
2733 } else if (const auto *copyoutClause =
2734 std::get_if<Fortran::parser::AccClause::Copyout>(
2735 &clause.u)) {
2736 auto crtDataStart = dataClauseOperands.size();
2737 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2738 Fortran::parser::AccClause::Copyout>(
2739 copyoutClause, converter, semanticsContext, stmtCtx,
2740 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2741 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
2742 mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes,
2743 asyncOnlyDeviceTypes);
2744 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2745 dataClauseOperands.end());
2746 } else if (const auto *createClause =
2747 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2748 auto crtDataStart = dataClauseOperands.size();
2749 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2750 Fortran::parser::AccClause::Create>(
2751 createClause, converter, semanticsContext, stmtCtx,
2752 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2753 mlir::acc::DataClause::acc_create,
2754 mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes,
2755 asyncOnlyDeviceTypes);
2756 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2757 dataClauseOperands.end());
2758 } else if (const auto *noCreateClause =
2759 std::get_if<Fortran::parser::AccClause::NoCreate>(
2760 &clause.u)) {
2761 auto crtDataStart = dataClauseOperands.size();
2762 genDataOperandOperations<mlir::acc::NoCreateOp>(
2763 noCreateClause->v, converter, semanticsContext, stmtCtx,
2764 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2765 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2766 asyncOnlyDeviceTypes);
2767 nocreateEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2768 dataClauseOperands.end());
2769 } else if (const auto *presentClause =
2770 std::get_if<Fortran::parser::AccClause::Present>(
2771 &clause.u)) {
2772 auto crtDataStart = dataClauseOperands.size();
2773 genDataOperandOperations<mlir::acc::PresentOp>(
2774 presentClause->v, converter, semanticsContext, stmtCtx,
2775 dataClauseOperands, mlir::acc::DataClause::acc_present,
2776 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2777 asyncOnlyDeviceTypes);
2778 presentEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2779 dataClauseOperands.end());
2780 } else if (const auto *devicePtrClause =
2781 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2782 &clause.u)) {
2783 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2784 devicePtrClause->v, converter, semanticsContext, stmtCtx,
2785 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2786 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2787 asyncOnlyDeviceTypes);
2788 } else if (const auto *attachClause =
2789 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2790 auto crtDataStart = dataClauseOperands.size();
2791 genDataOperandOperations<mlir::acc::AttachOp>(
2792 attachClause->v, converter, semanticsContext, stmtCtx,
2793 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2794 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2795 asyncOnlyDeviceTypes);
2796 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2797 dataClauseOperands.end());
2798 } else if (const auto *privateClause =
2799 std::get_if<Fortran::parser::AccClause::Private>(
2800 &clause.u)) {
2801 if (!combinedConstructs)
2802 genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>(
2803 privateClause->v, converter, semanticsContext, stmtCtx,
2804 privateOperands, privatizationRecipes, async, asyncDeviceTypes,
2805 asyncOnlyDeviceTypes);
2806 } else if (const auto *firstprivateClause =
2807 std::get_if<Fortran::parser::AccClause::Firstprivate>(
2808 &clause.u)) {
2809 genPrivatizationRecipes<mlir::acc::FirstprivateRecipeOp>(
2810 firstprivateClause->v, converter, semanticsContext, stmtCtx,
2811 firstprivateOperands, firstPrivatizationRecipes, async,
2812 asyncDeviceTypes, asyncOnlyDeviceTypes);
2813 } else if (const auto *reductionClause =
2814 std::get_if<Fortran::parser::AccClause::Reduction>(
2815 &clause.u)) {
2816 // A reduction clause on a combined construct is treated as if it appeared
2817 // on the loop construct. So don't generate a reduction clause when it is
2818 // combined - delay it to the loop. However, a reduction clause on a
2819 // combined construct implies a copy clause so issue an implicit copy
2820 // instead.
2821 if (!combinedConstructs) {
2822 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
2823 reductionOperands, reductionRecipes, async,
2824 asyncDeviceTypes, asyncOnlyDeviceTypes);
2825 } else {
2826 auto crtDataStart = dataClauseOperands.size();
2827 genDataOperandOperations<mlir::acc::CopyinOp>(
2828 std::get<Fortran::parser::AccObjectList>(reductionClause->v.t),
2829 converter, semanticsContext, stmtCtx, dataClauseOperands,
2830 mlir::acc::DataClause::acc_reduction,
2831 /*structured=*/true, /*implicit=*/true, async, asyncDeviceTypes,
2832 asyncOnlyDeviceTypes);
2833 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2834 dataClauseOperands.end());
2835 }
2836 } else if (const auto *defaultClause =
2837 std::get_if<Fortran::parser::AccClause::Default>(
2838 &clause.u)) {
2839 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2840 hasDefaultNone = true;
2841 else if ((defaultClause->v).v ==
2842 llvm::acc::DefaultValue::ACC_Default_present)
2843 hasDefaultPresent = true;
2844 }
2845 }
2846
2847 // Prepare the operand segment size attribute and the operands value range.
2848 llvm::SmallVector<mlir::Value, 8> operands;
2849 llvm::SmallVector<int32_t, 8> operandSegments;
2850 addOperands(operands, operandSegments, clauseOperands: async);
2851 addOperands(operands, operandSegments, clauseOperands: waitOperands);
2852 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2853 addOperands(operands, operandSegments, clauseOperands: numGangs);
2854 addOperands(operands, operandSegments, clauseOperands: numWorkers);
2855 addOperands(operands, operandSegments, clauseOperands: vectorLength);
2856 }
2857 addOperand(operands, operandSegments, clauseOperand: ifCond);
2858 addOperand(operands, operandSegments, clauseOperand: selfCond);
2859 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2860 addOperands(operands, operandSegments, clauseOperands: reductionOperands);
2861 addOperands(operands, operandSegments, clauseOperands: privateOperands);
2862 addOperands(operands, operandSegments, clauseOperands: firstprivateOperands);
2863 }
2864 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
2865
2866 Op computeOp;
2867 if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
2868 computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
2869 builder, currentLocation, currentLocation, eval, operands,
2870 operandSegments, /*outerCombined=*/combinedConstructs.has_value());
2871 else
2872 computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
2873 builder, currentLocation, currentLocation, eval, operands,
2874 operandSegments, /*outerCombined=*/combinedConstructs.has_value());
2875
2876 if (addSelfAttr)
2877 computeOp.setSelfAttrAttr(builder.getUnitAttr());
2878
2879 if (hasDefaultNone)
2880 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2881 if (hasDefaultPresent)
2882 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2883
2884 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2885 if (!numWorkersDeviceTypes.empty())
2886 computeOp.setNumWorkersDeviceTypeAttr(
2887 mlir::ArrayAttr::get(context: builder.getContext(), value: numWorkersDeviceTypes));
2888 if (!vectorLengthDeviceTypes.empty())
2889 computeOp.setVectorLengthDeviceTypeAttr(
2890 mlir::ArrayAttr::get(context: builder.getContext(), value: vectorLengthDeviceTypes));
2891 if (!numGangsDeviceTypes.empty())
2892 computeOp.setNumGangsDeviceTypeAttr(
2893 mlir::ArrayAttr::get(context: builder.getContext(), value: numGangsDeviceTypes));
2894 if (!numGangsSegments.empty())
2895 computeOp.setNumGangsSegmentsAttr(
2896 builder.getDenseI32ArrayAttr(numGangsSegments));
2897 }
2898 if (!asyncDeviceTypes.empty())
2899 computeOp.setAsyncOperandsDeviceTypeAttr(
2900 builder.getArrayAttr(asyncDeviceTypes));
2901 if (!asyncOnlyDeviceTypes.empty())
2902 computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2903
2904 if (!waitOperandsDeviceTypes.empty())
2905 computeOp.setWaitOperandsDeviceTypeAttr(
2906 builder.getArrayAttr(waitOperandsDeviceTypes));
2907 if (!waitOperandsSegments.empty())
2908 computeOp.setWaitOperandsSegmentsAttr(
2909 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2910 if (!hasWaitDevnums.empty())
2911 computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
2912 if (!waitOnlyDeviceTypes.empty())
2913 computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2914
2915 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2916 if (!privatizationRecipes.empty())
2917 computeOp.setPrivatizationRecipesAttr(
2918 mlir::ArrayAttr::get(context: builder.getContext(), value: privatizationRecipes));
2919 if (!reductionRecipes.empty())
2920 computeOp.setReductionRecipesAttr(
2921 mlir::ArrayAttr::get(context: builder.getContext(), value: reductionRecipes));
2922 if (!firstPrivatizationRecipes.empty())
2923 computeOp.setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(
2924 context: builder.getContext(), value: firstPrivatizationRecipes));
2925 }
2926
2927 if (combinedConstructs)
2928 computeOp.setCombinedAttr(builder.getUnitAttr());
2929
2930 auto insPt = builder.saveInsertionPoint();
2931 builder.setInsertionPointAfter(computeOp);
2932
2933 // Create the exit operations after the region.
2934 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2935 builder, copyEntryOperands, /*structured=*/true);
2936 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
2937 builder, copyinEntryOperands, /*structured=*/true);
2938 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2939 builder, copyoutEntryOperands, /*structured=*/true);
2940 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2941 builder, attachEntryOperands, /*structured=*/true);
2942 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2943 builder, createEntryOperands, /*structured=*/true);
2944 genDataExitOperations<mlir::acc::NoCreateOp, mlir::acc::DeleteOp>(
2945 builder, nocreateEntryOperands, /*structured=*/true);
2946 genDataExitOperations<mlir::acc::PresentOp, mlir::acc::DeleteOp>(
2947 builder, presentEntryOperands, /*structured=*/true);
2948
2949 builder.restoreInsertionPoint(insPt);
2950 return computeOp;
2951}
2952
2953static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2954 mlir::Location currentLocation,
2955 mlir::Location endLocation,
2956 Fortran::lower::pft::Evaluation &eval,
2957 Fortran::semantics::SemanticsContext &semanticsContext,
2958 Fortran::lower::StatementContext &stmtCtx,
2959 const Fortran::parser::AccClauseList &accClauseList) {
2960 mlir::Value ifCond;
2961 llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
2962 copyEntryOperands, copyinEntryOperands, copyoutEntryOperands,
2963 nocreateEntryOperands, presentEntryOperands, dataClauseOperands,
2964 waitOperands, async;
2965 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2966 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2967 llvm::SmallVector<int32_t> waitOperandsSegments;
2968 llvm::SmallVector<bool> hasWaitDevnums;
2969
2970 bool hasDefaultNone = false;
2971 bool hasDefaultPresent = false;
2972
2973 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2974
2975 // device_type attribute is set to `none` until a device_type clause is
2976 // encountered.
2977 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2978 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2979 builder.getContext(), mlir::acc::DeviceType::None));
2980
2981 // Lower clauses values mapped to operands and array attributes.
2982 // Keep track of each group of operands separately as clauses can appear
2983 // more than once.
2984
2985 // Process the clauses that may have a specified device_type first.
2986 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2987 if (const auto *asyncClause =
2988 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2989 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2990 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2991 } else if (const auto *waitClause =
2992 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2993 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
2994 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2995 hasWaitDevnums, waitOperandsSegments,
2996 crtDeviceTypes, stmtCtx);
2997 } else if (const auto *deviceTypeClause =
2998 std::get_if<Fortran::parser::AccClause::DeviceType>(
2999 &clause.u)) {
3000 crtDeviceTypes.clear();
3001 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
3002 }
3003 }
3004
3005 // Process the clauses independent of device_type.
3006 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3007 mlir::Location clauseLocation = converter.genLocation(clause.source);
3008 if (const auto *ifClause =
3009 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3010 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3011 } else if (const auto *copyClause =
3012 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
3013 auto crtDataStart = dataClauseOperands.size();
3014 genDataOperandOperations<mlir::acc::CopyinOp>(
3015 copyClause->v, converter, semanticsContext, stmtCtx,
3016 dataClauseOperands, mlir::acc::DataClause::acc_copy,
3017 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
3018 asyncOnlyDeviceTypes);
3019 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3020 dataClauseOperands.end());
3021 } else if (const auto *copyinClause =
3022 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3023 auto crtDataStart = dataClauseOperands.size();
3024 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
3025 Fortran::parser::AccClause::Copyin>(
3026 copyinClause, converter, semanticsContext, stmtCtx,
3027 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3028 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
3029 mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes,
3030 asyncOnlyDeviceTypes);
3031 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3032 dataClauseOperands.end());
3033 } else if (const auto *copyoutClause =
3034 std::get_if<Fortran::parser::AccClause::Copyout>(
3035 &clause.u)) {
3036 auto crtDataStart = dataClauseOperands.size();
3037 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
3038 Fortran::parser::AccClause::Copyout>(
3039 copyoutClause, converter, semanticsContext, stmtCtx,
3040 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
3041 mlir::acc::DataClause::acc_copyout,
3042 mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes,
3043 asyncOnlyDeviceTypes);
3044 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3045 dataClauseOperands.end());
3046 } else if (const auto *createClause =
3047 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3048 auto crtDataStart = dataClauseOperands.size();
3049 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
3050 Fortran::parser::AccClause::Create>(
3051 createClause, converter, semanticsContext, stmtCtx,
3052 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
3053 mlir::acc::DataClause::acc_create,
3054 mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes,
3055 asyncOnlyDeviceTypes);
3056 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3057 dataClauseOperands.end());
3058 } else if (const auto *noCreateClause =
3059 std::get_if<Fortran::parser::AccClause::NoCreate>(
3060 &clause.u)) {
3061 auto crtDataStart = dataClauseOperands.size();
3062 genDataOperandOperations<mlir::acc::NoCreateOp>(
3063 noCreateClause->v, converter, semanticsContext, stmtCtx,
3064 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
3065 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
3066 asyncOnlyDeviceTypes);
3067 nocreateEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3068 dataClauseOperands.end());
3069 } else if (const auto *presentClause =
3070 std::get_if<Fortran::parser::AccClause::Present>(
3071 &clause.u)) {
3072 auto crtDataStart = dataClauseOperands.size();
3073 genDataOperandOperations<mlir::acc::PresentOp>(
3074 presentClause->v, converter, semanticsContext, stmtCtx,
3075 dataClauseOperands, mlir::acc::DataClause::acc_present,
3076 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
3077 asyncOnlyDeviceTypes);
3078 presentEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3079 dataClauseOperands.end());
3080 } else if (const auto *deviceptrClause =
3081 std::get_if<Fortran::parser::AccClause::Deviceptr>(
3082 &clause.u)) {
3083 genDataOperandOperations<mlir::acc::DevicePtrOp>(
3084 deviceptrClause->v, converter, semanticsContext, stmtCtx,
3085 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
3086 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
3087 asyncOnlyDeviceTypes);
3088 } else if (const auto *attachClause =
3089 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
3090 auto crtDataStart = dataClauseOperands.size();
3091 genDataOperandOperations<mlir::acc::AttachOp>(
3092 attachClause->v, converter, semanticsContext, stmtCtx,
3093 dataClauseOperands, mlir::acc::DataClause::acc_attach,
3094 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
3095 asyncOnlyDeviceTypes);
3096 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3097 dataClauseOperands.end());
3098 } else if (const auto *defaultClause =
3099 std::get_if<Fortran::parser::AccClause::Default>(
3100 &clause.u)) {
3101 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
3102 hasDefaultNone = true;
3103 else if ((defaultClause->v).v ==
3104 llvm::acc::DefaultValue::ACC_Default_present)
3105 hasDefaultPresent = true;
3106 }
3107 }
3108
3109 // Prepare the operand segment size attribute and the operands value range.
3110 llvm::SmallVector<mlir::Value> operands;
3111 llvm::SmallVector<int32_t> operandSegments;
3112 addOperand(operands, operandSegments, clauseOperand: ifCond);
3113 addOperands(operands, operandSegments, clauseOperands: async);
3114 addOperands(operands, operandSegments, clauseOperands: waitOperands);
3115 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
3116
3117 if (dataClauseOperands.empty() && !hasDefaultNone && !hasDefaultPresent)
3118 return;
3119
3120 auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
3121 builder, currentLocation, currentLocation, eval, operands,
3122 operandSegments);
3123
3124 if (!asyncDeviceTypes.empty())
3125 dataOp.setAsyncOperandsDeviceTypeAttr(
3126 builder.getArrayAttr(asyncDeviceTypes));
3127 if (!asyncOnlyDeviceTypes.empty())
3128 dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
3129 if (!waitOperandsDeviceTypes.empty())
3130 dataOp.setWaitOperandsDeviceTypeAttr(
3131 builder.getArrayAttr(waitOperandsDeviceTypes));
3132 if (!waitOperandsSegments.empty())
3133 dataOp.setWaitOperandsSegmentsAttr(
3134 builder.getDenseI32ArrayAttr(waitOperandsSegments));
3135 if (!hasWaitDevnums.empty())
3136 dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
3137 if (!waitOnlyDeviceTypes.empty())
3138 dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
3139
3140 if (hasDefaultNone)
3141 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
3142 if (hasDefaultPresent)
3143 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
3144
3145 auto insPt = builder.saveInsertionPoint();
3146 builder.setInsertionPointAfter(dataOp);
3147
3148 // Create the exit operations after the region.
3149 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
3150 builder, copyEntryOperands, /*structured=*/true, endLocation);
3151 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
3152 builder, copyinEntryOperands, /*structured=*/true, endLocation);
3153 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
3154 builder, copyoutEntryOperands, /*structured=*/true, endLocation);
3155 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
3156 builder, attachEntryOperands, /*structured=*/true, endLocation);
3157 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3158 builder, createEntryOperands, /*structured=*/true, endLocation);
3159 genDataExitOperations<mlir::acc::NoCreateOp, mlir::acc::DeleteOp>(
3160 builder, nocreateEntryOperands, /*structured=*/true, endLocation);
3161 genDataExitOperations<mlir::acc::PresentOp, mlir::acc::DeleteOp>(
3162 builder, presentEntryOperands, /*structured=*/true, endLocation);
3163
3164 builder.restoreInsertionPoint(insPt);
3165}
3166
3167static void
3168genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
3169 mlir::Location currentLocation,
3170 Fortran::lower::pft::Evaluation &eval,
3171 Fortran::semantics::SemanticsContext &semanticsContext,
3172 Fortran::lower::StatementContext &stmtCtx,
3173 const Fortran::parser::AccClauseList &accClauseList) {
3174 mlir::Value ifCond;
3175 llvm::SmallVector<mlir::Value> dataOperands;
3176 bool addIfPresentAttr = false;
3177
3178 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3179
3180 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3181 mlir::Location clauseLocation = converter.genLocation(clause.source);
3182 if (const auto *ifClause =
3183 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3184 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3185 } else if (const auto *useDevice =
3186 std::get_if<Fortran::parser::AccClause::UseDevice>(
3187 &clause.u)) {
3188 genDataOperandOperations<mlir::acc::UseDeviceOp>(
3189 useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
3190 mlir::acc::DataClause::acc_use_device,
3191 /*structured=*/true, /*implicit=*/false, /*async=*/{},
3192 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3193 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
3194 addIfPresentAttr = true;
3195 }
3196 }
3197
3198 if (ifCond) {
3199 if (auto cst =
3200 mlir::dyn_cast<mlir::arith::ConstantOp>(ifCond.getDefiningOp()))
3201 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(cst.getValue())) {
3202 if (boolAttr.getValue()) {
3203 // get rid of the if condition if it is always true.
3204 ifCond = mlir::Value();
3205 } else {
3206 // Do not generate the acc.host_data op if the if condition is always
3207 // false.
3208 return;
3209 }
3210 }
3211 }
3212
3213 // Prepare the operand segment size attribute and the operands value range.
3214 llvm::SmallVector<mlir::Value> operands;
3215 llvm::SmallVector<int32_t> operandSegments;
3216 addOperand(operands, operandSegments, clauseOperand: ifCond);
3217 addOperands(operands, operandSegments, clauseOperands: dataOperands);
3218
3219 auto hostDataOp =
3220 createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
3221 builder, currentLocation, currentLocation, eval, operands,
3222 operandSegments);
3223
3224 if (addIfPresentAttr)
3225 hostDataOp.setIfPresentAttr(builder.getUnitAttr());
3226}
3227
3228static void
3229genACC(Fortran::lower::AbstractConverter &converter,
3230 Fortran::semantics::SemanticsContext &semanticsContext,
3231 Fortran::lower::pft::Evaluation &eval,
3232 const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
3233 const auto &beginBlockDirective =
3234 std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
3235 const auto &blockDirective =
3236 std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
3237 const auto &accClauseList =
3238 std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t);
3239 const auto &endBlockDirective =
3240 std::get<Fortran::parser::AccEndBlockDirective>(blockConstruct.t);
3241 mlir::Location endLocation = converter.genLocation(endBlockDirective.source);
3242 mlir::Location currentLocation = converter.genLocation(blockDirective.source);
3243 Fortran::lower::StatementContext stmtCtx;
3244
3245 if (blockDirective.v == llvm::acc::ACCD_parallel) {
3246 createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval,
3247 semanticsContext, stmtCtx,
3248 accClauseList);
3249 } else if (blockDirective.v == llvm::acc::ACCD_data) {
3250 genACCDataOp(converter, currentLocation, endLocation, eval,
3251 semanticsContext, stmtCtx, accClauseList);
3252 } else if (blockDirective.v == llvm::acc::ACCD_serial) {
3253 createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
3254 semanticsContext, stmtCtx,
3255 accClauseList);
3256 } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
3257 createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval,
3258 semanticsContext, stmtCtx,
3259 accClauseList);
3260 } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
3261 genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
3262 stmtCtx, accClauseList);
3263 }
3264}
3265
3266static void
3267genACC(Fortran::lower::AbstractConverter &converter,
3268 Fortran::semantics::SemanticsContext &semanticsContext,
3269 Fortran::lower::pft::Evaluation &eval,
3270 const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
3271 const auto &beginCombinedDirective =
3272 std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
3273 const auto &combinedDirective =
3274 std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
3275 const auto &accClauseList =
3276 std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
3277 const auto &outerDoConstruct =
3278 std::get<std::optional<Fortran::parser::DoConstruct>>(
3279 combinedConstruct.t);
3280
3281 mlir::Location currentLocation =
3282 converter.genLocation(beginCombinedDirective.source);
3283 Fortran::lower::StatementContext stmtCtx;
3284
3285 if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
3286 createComputeOp<mlir::acc::KernelsOp>(
3287 converter, currentLocation, eval, semanticsContext, stmtCtx,
3288 accClauseList, mlir::acc::CombinedConstructsType::KernelsLoop);
3289 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
3290 *outerDoConstruct, eval, accClauseList,
3291 mlir::acc::CombinedConstructsType::KernelsLoop);
3292 } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
3293 createComputeOp<mlir::acc::ParallelOp>(
3294 converter, currentLocation, eval, semanticsContext, stmtCtx,
3295 accClauseList, mlir::acc::CombinedConstructsType::ParallelLoop);
3296 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
3297 *outerDoConstruct, eval, accClauseList,
3298 mlir::acc::CombinedConstructsType::ParallelLoop);
3299 } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
3300 createComputeOp<mlir::acc::SerialOp>(
3301 converter, currentLocation, eval, semanticsContext, stmtCtx,
3302 accClauseList, mlir::acc::CombinedConstructsType::SerialLoop);
3303 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
3304 *outerDoConstruct, eval, accClauseList,
3305 mlir::acc::CombinedConstructsType::SerialLoop);
3306 } else {
3307 llvm::report_fatal_error(reason: "Unknown combined construct encountered");
3308 }
3309}
3310
3311static void
3312genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
3313 mlir::Location currentLocation,
3314 Fortran::semantics::SemanticsContext &semanticsContext,
3315 Fortran::lower::StatementContext &stmtCtx,
3316 const Fortran::parser::AccClauseList &accClauseList) {
3317 mlir::Value ifCond, async, waitDevnum;
3318 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands;
3319
3320 // Async, wait and self clause have optional values but can be present with
3321 // no value as well. When there is no value, the op has an attribute to
3322 // represent the clause.
3323 bool addAsyncAttr = false;
3324 bool addWaitAttr = false;
3325
3326 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3327
3328 // Lower clauses values mapped to operands.
3329 // Keep track of each group of operands separately as clauses can appear
3330 // more than once.
3331
3332 // Process the async clause first.
3333 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3334 if (const auto *asyncClause =
3335 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3336 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3337 }
3338 }
3339
3340 // The async clause of 'enter data' applies to all device types,
3341 // so propagate the async clause to copyin/create/attach ops
3342 // as if it is an async clause without preceding device_type clause.
3343 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes;
3344 llvm::SmallVector<mlir::Value> asyncValues;
3345 auto noneDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
3346 firOpBuilder.getContext(), mlir::acc::DeviceType::None);
3347 if (addAsyncAttr) {
3348 asyncOnlyDeviceTypes.push_back(Elt: noneDeviceTypeAttr);
3349 } else if (async) {
3350 asyncValues.push_back(Elt: async);
3351 asyncDeviceTypes.push_back(Elt: noneDeviceTypeAttr);
3352 }
3353
3354 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3355 mlir::Location clauseLocation = converter.genLocation(clause.source);
3356 if (const auto *ifClause =
3357 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3358 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3359 } else if (const auto *waitClause =
3360 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3361 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
3362 addWaitAttr, stmtCtx);
3363 } else if (const auto *copyinClause =
3364 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3365 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3366 copyinClause->v;
3367 const auto &accObjectList =
3368 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3369 genDataOperandOperations<mlir::acc::CopyinOp>(
3370 accObjectList, converter, semanticsContext, stmtCtx,
3371 dataClauseOperands, mlir::acc::DataClause::acc_copyin, false,
3372 /*implicit=*/false, asyncValues, asyncDeviceTypes,
3373 asyncOnlyDeviceTypes);
3374 } else if (const auto *createClause =
3375 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3376 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3377 createClause->v;
3378 const auto &accObjectList =
3379 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3380 const auto &modifier =
3381 std::get<std::optional<Fortran::parser::AccDataModifier>>(
3382 listWithModifier.t);
3383 mlir::acc::DataClause clause = mlir::acc::DataClause::acc_create;
3384 if (modifier &&
3385 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::Zero)
3386 clause = mlir::acc::DataClause::acc_create_zero;
3387 genDataOperandOperations<mlir::acc::CreateOp>(
3388 accObjectList, converter, semanticsContext, stmtCtx,
3389 dataClauseOperands, clause, false, /*implicit=*/false, asyncValues,
3390 asyncDeviceTypes, asyncOnlyDeviceTypes);
3391 } else if (const auto *attachClause =
3392 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
3393 genDataOperandOperations<mlir::acc::AttachOp>(
3394 attachClause->v, converter, semanticsContext, stmtCtx,
3395 dataClauseOperands, mlir::acc::DataClause::acc_attach, false,
3396 /*implicit=*/false, asyncValues, asyncDeviceTypes,
3397 asyncOnlyDeviceTypes);
3398 } else if (!std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3399 llvm::report_fatal_error(
3400 "Unknown clause in ENTER DATA directive lowering");
3401 }
3402 }
3403
3404 // Prepare the operand segment size attribute and the operands value range.
3405 llvm::SmallVector<mlir::Value, 16> operands;
3406 llvm::SmallVector<int32_t, 8> operandSegments;
3407 addOperand(operands, operandSegments, clauseOperand: ifCond);
3408 addOperand(operands, operandSegments, clauseOperand: async);
3409 addOperand(operands, operandSegments, clauseOperand: waitDevnum);
3410 addOperands(operands, operandSegments, clauseOperands: waitOperands);
3411 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
3412
3413 mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
3414 firOpBuilder, currentLocation, operands, operandSegments);
3415
3416 if (addAsyncAttr)
3417 enterDataOp.setAsyncAttr(firOpBuilder.getUnitAttr());
3418 if (addWaitAttr)
3419 enterDataOp.setWaitAttr(firOpBuilder.getUnitAttr());
3420}
3421
3422static void
3423genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
3424 mlir::Location currentLocation,
3425 Fortran::semantics::SemanticsContext &semanticsContext,
3426 Fortran::lower::StatementContext &stmtCtx,
3427 const Fortran::parser::AccClauseList &accClauseList) {
3428 mlir::Value ifCond, async, waitDevnum;
3429 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands,
3430 copyoutOperands, deleteOperands, detachOperands;
3431
3432 // Async and wait clause have optional values but can be present with
3433 // no value as well. When there is no value, the op has an attribute to
3434 // represent the clause.
3435 bool addAsyncAttr = false;
3436 bool addWaitAttr = false;
3437 bool addFinalizeAttr = false;
3438
3439 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3440
3441 // Lower clauses values mapped to operands.
3442 // Keep track of each group of operands separately as clauses can appear
3443 // more than once.
3444
3445 // Process the async clause first.
3446 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3447 if (const auto *asyncClause =
3448 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3449 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3450 }
3451 }
3452
3453 // The async clause of 'exit data' applies to all device types,
3454 // so propagate the async clause to copyin/create/attach ops
3455 // as if it is an async clause without preceding device_type clause.
3456 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes;
3457 llvm::SmallVector<mlir::Value> asyncValues;
3458 auto noneDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
3459 builder.getContext(), mlir::acc::DeviceType::None);
3460 if (addAsyncAttr) {
3461 asyncOnlyDeviceTypes.push_back(Elt: noneDeviceTypeAttr);
3462 } else if (async) {
3463 asyncValues.push_back(Elt: async);
3464 asyncDeviceTypes.push_back(Elt: noneDeviceTypeAttr);
3465 }
3466
3467 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3468 mlir::Location clauseLocation = converter.genLocation(clause.source);
3469 if (const auto *ifClause =
3470 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3471 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3472 } else if (const auto *waitClause =
3473 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3474 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
3475 addWaitAttr, stmtCtx);
3476 } else if (const auto *copyoutClause =
3477 std::get_if<Fortran::parser::AccClause::Copyout>(
3478 &clause.u)) {
3479 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3480 copyoutClause->v;
3481 const auto &accObjectList =
3482 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3483 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3484 accObjectList, converter, semanticsContext, stmtCtx, copyoutOperands,
3485 mlir::acc::DataClause::acc_copyout, false, /*implicit=*/false,
3486 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes);
3487 } else if (const auto *deleteClause =
3488 std::get_if<Fortran::parser::AccClause::Delete>(&clause.u)) {
3489 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3490 deleteClause->v, converter, semanticsContext, stmtCtx, deleteOperands,
3491 mlir::acc::DataClause::acc_delete, false, /*implicit=*/false,
3492 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes);
3493 } else if (const auto *detachClause =
3494 std::get_if<Fortran::parser::AccClause::Detach>(&clause.u)) {
3495 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3496 detachClause->v, converter, semanticsContext, stmtCtx, detachOperands,
3497 mlir::acc::DataClause::acc_detach, false, /*implicit=*/false,
3498 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes);
3499 } else if (std::get_if<Fortran::parser::AccClause::Finalize>(&clause.u)) {
3500 addFinalizeAttr = true;
3501 }
3502 }
3503
3504 dataClauseOperands.append(RHS: copyoutOperands);
3505 dataClauseOperands.append(RHS: deleteOperands);
3506 dataClauseOperands.append(RHS: detachOperands);
3507
3508 // Prepare the operand segment size attribute and the operands value range.
3509 llvm::SmallVector<mlir::Value, 14> operands;
3510 llvm::SmallVector<int32_t, 7> operandSegments;
3511 addOperand(operands, operandSegments, clauseOperand: ifCond);
3512 addOperand(operands, operandSegments, clauseOperand: async);
3513 addOperand(operands, operandSegments, clauseOperand: waitDevnum);
3514 addOperands(operands, operandSegments, clauseOperands: waitOperands);
3515 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
3516
3517 mlir::acc::ExitDataOp exitDataOp = createSimpleOp<mlir::acc::ExitDataOp>(
3518 builder, currentLocation, operands, operandSegments);
3519
3520 if (addAsyncAttr)
3521 exitDataOp.setAsyncAttr(builder.getUnitAttr());
3522 if (addWaitAttr)
3523 exitDataOp.setWaitAttr(builder.getUnitAttr());
3524 if (addFinalizeAttr)
3525 exitDataOp.setFinalizeAttr(builder.getUnitAttr());
3526
3527 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::CopyoutOp>(
3528 builder, copyoutOperands, /*structured=*/false);
3529 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DeleteOp>(
3530 builder, deleteOperands, /*structured=*/false);
3531 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DetachOp>(
3532 builder, detachOperands, /*structured=*/false);
3533}
3534
3535template <typename Op>
3536static void
3537genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
3538 mlir::Location currentLocation,
3539 const Fortran::parser::AccClauseList &accClauseList) {
3540 mlir::Value ifCond, deviceNum;
3541
3542 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3543 Fortran::lower::StatementContext stmtCtx;
3544 llvm::SmallVector<mlir::Attribute> deviceTypes;
3545
3546 // Lower clauses values mapped to operands.
3547 // Keep track of each group of operands separately as clauses can appear
3548 // more than once.
3549 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3550 mlir::Location clauseLocation = converter.genLocation(clause.source);
3551 if (const auto *ifClause =
3552 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3553 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3554 } else if (const auto *deviceNumClause =
3555 std::get_if<Fortran::parser::AccClause::DeviceNum>(
3556 &clause.u)) {
3557 deviceNum = fir::getBase(converter.genExprValue(
3558 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
3559 } else if (const auto *deviceTypeClause =
3560 std::get_if<Fortran::parser::AccClause::DeviceType>(
3561 &clause.u)) {
3562 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
3563 }
3564 }
3565
3566 // Prepare the operand segment size attribute and the operands value range.
3567 llvm::SmallVector<mlir::Value, 6> operands;
3568 llvm::SmallVector<int32_t, 2> operandSegments;
3569
3570 addOperand(operands, operandSegments, clauseOperand: deviceNum);
3571 addOperand(operands, operandSegments, clauseOperand: ifCond);
3572
3573 Op op =
3574 createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
3575 if (!deviceTypes.empty())
3576 op.setDeviceTypesAttr(
3577 mlir::ArrayAttr::get(context: builder.getContext(), value: deviceTypes));
3578}
3579
3580void genACCSetOp(Fortran::lower::AbstractConverter &converter,
3581 mlir::Location currentLocation,
3582 const Fortran::parser::AccClauseList &accClauseList) {
3583 mlir::Value ifCond, deviceNum, defaultAsync;
3584 llvm::SmallVector<mlir::Value> deviceTypeOperands;
3585
3586 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3587 Fortran::lower::StatementContext stmtCtx;
3588 llvm::SmallVector<mlir::Attribute> deviceTypes;
3589
3590 // Lower clauses values mapped to operands.
3591 // Keep track of each group of operands separately as clauses can appear
3592 // more than once.
3593 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3594 mlir::Location clauseLocation = converter.genLocation(clause.source);
3595 if (const auto *ifClause =
3596 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3597 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3598 } else if (const auto *defaultAsyncClause =
3599 std::get_if<Fortran::parser::AccClause::DefaultAsync>(
3600 &clause.u)) {
3601 defaultAsync = fir::getBase(converter.genExprValue(
3602 *Fortran::semantics::GetExpr(defaultAsyncClause->v), stmtCtx));
3603 } else if (const auto *deviceNumClause =
3604 std::get_if<Fortran::parser::AccClause::DeviceNum>(
3605 &clause.u)) {
3606 deviceNum = fir::getBase(converter.genExprValue(
3607 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
3608 } else if (const auto *deviceTypeClause =
3609 std::get_if<Fortran::parser::AccClause::DeviceType>(
3610 &clause.u)) {
3611 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
3612 }
3613 }
3614
3615 // Prepare the operand segment size attribute and the operands value range.
3616 llvm::SmallVector<mlir::Value> operands;
3617 llvm::SmallVector<int32_t, 3> operandSegments;
3618 addOperand(operands, operandSegments, clauseOperand: defaultAsync);
3619 addOperand(operands, operandSegments, clauseOperand: deviceNum);
3620 addOperand(operands, operandSegments, clauseOperand: ifCond);
3621
3622 auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
3623 operandSegments);
3624 if (!deviceTypes.empty()) {
3625 assert(deviceTypes.size() == 1 && "expect only one value for acc.set");
3626 op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
3627 }
3628}
3629
3630static inline mlir::ArrayAttr
3631getArrayAttr(fir::FirOpBuilder &b,
3632 llvm::SmallVector<mlir::Attribute> &attributes) {
3633 return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
3634}
3635
3636static inline mlir::ArrayAttr
3637getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
3638 return values.empty() ? nullptr : b.getBoolArrayAttr(values);
3639}
3640
3641static inline mlir::DenseI32ArrayAttr
3642getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
3643 llvm::SmallVector<int32_t> &values) {
3644 return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
3645}
3646
3647static void
3648genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
3649 mlir::Location currentLocation,
3650 Fortran::semantics::SemanticsContext &semanticsContext,
3651 Fortran::lower::StatementContext &stmtCtx,
3652 const Fortran::parser::AccClauseList &accClauseList) {
3653 mlir::Value ifCond;
3654 llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
3655 waitOperands, deviceTypeOperands, asyncOperands;
3656 llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
3657 asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
3658 llvm::SmallVector<bool> hasWaitDevnums;
3659 llvm::SmallVector<int32_t> waitOperandsSegments;
3660
3661 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3662
3663 // device_type attribute is set to `none` until a device_type clause is
3664 // encountered.
3665 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
3666 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
3667 builder.getContext(), mlir::acc::DeviceType::None));
3668
3669 bool ifPresent = false;
3670
3671 // Lower clauses values mapped to operands and array attributes.
3672 // Keep track of each group of operands separately as clauses can appear
3673 // more than once.
3674
3675 // Process the clauses that may have a specified device_type first.
3676 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3677 if (const auto *asyncClause =
3678 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3679 genAsyncClause(converter, asyncClause, asyncOperands,
3680 asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
3681 crtDeviceTypes, stmtCtx);
3682 } else if (const auto *waitClause =
3683 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3684 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
3685 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3686 hasWaitDevnums, waitOperandsSegments,
3687 crtDeviceTypes, stmtCtx);
3688 } else if (const auto *deviceTypeClause =
3689 std::get_if<Fortran::parser::AccClause::DeviceType>(
3690 &clause.u)) {
3691 crtDeviceTypes.clear();
3692 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
3693 }
3694 }
3695
3696 // Process the clauses independent of device_type.
3697 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3698 mlir::Location clauseLocation = converter.genLocation(clause.source);
3699 if (const auto *ifClause =
3700 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3701 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3702 } else if (const auto *hostClause =
3703 std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
3704 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3705 hostClause->v, converter, semanticsContext, stmtCtx,
3706 updateHostOperands, mlir::acc::DataClause::acc_update_host, false,
3707 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes,
3708 asyncOnlyDeviceTypes);
3709 } else if (const auto *deviceClause =
3710 std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
3711 genDataOperandOperations<mlir::acc::UpdateDeviceOp>(
3712 deviceClause->v, converter, semanticsContext, stmtCtx,
3713 dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
3714 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes,
3715 asyncOnlyDeviceTypes);
3716 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
3717 ifPresent = true;
3718 } else if (const auto *selfClause =
3719 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
3720 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
3721 selfClause->v;
3722 const auto *accObjectList =
3723 std::get_if<Fortran::parser::AccObjectList>(&(*accSelfClause).u);
3724 assert(accObjectList && "expect AccObjectList");
3725 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3726 *accObjectList, converter, semanticsContext, stmtCtx,
3727 updateHostOperands, mlir::acc::DataClause::acc_update_self, false,
3728 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes,
3729 asyncOnlyDeviceTypes);
3730 }
3731 }
3732
3733 dataClauseOperands.append(RHS: updateHostOperands);
3734
3735 builder.create<mlir::acc::UpdateOp>(
3736 currentLocation, ifCond, asyncOperands,
3737 getArrayAttr(builder, asyncOperandsDeviceTypes),
3738 getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
3739 getDenseI32ArrayAttr(builder, waitOperandsSegments),
3740 getArrayAttr(builder, waitOperandsDeviceTypes),
3741 getBoolArrayAttr(builder, hasWaitDevnums),
3742 getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
3743 ifPresent);
3744
3745 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
3746 builder, updateHostOperands, /*structured=*/false);
3747}
3748
3749static void
3750genACC(Fortran::lower::AbstractConverter &converter,
3751 Fortran::semantics::SemanticsContext &semanticsContext,
3752 const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) {
3753 const auto &standaloneDirective =
3754 std::get<Fortran::parser::AccStandaloneDirective>(standaloneConstruct.t);
3755 const auto &accClauseList =
3756 std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t);
3757
3758 mlir::Location currentLocation =
3759 converter.genLocation(standaloneDirective.source);
3760 Fortran::lower::StatementContext stmtCtx;
3761
3762 if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) {
3763 genACCEnterDataOp(converter, currentLocation, semanticsContext, stmtCtx,
3764 accClauseList);
3765 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) {
3766 genACCExitDataOp(converter, currentLocation, semanticsContext, stmtCtx,
3767 accClauseList);
3768 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) {
3769 genACCInitShutdownOp<mlir::acc::InitOp>(converter, currentLocation,
3770 accClauseList);
3771 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) {
3772 genACCInitShutdownOp<mlir::acc::ShutdownOp>(converter, currentLocation,
3773 accClauseList);
3774 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
3775 genACCSetOp(converter, currentLocation, accClauseList);
3776 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
3777 genACCUpdateOp(converter, currentLocation, semanticsContext, stmtCtx,
3778 accClauseList);
3779 }
3780}
3781
3782static void genACC(Fortran::lower::AbstractConverter &converter,
3783 const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
3784
3785 const auto &waitArgument =
3786 std::get<std::optional<Fortran::parser::AccWaitArgument>>(
3787 waitConstruct.t);
3788 const auto &accClauseList =
3789 std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
3790
3791 mlir::Value ifCond, waitDevnum, async;
3792 llvm::SmallVector<mlir::Value> waitOperands;
3793
3794 // Async clause have optional values but can be present with
3795 // no value as well. When there is no value, the op has an attribute to
3796 // represent the clause.
3797 bool addAsyncAttr = false;
3798
3799 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3800 mlir::Location currentLocation = converter.genLocation(waitConstruct.source);
3801 Fortran::lower::StatementContext stmtCtx;
3802
3803 if (waitArgument) { // wait has a value.
3804 const Fortran::parser::AccWaitArgument &waitArg = *waitArgument;
3805 const auto &waitList =
3806 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
3807 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
3808 mlir::Value v = fir::getBase(
3809 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
3810 waitOperands.push_back(v);
3811 }
3812
3813 const auto &waitDevnumValue =
3814 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
3815 if (waitDevnumValue)
3816 waitDevnum = fir::getBase(converter.genExprValue(
3817 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
3818 }
3819
3820 // Lower clauses values mapped to operands.
3821 // Keep track of each group of operands separately as clauses can appear
3822 // more than once.
3823 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3824 mlir::Location clauseLocation = converter.genLocation(clause.source);
3825 if (const auto *ifClause =
3826 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3827 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3828 } else if (const auto *asyncClause =
3829 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3830 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3831 }
3832 }
3833
3834 // Prepare the operand segment size attribute and the operands value range.
3835 llvm::SmallVector<mlir::Value> operands;
3836 llvm::SmallVector<int32_t> operandSegments;
3837 addOperands(operands, operandSegments, clauseOperands: waitOperands);
3838 addOperand(operands, operandSegments, clauseOperand: async);
3839 addOperand(operands, operandSegments, clauseOperand: waitDevnum);
3840 addOperand(operands, operandSegments, clauseOperand: ifCond);
3841
3842 mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
3843 firOpBuilder, currentLocation, operands, operandSegments);
3844
3845 if (addAsyncAttr)
3846 waitOp.setAsyncAttr(firOpBuilder.getUnitAttr());
3847}
3848
3849template <typename GlobalOp, typename EntryOp, typename DeclareOp,
3850 typename ExitOp>
3851static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
3852 fir::FirOpBuilder &builder,
3853 mlir::Location loc, fir::GlobalOp globalOp,
3854 mlir::acc::DataClause clause,
3855 const std::string &declareGlobalName,
3856 bool implicit, std::stringstream &asFortran) {
3857 GlobalOp declareGlobalOp =
3858 modBuilder.create<GlobalOp>(loc, declareGlobalName);
3859 builder.createBlock(&declareGlobalOp.getRegion(),
3860 declareGlobalOp.getRegion().end(), {}, {});
3861 builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back());
3862
3863 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3864 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3865 addDeclareAttr(builder, addrOp, clause);
3866
3867 llvm::SmallVector<mlir::Value> bounds;
3868 EntryOp entryOp = createDataEntryOp<EntryOp>(
3869 builder, loc, addrOp.getResTy(), asFortran, bounds,
3870 /*structured=*/false, implicit, clause, addrOp.getResTy().getType(),
3871 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3872 if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
3873 builder.create<DeclareOp>(
3874 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3875 mlir::ValueRange(entryOp.getAccVar()));
3876 else
3877 builder.create<DeclareOp>(loc, mlir::Value{},
3878 mlir::ValueRange(entryOp.getAccVar()));
3879 if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
3880 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
3881 entryOp.getBounds(), entryOp.getAsyncOperands(),
3882 entryOp.getAsyncOperandsDeviceTypeAttr(),
3883 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
3884 /*structured=*/false, /*implicit=*/false,
3885 builder.getStringAttr(*entryOp.getName()));
3886 }
3887 builder.create<mlir::acc::TerminatorOp>(loc);
3888 modBuilder.setInsertionPointAfter(declareGlobalOp);
3889}
3890
3891template <typename EntryOp>
3892static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
3893 fir::FirOpBuilder &builder,
3894 mlir::Location loc, fir::GlobalOp &globalOp,
3895 mlir::acc::DataClause clause) {
3896 std::stringstream registerFuncName;
3897 registerFuncName << globalOp.getSymName().str()
3898 << Fortran::lower::declarePostAllocSuffix.str();
3899 auto registerFuncOp =
3900 createDeclareFunc(modBuilder, builder, loc, registerFuncName.str());
3901
3902 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3903 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3904
3905 std::stringstream asFortran;
3906 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3907 std::stringstream asFortranDesc;
3908 asFortranDesc << asFortran.str();
3909 if (unwrapFirBox)
3910 asFortranDesc << accFirDescriptorPostfix.str();
3911 llvm::SmallVector<mlir::Value> bounds;
3912
3913 // Updating descriptor must occur before the mapping of the data so that
3914 // attached data pointer is not overwritten.
3915 mlir::acc::UpdateDeviceOp updateDeviceOp =
3916 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3917 builder, loc, addrOp, asFortranDesc, bounds,
3918 /*structured=*/false, /*implicit=*/true,
3919 mlir::acc::DataClause::acc_update_device, addrOp.getType(),
3920 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3921 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
3922 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3923 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3924
3925 if (unwrapFirBox) {
3926 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3927 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3928 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
3929 EntryOp entryOp = createDataEntryOp<EntryOp>(
3930 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
3931 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
3932 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3933 builder.create<mlir::acc::DeclareEnterOp>(
3934 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3935 mlir::ValueRange(entryOp.getAccVar()));
3936 }
3937
3938 modBuilder.setInsertionPointAfter(registerFuncOp);
3939}
3940
3941/// Action to be performed on deallocation are split in two distinct functions.
3942/// - Pre deallocation function includes all the action to be performed before
3943/// the actual deallocation is done on the host side.
3944/// - Post deallocation function includes update to the descriptor.
3945template <typename ExitOp>
3946static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
3947 fir::FirOpBuilder &builder,
3948 mlir::Location loc,
3949 fir::GlobalOp &globalOp,
3950 mlir::acc::DataClause clause) {
3951 std::stringstream asFortran;
3952 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3953
3954 // If FIR box semantics are being unwrapped, then a pre-dealloc function
3955 // needs generated to ensure to delete the device data pointed to by the
3956 // descriptor before this information is lost.
3957 if (unwrapFirBox) {
3958 // Generate the pre dealloc function.
3959 std::stringstream preDeallocFuncName;
3960 preDeallocFuncName << globalOp.getSymName().str()
3961 << Fortran::lower::declarePreDeallocSuffix.str();
3962 auto preDeallocOp =
3963 createDeclareFunc(modBuilder, builder, loc, preDeallocFuncName.str());
3964
3965 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3966 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3967 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3968 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3969 mlir::Value var = boxAddrOp.getResult();
3970 addDeclareAttr(builder, var.getDefiningOp(), clause);
3971
3972 llvm::SmallVector<mlir::Value> bounds;
3973 mlir::acc::GetDevicePtrOp entryOp =
3974 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
3975 builder, loc, var, asFortran, bounds,
3976 /*structured=*/false, /*implicit=*/false, clause, var.getType(),
3977 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3978
3979 builder.create<mlir::acc::DeclareExitOp>(
3980 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccVar()));
3981
3982 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
3983 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
3984 builder.create<ExitOp>(
3985 entryOp.getLoc(), entryOp.getAccVar(), entryOp.getVar(),
3986 entryOp.getBounds(), entryOp.getAsyncOperands(),
3987 entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
3988 entryOp.getDataClause(),
3989 /*structured=*/false, /*implicit=*/false,
3990 builder.getStringAttr(*entryOp.getName()));
3991 else
3992 builder.create<ExitOp>(
3993 entryOp.getLoc(), entryOp.getAccVar(), entryOp.getBounds(),
3994 entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(),
3995 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
3996 /*structured=*/false, /*implicit=*/false,
3997 builder.getStringAttr(*entryOp.getName()));
3998
3999 // Generate the post dealloc function.
4000 modBuilder.setInsertionPointAfter(preDeallocOp);
4001 }
4002
4003 std::stringstream postDeallocFuncName;
4004 postDeallocFuncName << globalOp.getSymName().str()
4005 << Fortran::lower::declarePostDeallocSuffix.str();
4006 auto postDeallocOp =
4007 createDeclareFunc(modBuilder, builder, loc, postDeallocFuncName.str());
4008
4009 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
4010 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
4011 if (unwrapFirBox)
4012 asFortran << accFirDescriptorPostfix.str();
4013 llvm::SmallVector<mlir::Value> bounds;
4014 mlir::acc::UpdateDeviceOp updateDeviceOp =
4015 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
4016 builder, loc, addrOp, asFortran, bounds,
4017 /*structured=*/false, /*implicit=*/true,
4018 mlir::acc::DataClause::acc_update_device, addrOp.getType(),
4019 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
4020 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
4021 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
4022 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
4023 modBuilder.setInsertionPointAfter(postDeallocOp);
4024}
4025
4026template <typename EntryOp, typename ExitOp>
4027static void genGlobalCtors(Fortran::lower::AbstractConverter &converter,
4028 mlir::OpBuilder &modBuilder,
4029 const Fortran::parser::AccObjectList &accObjectList,
4030 mlir::acc::DataClause clause) {
4031 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4032 auto genCtors = [&](const mlir::Location operandLocation,
4033 const Fortran::semantics::Symbol &symbol) {
4034 std::string globalName = converter.mangleName(symbol);
4035 fir::GlobalOp globalOp = builder.getNamedGlobal(globalName);
4036 std::stringstream declareGlobalCtorName;
4037 declareGlobalCtorName << globalName << "_acc_ctor";
4038 std::stringstream declareGlobalDtorName;
4039 declareGlobalDtorName << globalName << "_acc_dtor";
4040 std::stringstream asFortran;
4041 asFortran << symbol.name().ToString();
4042
4043 if (builder.getModule().lookupSymbol<mlir::acc::GlobalConstructorOp>(
4044 declareGlobalCtorName.str()))
4045 return;
4046
4047 if (!globalOp) {
4048 if (Fortran::semantics::FindEquivalenceSet(symbol)) {
4049 for (Fortran::semantics::EquivalenceObject eqObj :
4050 *Fortran::semantics::FindEquivalenceSet(symbol)) {
4051 std::string eqName = converter.mangleName(eqObj.symbol);
4052 globalOp = builder.getNamedGlobal(eqName);
4053 if (globalOp)
4054 break;
4055 }
4056
4057 if (!globalOp)
4058 llvm::report_fatal_error(reason: "could not retrieve global symbol");
4059 } else {
4060 llvm::report_fatal_error(reason: "could not retrieve global symbol");
4061 }
4062 }
4063
4064 addDeclareAttr(builder, globalOp.getOperation(), clause);
4065 auto crtPos = builder.saveInsertionPoint();
4066 modBuilder.setInsertionPointAfter(globalOp);
4067 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(globalOp.getType()))) {
4068 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, mlir::acc::CopyinOp,
4069 mlir::acc::DeclareEnterOp, ExitOp>(
4070 modBuilder, builder, operandLocation, globalOp, clause,
4071 declareGlobalCtorName.str(), /*implicit=*/true, asFortran);
4072 createDeclareAllocFunc<EntryOp>(modBuilder, builder, operandLocation,
4073 globalOp, clause);
4074 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
4075 createDeclareDeallocFunc<ExitOp>(modBuilder, builder, operandLocation,
4076 globalOp, clause);
4077 } else {
4078 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp,
4079 mlir::acc::DeclareEnterOp, ExitOp>(
4080 modBuilder, builder, operandLocation, globalOp, clause,
4081 declareGlobalCtorName.str(), /*implicit=*/false, asFortran);
4082 }
4083 if constexpr (!std::is_same_v<EntryOp, ExitOp>) {
4084 createDeclareGlobalOp<mlir::acc::GlobalDestructorOp,
4085 mlir::acc::GetDevicePtrOp, mlir::acc::DeclareExitOp,
4086 ExitOp>(
4087 modBuilder, builder, operandLocation, globalOp, clause,
4088 declareGlobalDtorName.str(), /*implicit=*/false, asFortran);
4089 }
4090 builder.restoreInsertionPoint(crtPos);
4091 };
4092 for (const auto &accObject : accObjectList.v) {
4093 mlir::Location operandLocation = genOperandLocation(converter, accObject);
4094 Fortran::common::visit(
4095 Fortran::common::visitors{
4096 [&](const Fortran::parser::Designator &designator) {
4097 if (const auto *name =
4098 Fortran::semantics::getDesignatorNameIfDataRef(
4099 designator)) {
4100 genCtors(operandLocation, *name->symbol);
4101 }
4102 },
4103 [&](const Fortran::parser::Name &name) {
4104 if (const auto *symbol = name.symbol) {
4105 if (symbol
4106 ->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
4107 genCtors(operandLocation, *symbol);
4108 } else {
4109 TODO(operandLocation,
4110 "OpenACC Global Ctor from parser::Name");
4111 }
4112 }
4113 }},
4114 accObject.u);
4115 }
4116}
4117
4118template <typename Clause, typename EntryOp, typename ExitOp>
4119static void
4120genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter,
4121 mlir::OpBuilder &modBuilder, const Clause *x,
4122 Fortran::parser::AccDataModifier::Modifier mod,
4123 const mlir::acc::DataClause clause,
4124 const mlir::acc::DataClause clauseWithModifier) {
4125 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
4126 const auto &accObjectList =
4127 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4128 const auto &modifier =
4129 std::get<std::optional<Fortran::parser::AccDataModifier>>(
4130 listWithModifier.t);
4131 mlir::acc::DataClause dataClause =
4132 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
4133 genGlobalCtors<EntryOp, ExitOp>(converter, modBuilder, accObjectList,
4134 dataClause);
4135}
4136
4137static void
4138genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
4139 Fortran::semantics::SemanticsContext &semanticsContext,
4140 Fortran::lower::StatementContext &openAccCtx,
4141 mlir::Location loc,
4142 const Fortran::parser::AccClauseList &accClauseList) {
4143 llvm::SmallVector<mlir::Value> dataClauseOperands, copyEntryOperands,
4144 copyinEntryOperands, createEntryOperands, copyoutEntryOperands,
4145 presentEntryOperands, deviceResidentEntryOperands;
4146 Fortran::lower::StatementContext stmtCtx;
4147 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4148
4149 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
4150 if (const auto *copyClause =
4151 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
4152 auto crtDataStart = dataClauseOperands.size();
4153 genDeclareDataOperandOperations<mlir::acc::CopyinOp,
4154 mlir::acc::CopyoutOp>(
4155 copyClause->v, converter, semanticsContext, stmtCtx,
4156 dataClauseOperands, mlir::acc::DataClause::acc_copy,
4157 /*structured=*/true, /*implicit=*/false);
4158 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
4159 dataClauseOperands.end());
4160 } else if (const auto *createClause =
4161 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
4162 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
4163 createClause->v;
4164 const auto &accObjectList =
4165 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4166 auto crtDataStart = dataClauseOperands.size();
4167 genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
4168 accObjectList, converter, semanticsContext, stmtCtx,
4169 dataClauseOperands, mlir::acc::DataClause::acc_create,
4170 /*structured=*/true, /*implicit=*/false);
4171 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
4172 dataClauseOperands.end());
4173 } else if (const auto *presentClause =
4174 std::get_if<Fortran::parser::AccClause::Present>(
4175 &clause.u)) {
4176 auto crtDataStart = dataClauseOperands.size();
4177 genDeclareDataOperandOperations<mlir::acc::PresentOp,
4178 mlir::acc::DeleteOp>(
4179 presentClause->v, converter, semanticsContext, stmtCtx,
4180 dataClauseOperands, mlir::acc::DataClause::acc_present,
4181 /*structured=*/true, /*implicit=*/false);
4182 presentEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
4183 dataClauseOperands.end());
4184 } else if (const auto *copyinClause =
4185 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
4186 auto crtDataStart = dataClauseOperands.size();
4187 genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
4188 mlir::acc::DeleteOp>(
4189 copyinClause, converter, semanticsContext, stmtCtx,
4190 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
4191 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
4192 mlir::acc::DataClause::acc_copyin_readonly);
4193 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
4194 dataClauseOperands.end());
4195 } else if (const auto *copyoutClause =
4196 std::get_if<Fortran::parser::AccClause::Copyout>(
4197 &clause.u)) {
4198 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
4199 copyoutClause->v;
4200 const auto &accObjectList =
4201 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4202 auto crtDataStart = dataClauseOperands.size();
4203 genDeclareDataOperandOperations<mlir::acc::CreateOp,
4204 mlir::acc::CopyoutOp>(
4205 accObjectList, converter, semanticsContext, stmtCtx,
4206 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
4207 /*structured=*/true, /*implicit=*/false);
4208 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
4209 dataClauseOperands.end());
4210 } else if (const auto *devicePtrClause =
4211 std::get_if<Fortran::parser::AccClause::Deviceptr>(
4212 &clause.u)) {
4213 genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
4214 mlir::acc::DevicePtrOp>(
4215 devicePtrClause->v, converter, semanticsContext, stmtCtx,
4216 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
4217 /*structured=*/true, /*implicit=*/false);
4218 } else if (const auto *linkClause =
4219 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
4220 genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
4221 mlir::acc::DeclareLinkOp>(
4222 linkClause->v, converter, semanticsContext, stmtCtx,
4223 dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
4224 /*structured=*/true, /*implicit=*/false);
4225 } else if (const auto *deviceResidentClause =
4226 std::get_if<Fortran::parser::AccClause::DeviceResident>(
4227 &clause.u)) {
4228 auto crtDataStart = dataClauseOperands.size();
4229 genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
4230 mlir::acc::DeleteOp>(
4231 deviceResidentClause->v, converter, semanticsContext, stmtCtx,
4232 dataClauseOperands,
4233 mlir::acc::DataClause::acc_declare_device_resident,
4234 /*structured=*/true, /*implicit=*/false);
4235 deviceResidentEntryOperands.append(
4236 dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end());
4237 } else {
4238 mlir::Location clauseLocation = converter.genLocation(clause.source);
4239 TODO(clauseLocation, "clause on declare directive");
4240 }
4241 }
4242
4243 mlir::func::FuncOp funcOp = builder.getFunction();
4244 auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
4245 mlir::Value declareToken;
4246 if (ops.empty()) {
4247 declareToken = builder.create<mlir::acc::DeclareEnterOp>(
4248 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
4249 dataClauseOperands);
4250 } else {
4251 auto declareOp = *ops.begin();
4252 auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
4253 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
4254 declareOp.getDataClauseOperands());
4255 newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
4256 declareToken = newDeclareOp.getToken();
4257 declareOp.erase();
4258 }
4259
4260 openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
4261 copyEntryOperands, copyinEntryOperands,
4262 copyoutEntryOperands, presentEntryOperands,
4263 deviceResidentEntryOperands, declareToken]() {
4264 llvm::SmallVector<mlir::Value> operands;
4265 operands.append(RHS: createEntryOperands);
4266 operands.append(RHS: deviceResidentEntryOperands);
4267 operands.append(RHS: copyEntryOperands);
4268 operands.append(RHS: copyinEntryOperands);
4269 operands.append(RHS: copyoutEntryOperands);
4270 operands.append(RHS: presentEntryOperands);
4271
4272 mlir::func::FuncOp funcOp = builder.getFunction();
4273 auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
4274 if (ops.empty()) {
4275 builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
4276 } else {
4277 auto declareOp = *ops.begin();
4278 declareOp.getDataClauseOperandsMutable().append(operands);
4279 }
4280
4281 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
4282 builder, createEntryOperands, /*structured=*/true);
4283 genDataExitOperations<mlir::acc::DeclareDeviceResidentOp,
4284 mlir::acc::DeleteOp>(
4285 builder, deviceResidentEntryOperands, /*structured=*/true);
4286 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
4287 builder, copyEntryOperands, /*structured=*/true);
4288 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
4289 builder, copyinEntryOperands, /*structured=*/true);
4290 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
4291 builder, copyoutEntryOperands, /*structured=*/true);
4292 genDataExitOperations<mlir::acc::PresentOp, mlir::acc::DeleteOp>(
4293 builder, presentEntryOperands, /*structured=*/true);
4294 });
4295}
4296
4297static void
4298genDeclareInModule(Fortran::lower::AbstractConverter &converter,
4299 mlir::ModuleOp moduleOp,
4300 const Fortran::parser::AccClauseList &accClauseList) {
4301 mlir::OpBuilder modBuilder(moduleOp.getBodyRegion());
4302 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
4303 if (const auto *createClause =
4304 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
4305 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
4306 createClause->v;
4307 const auto &accObjectList =
4308 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4309 genGlobalCtors<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
4310 converter, modBuilder, accObjectList,
4311 mlir::acc::DataClause::acc_create);
4312 } else if (const auto *copyinClause =
4313 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
4314 genGlobalCtorsWithModifier<Fortran::parser::AccClause::Copyin,
4315 mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
4316 converter, modBuilder, copyinClause,
4317 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
4318 mlir::acc::DataClause::acc_copyin,
4319 mlir::acc::DataClause::acc_copyin_readonly);
4320 } else if (const auto *deviceResidentClause =
4321 std::get_if<Fortran::parser::AccClause::DeviceResident>(
4322 &clause.u)) {
4323 genGlobalCtors<mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeleteOp>(
4324 converter, modBuilder, deviceResidentClause->v,
4325 mlir::acc::DataClause::acc_declare_device_resident);
4326 } else if (const auto *linkClause =
4327 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
4328 genGlobalCtors<mlir::acc::DeclareLinkOp, mlir::acc::DeclareLinkOp>(
4329 converter, modBuilder, linkClause->v,
4330 mlir::acc::DataClause::acc_declare_link);
4331 } else {
4332 llvm::report_fatal_error("unsupported clause on DECLARE directive");
4333 }
4334 }
4335}
4336
4337static void genACC(Fortran::lower::AbstractConverter &converter,
4338 Fortran::semantics::SemanticsContext &semanticsContext,
4339 Fortran::lower::StatementContext &openAccCtx,
4340 const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
4341 &declareConstruct) {
4342
4343 const auto &declarativeDir =
4344 std::get<Fortran::parser::AccDeclarativeDirective>(declareConstruct.t);
4345 mlir::Location directiveLocation =
4346 converter.genLocation(declarativeDir.source);
4347 const auto &accClauseList =
4348 std::get<Fortran::parser::AccClauseList>(declareConstruct.t);
4349
4350 if (declarativeDir.v == llvm::acc::Directive::ACCD_declare) {
4351 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4352 auto moduleOp =
4353 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
4354 auto funcOp =
4355 builder.getBlock()->getParent()->getParentOfType<mlir::func::FuncOp>();
4356 if (funcOp)
4357 genDeclareInFunction(converter, semanticsContext, openAccCtx,
4358 directiveLocation, accClauseList);
4359 else if (moduleOp)
4360 genDeclareInModule(converter, moduleOp, accClauseList);
4361 return;
4362 }
4363 llvm_unreachable("unsupported declarative directive");
4364}
4365
4366static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
4367 mlir::acc::DeviceType deviceType) {
4368 for (auto attr : arrayAttr) {
4369 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4370 if (deviceTypeAttr.getValue() == deviceType)
4371 return true;
4372 }
4373 return false;
4374}
4375
4376template <typename RetTy, typename AttrTy>
4377static std::optional<RetTy>
4378getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
4379 llvm::SmallVector<mlir::Attribute> &deviceTypes,
4380 mlir::acc::DeviceType deviceType) {
4381 assert(attributes.size() == deviceTypes.size() &&
4382 "expect same number of attributes");
4383 for (auto it : llvm::enumerate(First&: deviceTypes)) {
4384 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
4385 if (deviceTypeAttr.getValue() == deviceType) {
4386 if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
4387 auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
4388 return strAttr.getValue();
4389 } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
4390 auto intAttr =
4391 mlir::dyn_cast<mlir::IntegerAttr>(Val&: attributes[it.index()]);
4392 return intAttr.getInt();
4393 }
4394 }
4395 }
4396 return std::nullopt;
4397}
4398
4399static bool compareDeviceTypeInfo(
4400 mlir::acc::RoutineOp op,
4401 llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4402 llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4403 llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
4404 llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
4405 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
4406 llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
4407 llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
4408 llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
4409 for (uint32_t dtypeInt = 0;
4410 dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
4411 auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
4412 if (op.getBindNameValue(dtype) !=
4413 getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4414 bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4415 return false;
4416 if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
4417 return false;
4418 if (op.getGangDimValue(dtype) !=
4419 getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
4420 gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
4421 return false;
4422 if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
4423 return false;
4424 if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
4425 return false;
4426 if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
4427 return false;
4428 }
4429 return true;
4430}
4431
4432static void attachRoutineInfo(mlir::func::FuncOp func,
4433 mlir::SymbolRefAttr routineAttr) {
4434 llvm::SmallVector<mlir::SymbolRefAttr> routines;
4435 if (func.getOperation()->hasAttr(mlir::acc::getRoutineInfoAttrName())) {
4436 auto routineInfo =
4437 func.getOperation()->getAttrOfType<mlir::acc::RoutineInfoAttr>(
4438 mlir::acc::getRoutineInfoAttrName());
4439 routines.append(routineInfo.getAccRoutines().begin(),
4440 routineInfo.getAccRoutines().end());
4441 }
4442 routines.push_back(Elt: routineAttr);
4443 func.getOperation()->setAttr(
4444 mlir::acc::getRoutineInfoAttrName(),
4445 mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
4446}
4447
4448static mlir::ArrayAttr
4449getArrayAttrOrNull(fir::FirOpBuilder &builder,
4450 llvm::SmallVector<mlir::Attribute> &attributes) {
4451 if (attributes.empty()) {
4452 return nullptr;
4453 } else {
4454 return builder.getArrayAttr(attributes);
4455 }
4456}
4457
4458void createOpenACCRoutineConstruct(
4459 Fortran::lower::AbstractConverter &converter, mlir::Location loc,
4460 mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4461 bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4462 llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4463 llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
4464 llvm::SmallVector<mlir::Attribute> &gangDimValues,
4465 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
4466 llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
4467 llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4468 llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes) {
4469
4470 for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
4471 if (routineOp.getFuncName().getLeafReference().str().compare(funcName) ==
4472 0) {
4473 // If the routine is already specified with the same clauses, just skip
4474 // the operation creation.
4475 if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
4476 gangDeviceTypes, gangDimValues,
4477 gangDimDeviceTypes, seqDeviceTypes,
4478 workerDeviceTypes, vectorDeviceTypes) &&
4479 routineOp.getNohost() == hasNohost)
4480 return;
4481 mlir::emitError(loc, "Routine already specified with different clauses");
4482 }
4483 }
4484 std::stringstream routineOpName;
4485 routineOpName << accRoutinePrefix.str() << routineCounter++;
4486 std::string routineOpStr = routineOpName.str();
4487 mlir::OpBuilder modBuilder(mod.getBodyRegion());
4488 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4489 modBuilder.create<mlir::acc::RoutineOp>(
4490 loc, routineOpStr,
4491 mlir::SymbolRefAttr::get(builder.getContext(), funcName),
4492 getArrayAttrOrNull(builder, bindNames),
4493 getArrayAttrOrNull(builder, bindNameDeviceTypes),
4494 getArrayAttrOrNull(builder, workerDeviceTypes),
4495 getArrayAttrOrNull(builder, vectorDeviceTypes),
4496 getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
4497 /*implicit=*/false, getArrayAttrOrNull(builder, gangDeviceTypes),
4498 getArrayAttrOrNull(builder, gangDimValues),
4499 getArrayAttrOrNull(builder, gangDimDeviceTypes));
4500
4501 attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpStr));
4502}
4503
4504static void interpretRoutineDeviceInfo(
4505 Fortran::lower::AbstractConverter &converter,
4506 const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo,
4507 llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
4508 llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
4509 llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4510 llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4511 llvm::SmallVector<mlir::Attribute> &bindNames,
4512 llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
4513 llvm::SmallVector<mlir::Attribute> &gangDimValues,
4514 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
4515 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4516 auto getDeviceTypeAttr = [&]() -> mlir::Attribute {
4517 auto context = builder.getContext();
4518 auto value = getDeviceType(dinfo.dType());
4519 return mlir::acc::DeviceTypeAttr::get(context, value);
4520 };
4521 if (dinfo.isSeq()) {
4522 seqDeviceTypes.push_back(Elt: getDeviceTypeAttr());
4523 }
4524 if (dinfo.isVector()) {
4525 vectorDeviceTypes.push_back(Elt: getDeviceTypeAttr());
4526 }
4527 if (dinfo.isWorker()) {
4528 workerDeviceTypes.push_back(Elt: getDeviceTypeAttr());
4529 }
4530 if (dinfo.isGang()) {
4531 unsigned gangDim = dinfo.gangDim();
4532 auto deviceType = getDeviceTypeAttr();
4533 if (!gangDim) {
4534 gangDeviceTypes.push_back(Elt: deviceType);
4535 } else {
4536 gangDimValues.push_back(
4537 Elt: builder.getIntegerAttr(builder.getI64Type(), gangDim));
4538 gangDimDeviceTypes.push_back(Elt: deviceType);
4539 }
4540 }
4541 if (dinfo.bindNameOpt().has_value()) {
4542 const auto &bindName = dinfo.bindNameOpt().value();
4543 mlir::Attribute bindNameAttr;
4544 if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4545 bindNameAttr = builder.getStringAttr(*bindStr);
4546 } else if (const auto &bindSym{
4547 std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4548 bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
4549 } else {
4550 llvm_unreachable("Unsupported bind name type");
4551 }
4552 bindNames.push_back(Elt: bindNameAttr);
4553 bindNameDeviceTypes.push_back(Elt: getDeviceTypeAttr());
4554 }
4555}
4556
4557void Fortran::lower::genOpenACCRoutineConstruct(
4558 Fortran::lower::AbstractConverter &converter, mlir::ModuleOp mod,
4559 mlir::func::FuncOp funcOp,
4560 const std::vector<Fortran::semantics::OpenACCRoutineInfo> &routineInfos) {
4561 CHECK(funcOp && "Expected a valid function operation");
4562 mlir::Location loc{funcOp.getLoc()};
4563 std::string funcName{funcOp.getName()};
4564
4565 // Collect the routine clauses
4566 bool hasNohost{false};
4567
4568 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4569 workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4570 gangDimDeviceTypes, gangDimValues;
4571
4572 for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
4573 // Device Independent Attributes
4574 if (info.isNohost()) {
4575 hasNohost = true;
4576 }
4577 // Note: Device Independent Attributes are set to the
4578 // none device type in `info`.
4579 interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
4580 vectorDeviceTypes, workerDeviceTypes,
4581 bindNameDeviceTypes, bindNames, gangDeviceTypes,
4582 gangDimValues, gangDimDeviceTypes);
4583
4584 // Device Dependent Attributes
4585 for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
4586 info.deviceTypeInfos()) {
4587 interpretRoutineDeviceInfo(
4588 converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4589 workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4590 gangDimValues, gangDimDeviceTypes);
4591 }
4592 }
4593 createOpenACCRoutineConstruct(
4594 converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4595 bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4596 seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4597}
4598
4599static void
4600genACC(Fortran::lower::AbstractConverter &converter,
4601 Fortran::lower::pft::Evaluation &eval,
4602 const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
4603
4604 mlir::Location loc = converter.genLocation(atomicConstruct.source);
4605 Fortran::common::visit(
4606 Fortran::common::visitors{
4607 [&](const Fortran::parser::AccAtomicRead &atomicRead) {
4608 genAtomicRead(converter, atomicRead, loc);
4609 },
4610 [&](const Fortran::parser::AccAtomicWrite &atomicWrite) {
4611 genAtomicWrite(converter, atomicWrite, loc);
4612 },
4613 [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) {
4614 genAtomicUpdate(converter, atomicUpdate, loc);
4615 },
4616 [&](const Fortran::parser::AccAtomicCapture &atomicCapture) {
4617 genAtomicCapture(converter, atomicCapture, loc);
4618 },
4619 },
4620 atomicConstruct.u);
4621}
4622
4623static void
4624genACC(Fortran::lower::AbstractConverter &converter,
4625 Fortran::semantics::SemanticsContext &semanticsContext,
4626 const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
4627 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4628 auto loopOp = builder.getRegion().getParentOfType<mlir::acc::LoopOp>();
4629 auto crtPos = builder.saveInsertionPoint();
4630 if (loopOp) {
4631 builder.setInsertionPoint(loopOp);
4632 Fortran::lower::StatementContext stmtCtx;
4633 llvm::SmallVector<mlir::Value> cacheOperands;
4634 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
4635 std::get<Fortran::parser::AccObjectListWithModifier>(cacheConstruct.t);
4636 const auto &accObjectList =
4637 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4638 const auto &modifier =
4639 std::get<std::optional<Fortran::parser::AccDataModifier>>(
4640 listWithModifier.t);
4641
4642 mlir::acc::DataClause dataClause = mlir::acc::DataClause::acc_cache;
4643 if (modifier &&
4644 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly)
4645 dataClause = mlir::acc::DataClause::acc_cache_readonly;
4646 genDataOperandOperations<mlir::acc::CacheOp>(
4647 accObjectList, converter, semanticsContext, stmtCtx, cacheOperands,
4648 dataClause,
4649 /*structured=*/true, /*implicit=*/false,
4650 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{},
4651 /*setDeclareAttr*/ false);
4652 loopOp.getCacheOperandsMutable().append(cacheOperands);
4653 } else {
4654 llvm::report_fatal_error(
4655 reason: "could not find loop to attach OpenACC cache information.");
4656 }
4657 builder.restoreInsertionPoint(crtPos);
4658}
4659
4660mlir::Value Fortran::lower::genOpenACCConstruct(
4661 Fortran::lower::AbstractConverter &converter,
4662 Fortran::semantics::SemanticsContext &semanticsContext,
4663 Fortran::lower::pft::Evaluation &eval,
4664 const Fortran::parser::OpenACCConstruct &accConstruct) {
4665
4666 mlir::Value exitCond;
4667 Fortran::common::visit(
4668 common::visitors{
4669 [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
4670 genACC(converter, semanticsContext, eval, blockConstruct);
4671 },
4672 [&](const Fortran::parser::OpenACCCombinedConstruct
4673 &combinedConstruct) {
4674 genACC(converter, semanticsContext, eval, combinedConstruct);
4675 },
4676 [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
4677 exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
4678 },
4679 [&](const Fortran::parser::OpenACCStandaloneConstruct
4680 &standaloneConstruct) {
4681 genACC(converter, semanticsContext, standaloneConstruct);
4682 },
4683 [&](const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
4684 genACC(converter, semanticsContext, cacheConstruct);
4685 },
4686 [&](const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
4687 genACC(converter, waitConstruct);
4688 },
4689 [&](const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
4690 genACC(converter, eval, atomicConstruct);
4691 },
4692 [&](const Fortran::parser::OpenACCEndConstruct &) {
4693 // No op
4694 },
4695 },
4696 accConstruct.u);
4697 return exitCond;
4698}
4699
4700void Fortran::lower::genOpenACCDeclarativeConstruct(
4701 Fortran::lower::AbstractConverter &converter,
4702 Fortran::semantics::SemanticsContext &semanticsContext,
4703 Fortran::lower::StatementContext &openAccCtx,
4704 const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct) {
4705
4706 Fortran::common::visit(
4707 common::visitors{
4708 [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
4709 &standaloneDeclarativeConstruct) {
4710 genACC(converter, semanticsContext, openAccCtx,
4711 standaloneDeclarativeConstruct);
4712 },
4713 [&](const Fortran::parser::OpenACCRoutineConstruct &x) {},
4714 },
4715 accDeclConstruct.u);
4716}
4717
4718void Fortran::lower::attachDeclarePostAllocAction(
4719 AbstractConverter &converter, fir::FirOpBuilder &builder,
4720 const Fortran::semantics::Symbol &sym) {
4721 std::stringstream fctName;
4722 fctName << converter.mangleName(sym) << declarePostAllocSuffix.str();
4723 mlir::Operation *op = &builder.getInsertionBlock()->back();
4724
4725 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) {
4726 assert(resOp.getOperands().size() == 0 &&
4727 "expect only fir.result op with no operand");
4728 op = op->getPrevNode();
4729 }
4730 assert(op && "expect operation to attach the post allocation action");
4731
4732 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) {
4733 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4734 mlir::acc::getDeclareActionAttrName());
4735 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4736 mlir::acc::DeclareActionAttr::get(
4737 builder.getContext(), attr.getPreAlloc(),
4738 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
4739 attr.getPreDealloc(), attr.getPostDealloc()));
4740 } else {
4741 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4742 mlir::acc::DeclareActionAttr::get(
4743 builder.getContext(),
4744 /*preAlloc=*/{},
4745 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
4746 /*preDealloc=*/{}, /*postDealloc=*/{}));
4747 }
4748}
4749
4750void Fortran::lower::attachDeclarePreDeallocAction(
4751 AbstractConverter &converter, fir::FirOpBuilder &builder,
4752 mlir::Value beginOpValue, const Fortran::semantics::Symbol &sym) {
4753 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
4754 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
4755 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
4756 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
4757 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
4758 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
4759 return;
4760
4761 std::stringstream fctName;
4762 fctName << converter.mangleName(sym) << declarePreDeallocSuffix.str();
4763
4764 auto *op = beginOpValue.getDefiningOp();
4765 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) {
4766 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4767 mlir::acc::getDeclareActionAttrName());
4768 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4769 mlir::acc::DeclareActionAttr::get(
4770 builder.getContext(), attr.getPreAlloc(),
4771 attr.getPostAlloc(),
4772 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
4773 attr.getPostDealloc()));
4774 } else {
4775 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4776 mlir::acc::DeclareActionAttr::get(
4777 builder.getContext(),
4778 /*preAlloc=*/{}, /*postAlloc=*/{},
4779 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
4780 /*postDealloc=*/{}));
4781 }
4782}
4783
4784void Fortran::lower::attachDeclarePostDeallocAction(
4785 AbstractConverter &converter, fir::FirOpBuilder &builder,
4786 const Fortran::semantics::Symbol &sym) {
4787 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
4788 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
4789 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
4790 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
4791 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
4792 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
4793 return;
4794
4795 std::stringstream fctName;
4796 fctName << converter.mangleName(sym) << declarePostDeallocSuffix.str();
4797 mlir::Operation *op = &builder.getInsertionBlock()->back();
4798 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) {
4799 assert(resOp.getOperands().size() == 0 &&
4800 "expect only fir.result op with no operand");
4801 op = op->getPrevNode();
4802 }
4803 assert(op && "expect operation to attach the post deallocation action");
4804 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) {
4805 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4806 mlir::acc::getDeclareActionAttrName());
4807 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4808 mlir::acc::DeclareActionAttr::get(
4809 builder.getContext(), attr.getPreAlloc(),
4810 attr.getPostAlloc(), attr.getPreDealloc(),
4811 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
4812 } else {
4813 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4814 mlir::acc::DeclareActionAttr::get(
4815 builder.getContext(),
4816 /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{},
4817 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
4818 }
4819}
4820
4821void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
4822 mlir::Operation *op,
4823 mlir::Location loc) {
4824 if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op))
4825 builder.create<mlir::acc::YieldOp>(loc);
4826 else
4827 builder.create<mlir::acc::TerminatorOp>(loc);
4828}
4829
4830bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
4831 if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
4832 return true;
4833 return false;
4834}
4835
4836void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
4837 fir::FirOpBuilder &builder) {
4838 if (auto loopOp =
4839 builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
4840 builder.setInsertionPointAfter(loopOp);
4841}
4842
4843void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
4844 mlir::Location loc) {
4845 mlir::Value yieldValue =
4846 builder.createIntegerConstant(loc, builder.getI1Type(), 1);
4847 builder.create<mlir::acc::YieldOp>(loc, yieldValue);
4848}
4849
4850int64_t Fortran::lower::getLoopCountForCollapseAndTile(
4851 const Fortran::parser::AccClauseList &clauseList) {
4852 int64_t collapseLoopCount = 1;
4853 int64_t tileLoopCount = 1;
4854 for (const Fortran::parser::AccClause &clause : clauseList.v) {
4855 if (const auto *collapseClause =
4856 std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
4857 const parser::AccCollapseArg &arg = collapseClause->v;
4858 const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
4859 collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
4860 }
4861 if (const auto *tileClause =
4862 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
4863 const parser::AccTileExprList &tileExprList = tileClause->v;
4864 const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
4865 tileLoopCount = listTileExpr.size();
4866 }
4867 }
4868 if (tileLoopCount > collapseLoopCount)
4869 return tileLoopCount;
4870 return collapseLoopCount;
4871}
4872

source code of flang/lib/Lower/OpenACC.cpp