1//===-- OpenACC.cpp -- OpenACC directive lowering -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/OpenACC.h"
14#include "DirectivesCommon.h"
15#include "flang/Common/idioms.h"
16#include "flang/Lower/Bridge.h"
17#include "flang/Lower/ConvertType.h"
18#include "flang/Lower/Mangler.h"
19#include "flang/Lower/PFTBuilder.h"
20#include "flang/Lower/StatementContext.h"
21#include "flang/Lower/Support/Utils.h"
22#include "flang/Optimizer/Builder/BoxValue.h"
23#include "flang/Optimizer/Builder/Complex.h"
24#include "flang/Optimizer/Builder/FIRBuilder.h"
25#include "flang/Optimizer/Builder/HLFIRTools.h"
26#include "flang/Optimizer/Builder/IntrinsicCall.h"
27#include "flang/Optimizer/Builder/Todo.h"
28#include "flang/Parser/parse-tree-visitor.h"
29#include "flang/Parser/parse-tree.h"
30#include "flang/Semantics/expression.h"
31#include "flang/Semantics/scope.h"
32#include "flang/Semantics/tools.h"
33#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
34#include "llvm/Frontend/OpenACC/ACC.h.inc"
35
36// Special value for * passed in device_type or gang clauses.
37static constexpr std::int64_t starCst = -1;
38
39static unsigned routineCounter = 0;
40static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
41static constexpr llvm::StringRef accPrivateInitName = "acc.private.init";
42static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init";
43static constexpr llvm::StringRef accFirDescriptorPostfix = "_desc";
44
45static mlir::Location
46genOperandLocation(Fortran::lower::AbstractConverter &converter,
47 const Fortran::parser::AccObject &accObject) {
48 mlir::Location loc = converter.genUnknownLocation();
49 std::visit(Fortran::common::visitors{
50 [&](const Fortran::parser::Designator &designator) {
51 loc = converter.genLocation(designator.source);
52 },
53 [&](const Fortran::parser::Name &name) {
54 loc = converter.genLocation(name.source);
55 }},
56 accObject.u);
57 return loc;
58}
59
60template <typename Op>
61static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
62 mlir::Value baseAddr, std::stringstream &name,
63 mlir::SmallVector<mlir::Value> bounds,
64 bool structured, bool implicit,
65 mlir::acc::DataClause dataClause, mlir::Type retTy,
66 mlir::Value isPresent = {}) {
67 mlir::Value varPtrPtr;
68 if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
69 if (isPresent) {
70 mlir::Type ifRetTy = boxTy.getEleTy();
71 if (!fir::isa_ref_type(ifRetTy))
72 ifRetTy = fir::ReferenceType::get(ifRetTy);
73 baseAddr =
74 builder
75 .genIfOp(loc, {ifRetTy}, isPresent,
76 /*withElseRegion=*/true)
77 .genThen([&]() {
78 mlir::Value boxAddr =
79 builder.create<fir::BoxAddrOp>(loc, baseAddr);
80 builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr});
81 })
82 .genElse([&] {
83 mlir::Value absent =
84 builder.create<fir::AbsentOp>(loc, ifRetTy);
85 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
86 })
87 .getResults()[0];
88 } else {
89 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
90 }
91 retTy = baseAddr.getType();
92 }
93
94 Op op = builder.create<Op>(loc, retTy, baseAddr);
95 op.setNameAttr(builder.getStringAttr(name.str()));
96 op.setStructured(structured);
97 op.setImplicit(implicit);
98 op.setDataClause(dataClause);
99
100 unsigned insPos = 1;
101 if (varPtrPtr)
102 op->insertOperands(insPos++, varPtrPtr);
103 if (bounds.size() > 0)
104 op->insertOperands(insPos, bounds);
105 op->setAttr(Op::getOperandSegmentSizeAttr(),
106 builder.getDenseI32ArrayAttr(
107 {1, varPtrPtr ? 1 : 0, static_cast<int32_t>(bounds.size())}));
108 return op;
109}
110
111static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
112 mlir::acc::DataClause clause) {
113 if (!op)
114 return;
115 op->setAttr(mlir::acc::getDeclareAttrName(),
116 mlir::acc::DeclareAttr::get(builder.getContext(),
117 mlir::acc::DataClauseAttr::get(
118 builder.getContext(), clause)));
119}
120
121static mlir::func::FuncOp
122createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
123 mlir::Location loc, llvm::StringRef funcName,
124 llvm::SmallVector<mlir::Type> argsTy = {},
125 llvm::SmallVector<mlir::Location> locs = {}) {
126 auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {});
127 auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
128 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
129 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
130 locs);
131 builder.setInsertionPointToEnd(&funcOp.getRegion().back());
132 builder.create<mlir::func::ReturnOp>(loc);
133 builder.setInsertionPointToStart(&funcOp.getRegion().back());
134 return funcOp;
135}
136
137template <typename Op>
138static Op
139createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
140 const llvm::SmallVectorImpl<mlir::Value> &operands,
141 const llvm::SmallVectorImpl<int32_t> &operandSegments) {
142 llvm::ArrayRef<mlir::Type> argTy;
143 Op op = builder.create<Op>(loc, argTy, operands);
144 op->setAttr(Op::getOperandSegmentSizeAttr(),
145 builder.getDenseI32ArrayAttr(operandSegments));
146 return op;
147}
148
149template <typename EntryOp>
150static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
151 fir::FirOpBuilder &builder,
152 mlir::Location loc, mlir::Type descTy,
153 llvm::StringRef funcNamePrefix,
154 std::stringstream &asFortran,
155 mlir::acc::DataClause clause) {
156 auto crtInsPt = builder.saveInsertionPoint();
157 std::stringstream registerFuncName;
158 registerFuncName << funcNamePrefix.str()
159 << Fortran::lower::declarePostAllocSuffix.str();
160
161 if (!mlir::isa<fir::ReferenceType>(descTy))
162 descTy = fir::ReferenceType::get(descTy);
163 auto registerFuncOp = createDeclareFunc(
164 modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc});
165
166 llvm::SmallVector<mlir::Value> bounds;
167 std::stringstream asFortranDesc;
168 asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str();
169
170 // Updating descriptor must occur before the mapping of the data so that
171 // attached data pointer is not overwritten.
172 mlir::acc::UpdateDeviceOp updateDeviceOp =
173 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
174 builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
175 /*structured=*/false, /*implicit=*/true,
176 mlir::acc::DataClause::acc_update_device, descTy);
177 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
178 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
179 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
180
181 mlir::Value desc =
182 builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
183 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
184 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
185 EntryOp entryOp = createDataEntryOp<EntryOp>(
186 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
187 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
188 builder.create<mlir::acc::DeclareEnterOp>(
189 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
190 mlir::ValueRange(entryOp.getAccPtr()));
191
192 modBuilder.setInsertionPointAfter(registerFuncOp);
193 builder.restoreInsertionPoint(crtInsPt);
194}
195
196template <typename ExitOp>
197static void createDeclareDeallocFuncWithArg(
198 mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
199 mlir::Type descTy, llvm::StringRef funcNamePrefix,
200 std::stringstream &asFortran, mlir::acc::DataClause clause) {
201 auto crtInsPt = builder.saveInsertionPoint();
202 // Generate the pre dealloc function.
203 std::stringstream preDeallocFuncName;
204 preDeallocFuncName << funcNamePrefix.str()
205 << Fortran::lower::declarePreDeallocSuffix.str();
206 if (!mlir::isa<fir::ReferenceType>(descTy))
207 descTy = fir::ReferenceType::get(descTy);
208 auto preDeallocOp = createDeclareFunc(
209 modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
210 mlir::Value loadOp =
211 builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
212 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
213 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
214
215 llvm::SmallVector<mlir::Value> bounds;
216 mlir::acc::GetDevicePtrOp entryOp =
217 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
218 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
219 /*structured=*/false, /*implicit=*/false, clause,
220 boxAddrOp.getType());
221 builder.create<mlir::acc::DeclareExitOp>(
222 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
223
224 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
225 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
226 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
227 entryOp.getVarPtr(), entryOp.getBounds(),
228 entryOp.getDataClause(),
229 /*structured=*/false, /*implicit=*/false,
230 builder.getStringAttr(*entryOp.getName()));
231 else
232 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
233 entryOp.getBounds(), entryOp.getDataClause(),
234 /*structured=*/false, /*implicit=*/false,
235 builder.getStringAttr(*entryOp.getName()));
236
237 // Generate the post dealloc function.
238 modBuilder.setInsertionPointAfter(preDeallocOp);
239 std::stringstream postDeallocFuncName;
240 postDeallocFuncName << funcNamePrefix.str()
241 << Fortran::lower::declarePostDeallocSuffix.str();
242 auto postDeallocOp = createDeclareFunc(
243 modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
244 loadOp = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
245 asFortran << accFirDescriptorPostfix.str();
246 mlir::acc::UpdateDeviceOp updateDeviceOp =
247 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
248 builder, loc, loadOp, asFortran, bounds,
249 /*structured=*/false, /*implicit=*/true,
250 mlir::acc::DataClause::acc_update_device, loadOp.getType());
251 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
252 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
253 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
254 modBuilder.setInsertionPointAfter(postDeallocOp);
255 builder.restoreInsertionPoint(crtInsPt);
256}
257
258Fortran::semantics::Symbol &
259getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
260 if (const auto *designator =
261 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
262 if (const auto *name =
263 Fortran::semantics::getDesignatorNameIfDataRef(*designator))
264 return *name->symbol;
265 if (const auto *arrayElement =
266 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
267 *designator)) {
268 const Fortran::parser::Name &name =
269 Fortran::parser::GetLastName(arrayElement->base);
270 return *name.symbol;
271 }
272 if (const auto *component =
273 Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
274 *designator)) {
275 return *component->component.symbol;
276 }
277 } else if (const auto *name =
278 std::get_if<Fortran::parser::Name>(&accObject.u)) {
279 return *name->symbol;
280 }
281 llvm::report_fatal_error(reason: "Could not find symbol");
282}
283
284template <typename Op>
285static void
286genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
287 Fortran::lower::AbstractConverter &converter,
288 Fortran::semantics::SemanticsContext &semanticsContext,
289 Fortran::lower::StatementContext &stmtCtx,
290 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
291 mlir::acc::DataClause dataClause, bool structured,
292 bool implicit, bool setDeclareAttr = false) {
293 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
294 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
295 for (const auto &accObject : objectList.v) {
296 llvm::SmallVector<mlir::Value> bounds;
297 std::stringstream asFortran;
298 mlir::Location operandLocation = genOperandLocation(converter, accObject);
299 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
300 Fortran::semantics::MaybeExpr designator =
301 std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
302 Fortran::lower::AddrAndBoundsInfo info =
303 Fortran::lower::gatherDataOperandAddrAndBounds<
304 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
305 converter, builder, semanticsContext, stmtCtx, symbol, designator,
306 operandLocation, asFortran, bounds,
307 /*treatIndexAsSection=*/true);
308
309 // If the input value is optional and is not a descriptor, we use the
310 // rawInput directly.
311 mlir::Value baseAddr =
312 ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
313 info.isPresent)
314 ? info.rawInput
315 : info.addr;
316 Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
317 bounds, structured, implicit, dataClause,
318 baseAddr.getType(), info.isPresent);
319 dataOperands.push_back(op.getAccPtr());
320 }
321}
322
323template <typename EntryOp, typename ExitOp>
324static void genDeclareDataOperandOperations(
325 const Fortran::parser::AccObjectList &objectList,
326 Fortran::lower::AbstractConverter &converter,
327 Fortran::semantics::SemanticsContext &semanticsContext,
328 Fortran::lower::StatementContext &stmtCtx,
329 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
330 mlir::acc::DataClause dataClause, bool structured, bool implicit) {
331 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
332 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
333 for (const auto &accObject : objectList.v) {
334 llvm::SmallVector<mlir::Value> bounds;
335 std::stringstream asFortran;
336 mlir::Location operandLocation = genOperandLocation(converter, accObject);
337 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
338 Fortran::semantics::MaybeExpr designator =
339 std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
340 Fortran::lower::AddrAndBoundsInfo info =
341 Fortran::lower::gatherDataOperandAddrAndBounds<
342 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
343 converter, builder, semanticsContext, stmtCtx, symbol, designator,
344 operandLocation, asFortran, bounds);
345 EntryOp op = createDataEntryOp<EntryOp>(
346 builder, operandLocation, info.addr, asFortran, bounds, structured,
347 implicit, dataClause, info.addr.getType());
348 dataOperands.push_back(op.getAccPtr());
349 addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
350 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
351 mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
352 modBuilder.setInsertionPointAfter(builder.getFunction());
353 std::string prefix = converter.mangleName(symbol);
354 createDeclareAllocFuncWithArg<EntryOp>(
355 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
356 asFortran, dataClause);
357 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
358 createDeclareDeallocFuncWithArg<ExitOp>(
359 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
360 asFortran, dataClause);
361 }
362 }
363}
364
365template <typename EntryOp, typename ExitOp, typename Clause>
366static void genDeclareDataOperandOperationsWithModifier(
367 const Clause *x, Fortran::lower::AbstractConverter &converter,
368 Fortran::semantics::SemanticsContext &semanticsContext,
369 Fortran::lower::StatementContext &stmtCtx,
370 Fortran::parser::AccDataModifier::Modifier mod,
371 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
372 const mlir::acc::DataClause clause,
373 const mlir::acc::DataClause clauseWithModifier) {
374 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
375 const auto &accObjectList =
376 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
377 const auto &modifier =
378 std::get<std::optional<Fortran::parser::AccDataModifier>>(
379 listWithModifier.t);
380 mlir::acc::DataClause dataClause =
381 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
382 genDeclareDataOperandOperations<EntryOp, ExitOp>(
383 accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
384 dataClause,
385 /*structured=*/true, /*implicit=*/false);
386}
387
388template <typename EntryOp, typename ExitOp>
389static void genDataExitOperations(fir::FirOpBuilder &builder,
390 llvm::SmallVector<mlir::Value> operands,
391 bool structured) {
392 for (mlir::Value operand : operands) {
393 auto entryOp = mlir::dyn_cast_or_null<EntryOp>(operand.getDefiningOp());
394 assert(entryOp && "data entry op expected");
395 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
396 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
397 builder.create<ExitOp>(
398 entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
399 entryOp.getBounds(), entryOp.getDataClause(), structured,
400 entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
401 else
402 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
403 entryOp.getBounds(), entryOp.getDataClause(),
404 structured, entryOp.getImplicit(),
405 builder.getStringAttr(*entryOp.getName()));
406 }
407}
408
409fir::ShapeOp genShapeOp(mlir::OpBuilder &builder, fir::SequenceType seqTy,
410 mlir::Location loc) {
411 llvm::SmallVector<mlir::Value> extents;
412 mlir::Type idxTy = builder.getIndexType();
413 for (auto extent : seqTy.getShape())
414 extents.push_back(builder.create<mlir::arith::ConstantOp>(
415 loc, idxTy, builder.getIntegerAttr(idxTy, extent)));
416 return builder.create<fir::ShapeOp>(loc, extents);
417}
418
419template <typename RecipeOp>
420static void genPrivateLikeInitRegion(mlir::OpBuilder &builder, RecipeOp recipe,
421 mlir::Type ty, mlir::Location loc) {
422 mlir::Value retVal = recipe.getInitRegion().front().getArgument(0);
423 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
424 if (fir::isa_trivial(refTy.getEleTy())) {
425 auto alloca = builder.create<fir::AllocaOp>(loc, refTy.getEleTy());
426 auto declareOp = builder.create<hlfir::DeclareOp>(
427 loc, alloca, accPrivateInitName, /*shape=*/nullptr,
428 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
429 retVal = declareOp.getBase();
430 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
431 refTy.getEleTy())) {
432 if (fir::isa_trivial(seqTy.getEleTy())) {
433 mlir::Value shape;
434 llvm::SmallVector<mlir::Value> extents;
435 if (seqTy.hasDynamicExtents()) {
436 // Extents are passed as block arguments. First argument is the
437 // original value.
438 for (unsigned i = 1; i < recipe.getInitRegion().getArguments().size();
439 ++i)
440 extents.push_back(Elt: recipe.getInitRegion().getArgument(i));
441 shape = builder.create<fir::ShapeOp>(loc, extents);
442 } else {
443 shape = genShapeOp(builder, seqTy, loc);
444 }
445 auto alloca = builder.create<fir::AllocaOp>(
446 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
447 auto declareOp = builder.create<hlfir::DeclareOp>(
448 loc, alloca, accPrivateInitName, shape,
449 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
450 retVal = declareOp.getBase();
451 }
452 }
453 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
454 mlir::Type innerTy = fir::extractSequenceType(boxTy);
455 if (!innerTy)
456 TODO(loc, "Unsupported boxed type in OpenACC privatization");
457 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
458 hlfir::Entity source = hlfir::Entity{retVal};
459 auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source);
460 retVal = temp;
461 }
462 builder.create<mlir::acc::YieldOp>(loc, retVal);
463}
464
465mlir::acc::PrivateRecipeOp
466Fortran::lower::createOrGetPrivateRecipe(mlir::OpBuilder &builder,
467 llvm::StringRef recipeName,
468 mlir::Location loc, mlir::Type ty) {
469 mlir::ModuleOp mod =
470 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
471 if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName))
472 return recipe;
473
474 auto crtPos = builder.saveInsertionPoint();
475 mlir::OpBuilder modBuilder(mod.getBodyRegion());
476 auto recipe =
477 modBuilder.create<mlir::acc::PrivateRecipeOp>(loc, recipeName, ty);
478 llvm::SmallVector<mlir::Type> argsTy{ty};
479 llvm::SmallVector<mlir::Location> argsLoc{loc};
480 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
481 if (auto seqTy =
482 mlir::dyn_cast_or_null<fir::SequenceType>(refTy.getEleTy())) {
483 if (seqTy.hasDynamicExtents()) {
484 mlir::Type idxTy = builder.getIndexType();
485 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
486 argsTy.push_back(Elt: idxTy);
487 argsLoc.push_back(Elt: loc);
488 }
489 }
490 }
491 }
492 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
493 argsTy, argsLoc);
494 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
495 genPrivateLikeInitRegion<mlir::acc::PrivateRecipeOp>(builder, recipe, ty,
496 loc);
497 builder.restoreInsertionPoint(ip: crtPos);
498 return recipe;
499}
500
501/// Check if the DataBoundsOp is a constant bound (lb and ub are constants or
502/// extent is a constant).
503bool isConstantBound(mlir::acc::DataBoundsOp &op) {
504 if (op.getLowerbound() && fir::getIntIfConstant(op.getLowerbound()) &&
505 op.getUpperbound() && fir::getIntIfConstant(op.getUpperbound()))
506 return true;
507 if (op.getExtent() && fir::getIntIfConstant(op.getExtent()))
508 return true;
509 return false;
510}
511
512/// Return true iff all the bounds are expressed with constant values.
513bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
514 for (auto bound : bounds) {
515 auto dataBound =
516 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
517 assert(dataBound && "Must be DataBoundOp operation");
518 if (!isConstantBound(dataBound))
519 return false;
520 }
521 return true;
522}
523
524static llvm::SmallVector<mlir::Value>
525genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
526 mlir::acc::DataBoundsOp &dataBound) {
527 mlir::Type idxTy = builder.getIndexType();
528 mlir::Value lb, ub, step;
529 if (dataBound.getLowerbound() &&
530 fir::getIntIfConstant(dataBound.getLowerbound()) &&
531 dataBound.getUpperbound() &&
532 fir::getIntIfConstant(dataBound.getUpperbound())) {
533 lb = builder.createIntegerConstant(
534 loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound()));
535 ub = builder.createIntegerConstant(
536 loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound()));
537 step = builder.createIntegerConstant(loc, idxTy, 1);
538 } else if (dataBound.getExtent()) {
539 lb = builder.createIntegerConstant(loc, idxTy, 0);
540 ub = builder.createIntegerConstant(
541 loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1);
542 step = builder.createIntegerConstant(loc, idxTy, 1);
543 } else {
544 llvm::report_fatal_error(reason: "Expect constant lb/ub or extent");
545 }
546 return {lb, ub, step};
547}
548
549static fir::ShapeOp genShapeFromBoundsOrArgs(
550 mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
551 const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
552 llvm::SmallVector<mlir::Value> args;
553 if (areAllBoundConstant(bounds)) {
554 for (auto bound : llvm::reverse(C: bounds)) {
555 auto dataBound =
556 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
557 args.append(genConstantBounds(builder, loc, dataBound));
558 }
559 } else {
560 assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
561 "Expect 3 block arguments per dimension");
562 for (auto arg : arguments.drop_front(n: 2))
563 args.push_back(Elt: arg);
564 }
565
566 assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
567 llvm::SmallVector<mlir::Value> extents;
568 mlir::Type idxTy = builder.getIndexType();
569 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
570 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
571 for (unsigned i = 0; i < args.size(); i += 3) {
572 mlir::Value s1 =
573 builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
574 mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
575 mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
576 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
577 loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
578 mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
579 extents.push_back(Elt: ext);
580 }
581 return builder.create<fir::ShapeOp>(loc, extents);
582}
583
584static hlfir::DesignateOp::Subscripts
585getSubscriptsFromArgs(mlir::ValueRange args) {
586 hlfir::DesignateOp::Subscripts triplets;
587 for (unsigned i = 2; i < args.size(); i += 3)
588 triplets.emplace_back(
589 hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
590 return triplets;
591}
592
593static hlfir::Entity genDesignateWithTriplets(
594 fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
595 hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
596 llvm::SmallVector<mlir::Value> lenParams;
597 hlfir::genLengthParameters(loc, builder, entity, lenParams);
598 auto designate = builder.create<hlfir::DesignateOp>(
599 loc, entity.getBase().getType(), entity, /*component=*/"",
600 /*componentShape=*/mlir::Value{}, triplets,
601 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
602 lenParams);
603 return hlfir::Entity{designate.getResult()};
604}
605
606mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
607 mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
608 mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
609 mlir::ModuleOp mod =
610 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
611 if (auto recipe =
612 mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName))
613 return recipe;
614
615 auto crtPos = builder.saveInsertionPoint();
616 mlir::OpBuilder modBuilder(mod.getBodyRegion());
617 auto recipe =
618 modBuilder.create<mlir::acc::FirstprivateRecipeOp>(loc, recipeName, ty);
619 llvm::SmallVector<mlir::Type> initArgsTy{ty};
620 llvm::SmallVector<mlir::Location> initArgsLoc{loc};
621 auto refTy = fir::unwrapRefType(ty);
622 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) {
623 if (seqTy.hasDynamicExtents()) {
624 mlir::Type idxTy = builder.getIndexType();
625 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
626 initArgsTy.push_back(Elt: idxTy);
627 initArgsLoc.push_back(Elt: loc);
628 }
629 }
630 }
631 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
632 initArgsTy, initArgsLoc);
633 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
634 genPrivateLikeInitRegion<mlir::acc::FirstprivateRecipeOp>(builder, recipe, ty,
635 loc);
636
637 bool allConstantBound = areAllBoundConstant(bounds);
638 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
639 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
640 if (!allConstantBound) {
641 for (mlir::Value bound : llvm::reverse(C&: bounds)) {
642 auto dataBound =
643 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
644 argsTy.push_back(Elt: dataBound.getLowerbound().getType());
645 argsLoc.push_back(Elt: dataBound.getLowerbound().getLoc());
646 argsTy.push_back(Elt: dataBound.getUpperbound().getType());
647 argsLoc.push_back(Elt: dataBound.getUpperbound().getLoc());
648 argsTy.push_back(Elt: dataBound.getStartIdx().getType());
649 argsLoc.push_back(Elt: dataBound.getStartIdx().getLoc());
650 }
651 }
652 builder.createBlock(&recipe.getCopyRegion(), recipe.getCopyRegion().end(),
653 argsTy, argsLoc);
654
655 builder.setInsertionPointToEnd(&recipe.getCopyRegion().back());
656 ty = fir::unwrapRefType(ty);
657 if (fir::isa_trivial(ty)) {
658 mlir::Value initValue = builder.create<fir::LoadOp>(
659 loc, recipe.getCopyRegion().front().getArgument(0));
660 builder.create<fir::StoreOp>(loc, initValue,
661 recipe.getCopyRegion().front().getArgument(1));
662 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
663 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
664 auto shape = genShapeFromBoundsOrArgs(
665 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
666
667 auto leftDeclOp = builder.create<hlfir::DeclareOp>(
668 loc, recipe.getCopyRegion().getArgument(0), llvm::StringRef{}, shape,
669 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
670 auto rightDeclOp = builder.create<hlfir::DeclareOp>(
671 loc, recipe.getCopyRegion().getArgument(1), llvm::StringRef{}, shape,
672 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
673
674 hlfir::DesignateOp::Subscripts triplets =
675 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
676 auto leftEntity = hlfir::Entity{leftDeclOp.getBase()};
677 auto left =
678 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
679 auto rightEntity = hlfir::Entity{rightDeclOp.getBase()};
680 auto right =
681 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
682
683 firBuilder.create<hlfir::AssignOp>(loc, left, right);
684
685 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
686 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
687 llvm::SmallVector<mlir::Value> tripletArgs;
688 mlir::Type innerTy = fir::extractSequenceType(boxTy);
689 fir::SequenceType seqTy =
690 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
691 if (!seqTy)
692 TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
693
694 auto shape = genShapeFromBoundsOrArgs(
695 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
696 hlfir::DesignateOp::Subscripts triplets =
697 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
698 auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
699 auto left =
700 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
701 auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
702 auto right =
703 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
704 firBuilder.create<hlfir::AssignOp>(loc, left, right);
705 }
706
707 builder.create<mlir::acc::TerminatorOp>(loc);
708 builder.restoreInsertionPoint(ip: crtPos);
709 return recipe;
710}
711
712/// Get a string representation of the bounds.
713std::string getBoundsString(llvm::SmallVector<mlir::Value> &bounds) {
714 std::stringstream boundStr;
715 if (!bounds.empty())
716 boundStr << "_section_";
717 llvm::interleave(
718 c: bounds,
719 each_fn: [&](mlir::Value bound) {
720 auto boundsOp =
721 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
722 if (boundsOp.getLowerbound() &&
723 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
724 boundsOp.getUpperbound() &&
725 fir::getIntIfConstant(boundsOp.getUpperbound())) {
726 boundStr << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound())
727 << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound());
728 } else if (boundsOp.getExtent() &&
729 fir::getIntIfConstant(boundsOp.getExtent())) {
730 boundStr << "ext" << *fir::getIntIfConstant(boundsOp.getExtent());
731 } else {
732 boundStr << "?";
733 }
734 },
735 between_fn: [&] { boundStr << "x"; });
736 return boundStr.str();
737}
738
739/// Rebuild the array type from the acc.bounds operation with constant
740/// lowerbound/upperbound or extent.
741mlir::Type getTypeFromBounds(llvm::SmallVector<mlir::Value> &bounds,
742 mlir::Type ty) {
743 auto seqTy =
744 mlir::dyn_cast_or_null<fir::SequenceType>(fir::unwrapRefType(ty));
745 if (!bounds.empty() && seqTy) {
746 llvm::SmallVector<int64_t> shape;
747 for (auto b : bounds) {
748 auto boundsOp =
749 mlir::dyn_cast<mlir::acc::DataBoundsOp>(b.getDefiningOp());
750 if (boundsOp.getLowerbound() &&
751 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
752 boundsOp.getUpperbound() &&
753 fir::getIntIfConstant(boundsOp.getUpperbound())) {
754 int64_t ext = *fir::getIntIfConstant(boundsOp.getUpperbound()) -
755 *fir::getIntIfConstant(boundsOp.getLowerbound()) + 1;
756 shape.push_back(Elt: ext);
757 } else if (boundsOp.getExtent() &&
758 fir::getIntIfConstant(boundsOp.getExtent())) {
759 shape.push_back(*fir::getIntIfConstant(boundsOp.getExtent()));
760 } else {
761 return ty; // TODO: handle dynamic shaped array slice.
762 }
763 }
764 if (shape.empty() || shape.size() != bounds.size())
765 return ty;
766 auto newSeqTy = fir::SequenceType::get(shape, seqTy.getEleTy());
767 if (mlir::isa<fir::ReferenceType, fir::PointerType>(ty))
768 return fir::ReferenceType::get(newSeqTy);
769 return newSeqTy;
770 }
771 return ty;
772}
773
774template <typename RecipeOp>
775static void
776genPrivatizations(const Fortran::parser::AccObjectList &objectList,
777 Fortran::lower::AbstractConverter &converter,
778 Fortran::semantics::SemanticsContext &semanticsContext,
779 Fortran::lower::StatementContext &stmtCtx,
780 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
781 llvm::SmallVector<mlir::Attribute> &privatizations) {
782 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
783 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
784 for (const auto &accObject : objectList.v) {
785 llvm::SmallVector<mlir::Value> bounds;
786 std::stringstream asFortran;
787 mlir::Location operandLocation = genOperandLocation(converter, accObject);
788 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
789 Fortran::semantics::MaybeExpr designator =
790 std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
791 Fortran::lower::AddrAndBoundsInfo info =
792 Fortran::lower::gatherDataOperandAddrAndBounds<
793 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
794 converter, builder, semanticsContext, stmtCtx, symbol, designator,
795 operandLocation, asFortran, bounds);
796 RecipeOp recipe;
797 mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
798 if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
799 std::string recipeName =
800 fir::getTypeAsString(retTy, converter.getKindMap(),
801 Fortran::lower::privatizationRecipePrefix);
802 recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
803 operandLocation, retTy);
804 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
805 builder, operandLocation, info.addr, asFortran, bounds, true,
806 /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy);
807 dataOperands.push_back(op.getAccPtr());
808 } else {
809 std::string suffix =
810 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
811 std::string recipeName = fir::getTypeAsString(
812 retTy, converter.getKindMap(), "firstprivatization" + suffix);
813 recipe = Fortran::lower::createOrGetFirstprivateRecipe(
814 builder, recipeName, operandLocation, retTy, bounds);
815 auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
816 builder, operandLocation, info.addr, asFortran, bounds, true,
817 /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy);
818 dataOperands.push_back(op.getAccPtr());
819 }
820 privatizations.push_back(mlir::SymbolRefAttr::get(
821 builder.getContext(), recipe.getSymName().str()));
822 }
823}
824
825/// Return the corresponding enum value for the mlir::acc::ReductionOperator
826/// from the parser representation.
827static mlir::acc::ReductionOperator
828getReductionOperator(const Fortran::parser::AccReductionOperator &op) {
829 switch (op.v) {
830 case Fortran::parser::AccReductionOperator::Operator::Plus:
831 return mlir::acc::ReductionOperator::AccAdd;
832 case Fortran::parser::AccReductionOperator::Operator::Multiply:
833 return mlir::acc::ReductionOperator::AccMul;
834 case Fortran::parser::AccReductionOperator::Operator::Max:
835 return mlir::acc::ReductionOperator::AccMax;
836 case Fortran::parser::AccReductionOperator::Operator::Min:
837 return mlir::acc::ReductionOperator::AccMin;
838 case Fortran::parser::AccReductionOperator::Operator::Iand:
839 return mlir::acc::ReductionOperator::AccIand;
840 case Fortran::parser::AccReductionOperator::Operator::Ior:
841 return mlir::acc::ReductionOperator::AccIor;
842 case Fortran::parser::AccReductionOperator::Operator::Ieor:
843 return mlir::acc::ReductionOperator::AccXor;
844 case Fortran::parser::AccReductionOperator::Operator::And:
845 return mlir::acc::ReductionOperator::AccLand;
846 case Fortran::parser::AccReductionOperator::Operator::Or:
847 return mlir::acc::ReductionOperator::AccLor;
848 case Fortran::parser::AccReductionOperator::Operator::Eqv:
849 return mlir::acc::ReductionOperator::AccEqv;
850 case Fortran::parser::AccReductionOperator::Operator::Neqv:
851 return mlir::acc::ReductionOperator::AccNeqv;
852 }
853 llvm_unreachable("unexpected reduction operator");
854}
855
856/// Get the initial value for reduction operator.
857template <typename R>
858static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) {
859 if (op == mlir::acc::ReductionOperator::AccMin) {
860 // min init value -> largest
861 if constexpr (std::is_same_v<R, llvm::APInt>) {
862 assert(ty.isIntOrIndex() && "expect integer or index type");
863 return llvm::APInt::getSignedMaxValue(numBits: ty.getIntOrFloatBitWidth());
864 }
865 if constexpr (std::is_same_v<R, llvm::APFloat>) {
866 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(Val&: ty);
867 assert(floatTy && "expect float type");
868 return llvm::APFloat::getLargest(Sem: floatTy.getFloatSemantics(),
869 /*negative=*/Negative: false);
870 }
871 } else if (op == mlir::acc::ReductionOperator::AccMax) {
872 // max init value -> smallest
873 if constexpr (std::is_same_v<R, llvm::APInt>) {
874 assert(ty.isIntOrIndex() && "expect integer or index type");
875 return llvm::APInt::getSignedMinValue(numBits: ty.getIntOrFloatBitWidth());
876 }
877 if constexpr (std::is_same_v<R, llvm::APFloat>) {
878 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(Val&: ty);
879 assert(floatTy && "expect float type");
880 return llvm::APFloat::getSmallest(Sem: floatTy.getFloatSemantics(),
881 /*negative=*/Negative: true);
882 }
883 } else if (op == mlir::acc::ReductionOperator::AccIand) {
884 if constexpr (std::is_same_v<R, llvm::APInt>) {
885 assert(ty.isIntOrIndex() && "expect integer type");
886 unsigned bits = ty.getIntOrFloatBitWidth();
887 return llvm::APInt::getAllOnes(numBits: bits);
888 }
889 } else {
890 // +, ior, ieor init value -> 0
891 // * init value -> 1
892 int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0;
893 if constexpr (std::is_same_v<R, llvm::APInt>) {
894 assert(ty.isIntOrIndex() && "expect integer or index type");
895 return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true);
896 }
897
898 if constexpr (std::is_same_v<R, llvm::APFloat>) {
899 assert(mlir::isa<mlir::FloatType>(ty) && "expect float type");
900 auto floatTy = mlir::dyn_cast<mlir::FloatType>(Val&: ty);
901 return llvm::APFloat(floatTy.getFloatSemantics(), value);
902 }
903
904 if constexpr (std::is_same_v<R, int64_t>)
905 return value;
906 }
907 llvm_unreachable("OpenACC reduction unsupported type");
908}
909
910/// Return a constant with the initial value for the reduction operator and
911/// type combination.
912static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
913 mlir::Location loc, mlir::Type ty,
914 mlir::acc::ReductionOperator op) {
915 if (op == mlir::acc::ReductionOperator::AccLand ||
916 op == mlir::acc::ReductionOperator::AccLor ||
917 op == mlir::acc::ReductionOperator::AccEqv ||
918 op == mlir::acc::ReductionOperator::AccNeqv) {
919 assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type");
920 bool value = true; // .true. for .and. and .eqv.
921 if (op == mlir::acc::ReductionOperator::AccLor ||
922 op == mlir::acc::ReductionOperator::AccNeqv)
923 value = false; // .false. for .or. and .neqv.
924 return builder.createBool(loc, value);
925 }
926 if (ty.isIntOrIndex())
927 return builder.create<mlir::arith::ConstantOp>(
928 loc, ty,
929 builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty)));
930 if (op == mlir::acc::ReductionOperator::AccMin ||
931 op == mlir::acc::ReductionOperator::AccMax) {
932 if (mlir::isa<fir::ComplexType>(ty))
933 llvm::report_fatal_error(
934 reason: "min/max reduction not supported for complex type");
935 if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
936 return builder.create<mlir::arith::ConstantOp>(
937 loc, ty,
938 builder.getFloatAttr(ty,
939 getReductionInitValue<llvm::APFloat>(op, ty)));
940 } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(Val&: ty)) {
941 return builder.create<mlir::arith::ConstantOp>(
942 loc, ty,
943 builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
944 } else if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
945 mlir::Type floatTy =
946 Fortran::lower::convertReal(builder.getContext(), cmplxTy.getFKind());
947 mlir::Value realInit = builder.createRealConstant(
948 loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
949 mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0);
950 return fir::factory::Complex{builder, loc}.createComplex(
951 cmplxTy.getFKind(), realInit, imagInit);
952 }
953
954 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
955 return getReductionInitValue(builder, loc, seqTy.getEleTy(), op);
956
957 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
958 return getReductionInitValue(builder, loc, boxTy.getEleTy(), op);
959
960 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
961 return getReductionInitValue(builder, loc, heapTy.getEleTy(), op);
962
963 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
964 return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op);
965
966 llvm::report_fatal_error(reason: "Unsupported OpenACC reduction type");
967}
968
969static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
970 mlir::Location loc, mlir::Type ty,
971 mlir::acc::ReductionOperator op) {
972 ty = fir::unwrapRefType(ty);
973 mlir::Value initValue = getReductionInitValue(builder, loc, ty, op);
974 if (fir::isa_trivial(ty)) {
975 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
976 auto declareOp = builder.create<hlfir::DeclareOp>(
977 loc, alloca, accReductionInitName, /*shape=*/nullptr,
978 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
979 builder.create<fir::StoreOp>(loc, builder.createConvert(loc, ty, initValue),
980 declareOp.getBase());
981 return declareOp.getBase();
982 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
983 if (fir::isa_trivial(seqTy.getEleTy())) {
984 mlir::Value shape;
985 auto extents = builder.getBlock()->getArguments().drop_front(1);
986 if (seqTy.hasDynamicExtents())
987 shape = builder.create<fir::ShapeOp>(loc, extents);
988 else
989 shape = genShapeOp(builder, seqTy, loc);
990 mlir::Value alloca = builder.create<fir::AllocaOp>(
991 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
992 auto declareOp = builder.create<hlfir::DeclareOp>(
993 loc, alloca, accReductionInitName, shape,
994 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
995 mlir::Type idxTy = builder.getIndexType();
996 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
997 llvm::SmallVector<fir::DoLoopOp> loops;
998 llvm::SmallVector<mlir::Value> ivs;
999
1000 if (seqTy.hasDynamicExtents()) {
1001 builder.create<hlfir::AssignOp>(loc, initValue, declareOp.getBase());
1002 return declareOp.getBase();
1003 }
1004 for (auto ext : llvm::reverse(seqTy.getShape())) {
1005 auto lb = builder.createIntegerConstant(loc, idxTy, 0);
1006 auto ub = builder.createIntegerConstant(loc, idxTy, ext - 1);
1007 auto step = builder.createIntegerConstant(loc, idxTy, 1);
1008 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1009 /*unordered=*/false);
1010 builder.setInsertionPointToStart(loop.getBody());
1011 loops.push_back(loop);
1012 ivs.push_back(loop.getInductionVar());
1013 }
1014 auto coord = builder.create<fir::CoordinateOp>(loc, refTy,
1015 declareOp.getBase(), ivs);
1016 builder.create<fir::StoreOp>(loc, initValue, coord);
1017 builder.setInsertionPointAfter(loops[0]);
1018 return declareOp.getBase();
1019 }
1020 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
1021 mlir::Type innerTy = fir::extractSequenceType(boxTy);
1022 if (!mlir::isa<fir::SequenceType>(innerTy))
1023 TODO(loc, "Unsupported boxed type for reduction");
1024 // Create the private copy from the initial fir.box.
1025 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
1026 auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, source);
1027 builder.create<hlfir::AssignOp>(loc, initValue, temp);
1028 return temp;
1029 }
1030 llvm::report_fatal_error(reason: "Unsupported OpenACC reduction type");
1031}
1032
1033template <typename Op>
1034static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder,
1035 mlir::Location loc, mlir::Value value1,
1036 mlir::Value value2) {
1037 mlir::Type i1 = builder.getI1Type();
1038 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1039 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1040 mlir::Value combined = builder.create<Op>(loc, v1, v2);
1041 return builder.create<fir::ConvertOp>(loc, value1.getType(), combined);
1042}
1043
1044static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder,
1045 mlir::Location loc,
1046 mlir::arith::CmpIPredicate pred,
1047 mlir::Value value1,
1048 mlir::Value value2) {
1049 mlir::Type i1 = builder.getI1Type();
1050 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1051 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1052 mlir::Value add = builder.create<mlir::arith::CmpIOp>(loc, pred, v1, v2);
1053 return builder.create<fir::ConvertOp>(loc, value1.getType(), add);
1054}
1055
1056static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
1057 mlir::Location loc,
1058 mlir::acc::ReductionOperator op,
1059 mlir::Type ty, mlir::Value value1,
1060 mlir::Value value2) {
1061 value1 = builder.loadIfRef(loc, value1);
1062 value2 = builder.loadIfRef(loc, value2);
1063 if (op == mlir::acc::ReductionOperator::AccAdd) {
1064 if (ty.isIntOrIndex())
1065 return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
1066 if (mlir::isa<mlir::FloatType>(ty))
1067 return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
1068 if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty))
1069 return builder.create<fir::AddcOp>(loc, value1, value2);
1070 TODO(loc, "reduction add type");
1071 }
1072
1073 if (op == mlir::acc::ReductionOperator::AccMul) {
1074 if (ty.isIntOrIndex())
1075 return builder.create<mlir::arith::MulIOp>(loc, value1, value2);
1076 if (mlir::isa<mlir::FloatType>(ty))
1077 return builder.create<mlir::arith::MulFOp>(loc, value1, value2);
1078 if (mlir::isa<fir::ComplexType>(ty))
1079 return builder.create<fir::MulcOp>(loc, value1, value2);
1080 TODO(loc, "reduction mul type");
1081 }
1082
1083 if (op == mlir::acc::ReductionOperator::AccMin)
1084 return fir::genMin(builder, loc, {value1, value2});
1085
1086 if (op == mlir::acc::ReductionOperator::AccMax)
1087 return fir::genMax(builder, loc, {value1, value2});
1088
1089 if (op == mlir::acc::ReductionOperator::AccIand)
1090 return builder.create<mlir::arith::AndIOp>(loc, value1, value2);
1091
1092 if (op == mlir::acc::ReductionOperator::AccIor)
1093 return builder.create<mlir::arith::OrIOp>(loc, value1, value2);
1094
1095 if (op == mlir::acc::ReductionOperator::AccXor)
1096 return builder.create<mlir::arith::XOrIOp>(loc, value1, value2);
1097
1098 if (op == mlir::acc::ReductionOperator::AccLand)
1099 return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1,
1100 value2);
1101
1102 if (op == mlir::acc::ReductionOperator::AccLor)
1103 return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2);
1104
1105 if (op == mlir::acc::ReductionOperator::AccEqv)
1106 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq,
1107 value1, value2);
1108
1109 if (op == mlir::acc::ReductionOperator::AccNeqv)
1110 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne,
1111 value1, value2);
1112
1113 TODO(loc, "reduction operator");
1114}
1115
1116static hlfir::DesignateOp::Subscripts
1117getTripletsFromArgs(mlir::acc::ReductionRecipeOp recipe) {
1118 hlfir::DesignateOp::Subscripts triplets;
1119 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1120 i += 3)
1121 triplets.emplace_back(hlfir::DesignateOp::Triplet{
1122 recipe.getCombinerRegion().getArgument(i),
1123 recipe.getCombinerRegion().getArgument(i + 1),
1124 recipe.getCombinerRegion().getArgument(i + 2)});
1125 return triplets;
1126}
1127
1128static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
1129 mlir::acc::ReductionOperator op, mlir::Type ty,
1130 mlir::Value value1, mlir::Value value2,
1131 mlir::acc::ReductionRecipeOp &recipe,
1132 llvm::SmallVector<mlir::Value> &bounds,
1133 bool allConstantBound) {
1134 ty = fir::unwrapRefType(ty);
1135
1136 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1137 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
1138 llvm::SmallVector<fir::DoLoopOp> loops;
1139 llvm::SmallVector<mlir::Value> ivs;
1140 if (seqTy.hasDynamicExtents()) {
1141 auto shape =
1142 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds,
1143 recipe.getCombinerRegion().getArguments());
1144 auto v1DeclareOp = builder.create<hlfir::DeclareOp>(
1145 loc, value1, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1146 fir::FortranVariableFlagsAttr{});
1147 auto v2DeclareOp = builder.create<hlfir::DeclareOp>(
1148 loc, value2, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1149 fir::FortranVariableFlagsAttr{});
1150 hlfir::DesignateOp::Subscripts triplets = getTripletsFromArgs(recipe);
1151
1152 llvm::SmallVector<mlir::Value> lenParamsLeft;
1153 auto leftEntity = hlfir::Entity{v1DeclareOp.getBase()};
1154 hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
1155 auto leftDesignate = builder.create<hlfir::DesignateOp>(
1156 loc, v1DeclareOp.getBase().getType(), v1DeclareOp.getBase(),
1157 /*component=*/"",
1158 /*componentShape=*/mlir::Value{}, triplets,
1159 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1160 shape, lenParamsLeft);
1161 auto left = hlfir::Entity{leftDesignate.getResult()};
1162
1163 llvm::SmallVector<mlir::Value> lenParamsRight;
1164 auto rightEntity = hlfir::Entity{v2DeclareOp.getBase()};
1165 hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsLeft);
1166 auto rightDesignate = builder.create<hlfir::DesignateOp>(
1167 loc, v2DeclareOp.getBase().getType(), v2DeclareOp.getBase(),
1168 /*component=*/"",
1169 /*componentShape=*/mlir::Value{}, triplets,
1170 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1171 shape, lenParamsRight);
1172 auto right = hlfir::Entity{rightDesignate.getResult()};
1173
1174 llvm::SmallVector<mlir::Value, 1> typeParams;
1175 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1176 mlir::Location l, fir::FirOpBuilder &b,
1177 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1178 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1179 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1180 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1181 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1182 return hlfir::Entity{genScalarCombiner(
1183 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)};
1184 };
1185 mlir::Value elemental = hlfir::genElementalOp(
1186 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1187 /*isUnordered=*/true);
1188 builder.create<hlfir::AssignOp>(loc, elemental, v1DeclareOp.getBase());
1189 return;
1190 }
1191 if (allConstantBound) {
1192 // Use the constant bound directly in the combiner region so they do not
1193 // need to be passed as block argument.
1194 for (auto bound : llvm::reverse(C&: bounds)) {
1195 auto dataBound =
1196 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1197 llvm::SmallVector<mlir::Value> values =
1198 genConstantBounds(builder, loc, dataBound);
1199 auto loop =
1200 builder.create<fir::DoLoopOp>(loc, values[0], values[1], values[2],
1201 /*unordered=*/false);
1202 builder.setInsertionPointToStart(loop.getBody());
1203 loops.push_back(loop);
1204 ivs.push_back(Elt: loop.getInductionVar());
1205 }
1206 } else {
1207 // Lowerbound, upperbound and step are passed as block arguments.
1208 [[maybe_unused]] unsigned nbRangeArgs =
1209 recipe.getCombinerRegion().getArguments().size() - 2;
1210 assert((nbRangeArgs / 3 == seqTy.getDimension()) &&
1211 "Expect 3 block arguments per dimension");
1212 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1213 i += 3) {
1214 mlir::Value lb = recipe.getCombinerRegion().getArgument(i);
1215 mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1);
1216 mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2);
1217 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1218 /*unordered=*/false);
1219 builder.setInsertionPointToStart(loop.getBody());
1220 loops.push_back(loop);
1221 ivs.push_back(Elt: loop.getInductionVar());
1222 }
1223 }
1224 auto addr1 = builder.create<fir::CoordinateOp>(loc, refTy, value1, ivs);
1225 auto addr2 = builder.create<fir::CoordinateOp>(loc, refTy, value2, ivs);
1226 auto load1 = builder.create<fir::LoadOp>(loc, addr1);
1227 auto load2 = builder.create<fir::LoadOp>(loc, addr2);
1228 mlir::Value res =
1229 genScalarCombiner(builder, loc, op, seqTy.getEleTy(), load1, load2);
1230 builder.create<fir::StoreOp>(loc, res, addr1);
1231 builder.setInsertionPointAfter(loops[0]);
1232 } else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
1233 mlir::Type innerTy = fir::extractSequenceType(boxTy);
1234 fir::SequenceType seqTy =
1235 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
1236 if (!seqTy)
1237 TODO(loc, "Unsupported boxed type in OpenACC reduction");
1238
1239 auto shape = genShapeFromBoundsOrArgs(
1240 loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments());
1241 hlfir::DesignateOp::Subscripts triplets =
1242 getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
1243 auto leftEntity = hlfir::Entity{value1};
1244 auto left =
1245 genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
1246 auto rightEntity = hlfir::Entity{value2};
1247 auto right =
1248 genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
1249
1250 llvm::SmallVector<mlir::Value, 1> typeParams;
1251 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1252 mlir::Location l, fir::FirOpBuilder &b,
1253 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1254 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1255 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1256 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1257 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1258 return hlfir::Entity{genScalarCombiner(builder, loc, op, seqTy.getEleTy(),
1259 leftVal, rightVal)};
1260 };
1261 mlir::Value elemental = hlfir::genElementalOp(
1262 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1263 /*isUnordered=*/true);
1264 builder.create<hlfir::AssignOp>(loc, elemental, value1);
1265 } else {
1266 mlir::Value res = genScalarCombiner(builder, loc, op, ty, value1, value2);
1267 builder.create<fir::StoreOp>(loc, res, value1);
1268 }
1269}
1270
1271mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
1272 fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
1273 mlir::Type ty, mlir::acc::ReductionOperator op,
1274 llvm::SmallVector<mlir::Value> &bounds) {
1275 mlir::ModuleOp mod =
1276 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
1277 if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
1278 return recipe;
1279
1280 auto crtPos = builder.saveInsertionPoint();
1281 mlir::OpBuilder modBuilder(mod.getBodyRegion());
1282 auto recipe =
1283 modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
1284 llvm::SmallVector<mlir::Type> initArgsTy{ty};
1285 llvm::SmallVector<mlir::Location> initArgsLoc{loc};
1286 mlir::Type refTy = fir::unwrapRefType(ty);
1287 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) {
1288 if (seqTy.hasDynamicExtents()) {
1289 mlir::Type idxTy = builder.getIndexType();
1290 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
1291 initArgsTy.push_back(Elt: idxTy);
1292 initArgsLoc.push_back(Elt: loc);
1293 }
1294 }
1295 }
1296 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
1297 initArgsTy, initArgsLoc);
1298 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
1299 mlir::Value initValue = genReductionInitRegion(builder, loc, ty, op);
1300 builder.create<mlir::acc::YieldOp>(loc, initValue);
1301
1302 // The two first block arguments are the two values to be combined.
1303 // The next arguments are the iteration ranges (lb, ub, step) to be used
1304 // for the combiner if needed.
1305 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
1306 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
1307 bool allConstantBound = areAllBoundConstant(bounds);
1308 if (!allConstantBound) {
1309 for (mlir::Value bound : llvm::reverse(C&: bounds)) {
1310 auto dataBound =
1311 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1312 argsTy.push_back(Elt: dataBound.getLowerbound().getType());
1313 argsLoc.push_back(Elt: dataBound.getLowerbound().getLoc());
1314 argsTy.push_back(Elt: dataBound.getUpperbound().getType());
1315 argsLoc.push_back(Elt: dataBound.getUpperbound().getLoc());
1316 argsTy.push_back(Elt: dataBound.getStartIdx().getType());
1317 argsLoc.push_back(Elt: dataBound.getStartIdx().getLoc());
1318 }
1319 }
1320 builder.createBlock(&recipe.getCombinerRegion(),
1321 recipe.getCombinerRegion().end(), argsTy, argsLoc);
1322 builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
1323 mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
1324 mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
1325 genCombiner(builder, loc, op, ty, v1, v2, recipe, bounds, allConstantBound);
1326 builder.create<mlir::acc::YieldOp>(loc, v1);
1327 builder.restoreInsertionPoint(crtPos);
1328 return recipe;
1329}
1330
1331static bool isSupportedReductionType(mlir::Type ty) {
1332 ty = fir::unwrapRefType(ty);
1333 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
1334 return isSupportedReductionType(boxTy.getEleTy());
1335 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
1336 return isSupportedReductionType(seqTy.getEleTy());
1337 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
1338 return isSupportedReductionType(heapTy.getEleTy());
1339 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
1340 return isSupportedReductionType(ptrTy.getEleTy());
1341 return fir::isa_trivial(ty);
1342}
1343
1344static void
1345genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
1346 Fortran::lower::AbstractConverter &converter,
1347 Fortran::semantics::SemanticsContext &semanticsContext,
1348 Fortran::lower::StatementContext &stmtCtx,
1349 llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
1350 llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
1351 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1352 const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
1353 const auto &op =
1354 std::get<Fortran::parser::AccReductionOperator>(objectList.t);
1355 mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
1356 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
1357 for (const auto &accObject : objects.v) {
1358 llvm::SmallVector<mlir::Value> bounds;
1359 std::stringstream asFortran;
1360 mlir::Location operandLocation = genOperandLocation(converter, accObject);
1361 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
1362 Fortran::semantics::MaybeExpr designator =
1363 std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
1364 Fortran::lower::AddrAndBoundsInfo info =
1365 Fortran::lower::gatherDataOperandAddrAndBounds<
1366 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
1367 converter, builder, semanticsContext, stmtCtx, symbol, designator,
1368 operandLocation, asFortran, bounds);
1369
1370 mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
1371 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
1372 reductionTy = seqTy.getEleTy();
1373
1374 if (!isSupportedReductionType(reductionTy))
1375 TODO(operandLocation, "reduction with unsupported type");
1376
1377 auto op = createDataEntryOp<mlir::acc::ReductionOp>(
1378 builder, operandLocation, info.addr, asFortran, bounds,
1379 /*structured=*/true, /*implicit=*/false,
1380 mlir::acc::DataClause::acc_reduction, info.addr.getType());
1381 mlir::Type ty = op.getAccPtr().getType();
1382 if (!areAllBoundConstant(bounds) ||
1383 fir::isAssumedShape(info.addr.getType()) ||
1384 fir::isAllocatableOrPointerArray(info.addr.getType()))
1385 ty = info.addr.getType();
1386 std::string suffix =
1387 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
1388 std::string recipeName = fir::getTypeAsString(
1389 ty, converter.getKindMap(),
1390 ("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
1391
1392 mlir::acc::ReductionRecipeOp recipe =
1393 Fortran::lower::createOrGetReductionRecipe(
1394 builder, recipeName, operandLocation, ty, mlirOp, bounds);
1395 reductionRecipes.push_back(mlir::SymbolRefAttr::get(
1396 builder.getContext(), recipe.getSymName().str()));
1397 reductionOperands.push_back(op.getAccPtr());
1398 }
1399}
1400
1401static void
1402addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
1403 llvm::SmallVectorImpl<int32_t> &operandSegments,
1404 const llvm::SmallVectorImpl<mlir::Value> &clauseOperands) {
1405 operands.append(in_start: clauseOperands.begin(), in_end: clauseOperands.end());
1406 operandSegments.push_back(Elt: clauseOperands.size());
1407}
1408
1409static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
1410 llvm::SmallVectorImpl<int32_t> &operandSegments,
1411 const mlir::Value &clauseOperand) {
1412 if (clauseOperand) {
1413 operands.push_back(Elt: clauseOperand);
1414 operandSegments.push_back(Elt: 1);
1415 } else {
1416 operandSegments.push_back(Elt: 0);
1417 }
1418}
1419
1420template <typename Op, typename Terminator>
1421static Op
1422createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1423 mlir::Location returnLoc, Fortran::lower::pft::Evaluation &eval,
1424 const llvm::SmallVectorImpl<mlir::Value> &operands,
1425 const llvm::SmallVectorImpl<int32_t> &operandSegments,
1426 bool outerCombined = false,
1427 llvm::SmallVector<mlir::Type> retTy = {},
1428 mlir::Value yieldValue = {}, mlir::TypeRange argsTy = {},
1429 llvm::SmallVector<mlir::Location> locs = {}) {
1430 Op op = builder.create<Op>(loc, retTy, operands);
1431 builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs);
1432 mlir::Block &block = op.getRegion().back();
1433 builder.setInsertionPointToStart(&block);
1434
1435 op->setAttr(Op::getOperandSegmentSizeAttr(),
1436 builder.getDenseI32ArrayAttr(operandSegments));
1437
1438 // Place the insertion point to the start of the first block.
1439 builder.setInsertionPointToStart(&block);
1440
1441 // If it is an unstructured region and is not the outer region of a combined
1442 // construct, create empty blocks for all evaluations.
1443 if (eval.lowerAsUnstructured() && !outerCombined)
1444 Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp,
1445 mlir::acc::YieldOp>(
1446 builder, eval.getNestedEvaluations());
1447
1448 if (yieldValue) {
1449 if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
1450 Terminator yieldOp = builder.create<Terminator>(returnLoc, yieldValue);
1451 yieldValue.getDefiningOp()->moveBefore(yieldOp);
1452 } else {
1453 builder.create<Terminator>(returnLoc);
1454 }
1455 } else {
1456 builder.create<Terminator>(returnLoc);
1457 }
1458 builder.setInsertionPointToStart(&block);
1459 return op;
1460}
1461
1462static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
1463 const Fortran::parser::AccClause::Async *asyncClause,
1464 mlir::Value &async, bool &addAsyncAttr,
1465 Fortran::lower::StatementContext &stmtCtx) {
1466 const auto &asyncClauseValue = asyncClause->v;
1467 if (asyncClauseValue) { // async has a value.
1468 async = fir::getBase(converter.genExprValue(
1469 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1470 } else {
1471 addAsyncAttr = true;
1472 }
1473}
1474
1475static void
1476genAsyncClause(Fortran::lower::AbstractConverter &converter,
1477 const Fortran::parser::AccClause::Async *asyncClause,
1478 llvm::SmallVector<mlir::Value> &async,
1479 llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
1480 llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
1481 llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
1482 Fortran::lower::StatementContext &stmtCtx) {
1483 const auto &asyncClauseValue = asyncClause->v;
1484 if (asyncClauseValue) { // async has a value.
1485 mlir::Value asyncValue = fir::getBase(converter.genExprValue(
1486 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1487 for (auto deviceTypeAttr : deviceTypeAttrs) {
1488 async.push_back(Elt: asyncValue);
1489 asyncDeviceTypes.push_back(Elt: deviceTypeAttr);
1490 }
1491 } else {
1492 for (auto deviceTypeAttr : deviceTypeAttrs)
1493 asyncOnlyDeviceTypes.push_back(Elt: deviceTypeAttr);
1494 }
1495}
1496
1497static mlir::acc::DeviceType
1498getDeviceType(Fortran::common::OpenACCDeviceType device) {
1499 switch (device) {
1500 case Fortran::common::OpenACCDeviceType::Star:
1501 return mlir::acc::DeviceType::Star;
1502 case Fortran::common::OpenACCDeviceType::Default:
1503 return mlir::acc::DeviceType::Default;
1504 case Fortran::common::OpenACCDeviceType::Nvidia:
1505 return mlir::acc::DeviceType::Nvidia;
1506 case Fortran::common::OpenACCDeviceType::Radeon:
1507 return mlir::acc::DeviceType::Radeon;
1508 case Fortran::common::OpenACCDeviceType::Host:
1509 return mlir::acc::DeviceType::Host;
1510 case Fortran::common::OpenACCDeviceType::Multicore:
1511 return mlir::acc::DeviceType::Multicore;
1512 case Fortran::common::OpenACCDeviceType::None:
1513 return mlir::acc::DeviceType::None;
1514 }
1515 return mlir::acc::DeviceType::None;
1516}
1517
1518static void gatherDeviceTypeAttrs(
1519 fir::FirOpBuilder &builder,
1520 const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
1521 llvm::SmallVector<mlir::Attribute> &deviceTypes) {
1522 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1523 deviceTypeClause->v;
1524 for (const auto &deviceTypeExpr : deviceTypeExprList.v)
1525 deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
1526 builder.getContext(), getDeviceType(deviceTypeExpr.v)));
1527}
1528
1529static void genIfClause(Fortran::lower::AbstractConverter &converter,
1530 mlir::Location clauseLocation,
1531 const Fortran::parser::AccClause::If *ifClause,
1532 mlir::Value &ifCond,
1533 Fortran::lower::StatementContext &stmtCtx) {
1534 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1535 mlir::Value cond = fir::getBase(converter.genExprValue(
1536 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx, &clauseLocation));
1537 ifCond = firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
1538 cond);
1539}
1540
1541static void genWaitClause(Fortran::lower::AbstractConverter &converter,
1542 const Fortran::parser::AccClause::Wait *waitClause,
1543 llvm::SmallVectorImpl<mlir::Value> &operands,
1544 mlir::Value &waitDevnum, bool &addWaitAttr,
1545 Fortran::lower::StatementContext &stmtCtx) {
1546 const auto &waitClauseValue = waitClause->v;
1547 if (waitClauseValue) { // wait has a value.
1548 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1549 const auto &waitList =
1550 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1551 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1552 mlir::Value v = fir::getBase(
1553 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
1554 operands.push_back(v);
1555 }
1556
1557 const auto &waitDevnumValue =
1558 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1559 if (waitDevnumValue)
1560 waitDevnum = fir::getBase(converter.genExprValue(
1561 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
1562 } else {
1563 addWaitAttr = true;
1564 }
1565}
1566
1567static void genWaitClauseWithDeviceType(
1568 Fortran::lower::AbstractConverter &converter,
1569 const Fortran::parser::AccClause::Wait *waitClause,
1570 llvm::SmallVector<mlir::Value> &waitOperands,
1571 llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1572 llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1573 llvm::SmallVector<bool> &hasDevnums,
1574 llvm::SmallVector<int32_t> &waitOperandsSegments,
1575 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1576 Fortran::lower::StatementContext &stmtCtx) {
1577 const auto &waitClauseValue = waitClause->v;
1578 if (waitClauseValue) { // wait has a value.
1579 llvm::SmallVector<mlir::Value> waitValues;
1580
1581 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1582 const auto &waitDevnumValue =
1583 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1584 bool hasDevnum = false;
1585 if (waitDevnumValue) {
1586 waitValues.push_back(fir::getBase(converter.genExprValue(
1587 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
1588 hasDevnum = true;
1589 }
1590
1591 const auto &waitList =
1592 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1593 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1594 waitValues.push_back(fir::getBase(converter.genExprValue(
1595 *Fortran::semantics::GetExpr(value), stmtCtx)));
1596 }
1597
1598 for (auto deviceTypeAttr : deviceTypeAttrs) {
1599 for (auto value : waitValues)
1600 waitOperands.push_back(Elt: value);
1601 waitOperandsDeviceTypes.push_back(Elt: deviceTypeAttr);
1602 waitOperandsSegments.push_back(Elt: waitValues.size());
1603 hasDevnums.push_back(Elt: hasDevnum);
1604 }
1605 } else {
1606 for (auto deviceTypeAttr : deviceTypeAttrs)
1607 waitOnlyDeviceTypes.push_back(Elt: deviceTypeAttr);
1608 }
1609}
1610
1611mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
1612 const Fortran::semantics::Symbol &ivSym) {
1613 std::size_t ivTypeSize = ivSym.size();
1614 if (ivTypeSize == 0)
1615 llvm::report_fatal_error(reason: "unexpected induction variable size");
1616 // ivTypeSize is in bytes and IntegerType needs to be in bits.
1617 return builder.getIntegerType(ivTypeSize * 8);
1618}
1619
1620static void privatizeIv(Fortran::lower::AbstractConverter &converter,
1621 const Fortran::semantics::Symbol &sym,
1622 mlir::Location loc,
1623 llvm::SmallVector<mlir::Type> &ivTypes,
1624 llvm::SmallVector<mlir::Location> &ivLocs,
1625 llvm::SmallVector<mlir::Value> &privateOperands,
1626 llvm::SmallVector<mlir::Value> &ivPrivate,
1627 llvm::SmallVector<mlir::Attribute> &privatizations,
1628 bool isDoConcurrent = false) {
1629 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1630
1631 mlir::Type ivTy = getTypeFromIvTypeSize(builder, sym);
1632 ivTypes.push_back(Elt: ivTy);
1633 ivLocs.push_back(Elt: loc);
1634 mlir::Value ivValue = converter.getSymbolAddress(sym);
1635 if (!ivValue && isDoConcurrent) {
1636 // DO CONCURRENT induction variables are not mapped yet since they are local
1637 // to the DO CONCURRENT scope.
1638 mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
1639 builder.setInsertionPointToStart(builder.getAllocaBlock());
1640 ivValue = builder.createTemporaryAlloc(loc, ivTy, toStringRef(sym.name()));
1641 builder.restoreInsertionPoint(insPt);
1642 }
1643
1644 std::string recipeName =
1645 fir::getTypeAsString(ivValue.getType(), converter.getKindMap(),
1646 Fortran::lower::privatizationRecipePrefix);
1647 auto recipe = Fortran::lower::createOrGetPrivateRecipe(
1648 builder, recipeName, loc, ivValue.getType());
1649
1650 std::stringstream asFortran;
1651 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
1652 builder, loc, ivValue, asFortran, {}, true, /*implicit=*/true,
1653 mlir::acc::DataClause::acc_private, ivValue.getType());
1654
1655 privateOperands.push_back(Elt: op.getAccPtr());
1656 privatizations.push_back(mlir::SymbolRefAttr::get(builder.getContext(),
1657 recipe.getSymName().str()));
1658
1659 // Map the new private iv to its symbol for the scope of the loop. bindSymbol
1660 // might create a hlfir.declare op, if so, we map its result in order to
1661 // use the sym value in the scope.
1662 converter.bindSymbol(sym, op.getAccPtr());
1663 auto privateValue = converter.getSymbolAddress(sym);
1664 if (auto declareOp =
1665 mlir::dyn_cast<hlfir::DeclareOp>(privateValue.getDefiningOp()))
1666 privateValue = declareOp.getResults()[0];
1667 ivPrivate.push_back(Elt: privateValue);
1668}
1669
1670static mlir::acc::LoopOp createLoopOp(
1671 Fortran::lower::AbstractConverter &converter,
1672 mlir::Location currentLocation,
1673 Fortran::semantics::SemanticsContext &semanticsContext,
1674 Fortran::lower::StatementContext &stmtCtx,
1675 const Fortran::parser::DoConstruct &outerDoConstruct,
1676 Fortran::lower::pft::Evaluation &eval,
1677 const Fortran::parser::AccClauseList &accClauseList,
1678 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
1679 std::nullopt,
1680 bool needEarlyReturnHandling = false) {
1681 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1682 llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
1683 reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
1684 gangOperands, lowerbounds, upperbounds, steps;
1685 llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
1686 llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
1687 llvm::SmallVector<int64_t> collapseValues;
1688
1689 llvm::SmallVector<mlir::Attribute> gangArgTypes;
1690 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
1691 autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
1692 vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
1693 collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
1694
1695 // device_type attribute is set to `none` until a device_type clause is
1696 // encountered.
1697 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
1698 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
1699 builder.getContext(), mlir::acc::DeviceType::None));
1700
1701 llvm::SmallVector<mlir::Type> ivTypes;
1702 llvm::SmallVector<mlir::Location> ivLocs;
1703 llvm::SmallVector<bool> inclusiveBounds;
1704
1705 llvm::SmallVector<mlir::Location> locs;
1706 locs.push_back(Elt: currentLocation); // Location of the directive
1707 Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
1708 bool isDoConcurrent = outerDoConstruct.IsDoConcurrent();
1709 if (isDoConcurrent) {
1710 locs.push_back(converter.genLocation(
1711 Fortran::parser::FindSourceLocation(outerDoConstruct)));
1712 const Fortran::parser::LoopControl *loopControl =
1713 &*outerDoConstruct.GetLoopControl();
1714 const auto &concurrent =
1715 std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);
1716 if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
1717 .empty())
1718 TODO(currentLocation, "DO CONCURRENT with locality spec");
1719
1720 const auto &concurrentHeader =
1721 std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
1722 const auto &controls =
1723 std::get<std::list<Fortran::parser::ConcurrentControl>>(
1724 concurrentHeader.t);
1725 for (const auto &control : controls) {
1726 lowerbounds.push_back(fir::getBase(converter.genExprValue(
1727 *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx)));
1728 upperbounds.push_back(fir::getBase(converter.genExprValue(
1729 *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx)));
1730 if (const auto &expr =
1731 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
1732 control.t))
1733 steps.push_back(fir::getBase(converter.genExprValue(
1734 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
1735 else // If `step` is not present, assume it is `1`.
1736 steps.push_back(builder.createIntegerConstant(
1737 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
1738
1739 const auto &name = std::get<Fortran::parser::Name>(control.t);
1740 privatizeIv(converter, *name.symbol, currentLocation, ivTypes, ivLocs,
1741 privateOperands, ivPrivate, privatizations, isDoConcurrent);
1742
1743 inclusiveBounds.push_back(true);
1744 }
1745 } else {
1746 int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList);
1747 for (unsigned i = 0; i < collapseValue; ++i) {
1748 const Fortran::parser::LoopControl *loopControl;
1749 if (i == 0) {
1750 loopControl = &*outerDoConstruct.GetLoopControl();
1751 locs.push_back(converter.genLocation(
1752 Fortran::parser::FindSourceLocation(outerDoConstruct)));
1753 } else {
1754 auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
1755 assert(doCons && "expect do construct");
1756 loopControl = &*doCons->GetLoopControl();
1757 locs.push_back(converter.genLocation(
1758 Fortran::parser::FindSourceLocation(*doCons)));
1759 }
1760
1761 const Fortran::parser::LoopControl::Bounds *bounds =
1762 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
1763 assert(bounds && "Expected bounds on the loop construct");
1764 lowerbounds.push_back(fir::getBase(converter.genExprValue(
1765 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
1766 upperbounds.push_back(fir::getBase(converter.genExprValue(
1767 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
1768 if (bounds->step)
1769 steps.push_back(fir::getBase(converter.genExprValue(
1770 *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
1771 else // If `step` is not present, assume it is `1`.
1772 steps.push_back(Elt: builder.createIntegerConstant(
1773 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
1774
1775 Fortran::semantics::Symbol &ivSym =
1776 bounds->name.thing.symbol->GetUltimate();
1777 privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
1778 privateOperands, ivPrivate, privatizations);
1779
1780 inclusiveBounds.push_back(Elt: true);
1781
1782 if (i < collapseValue - 1)
1783 crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
1784 }
1785 }
1786
1787 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
1788 mlir::Location clauseLocation = converter.genLocation(clause.source);
1789 if (const auto *gangClause =
1790 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
1791 if (gangClause->v) {
1792 const Fortran::parser::AccGangArgList &x = *gangClause->v;
1793 mlir::SmallVector<mlir::Value> gangValues;
1794 mlir::SmallVector<mlir::Attribute> gangArgs;
1795 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
1796 if (const auto *num =
1797 std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
1798 gangValues.push_back(fir::getBase(converter.genExprValue(
1799 *Fortran::semantics::GetExpr(num->v), stmtCtx)));
1800 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
1801 builder.getContext(), mlir::acc::GangArgType::Num));
1802 } else if (const auto *staticArg =
1803 std::get_if<Fortran::parser::AccGangArg::Static>(
1804 &gangArg.u)) {
1805 const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
1806 if (sizeExpr.v) {
1807 gangValues.push_back(fir::getBase(converter.genExprValue(
1808 *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
1809 } else {
1810 // * was passed as value and will be represented as a special
1811 // constant.
1812 gangValues.push_back(builder.createIntegerConstant(
1813 clauseLocation, builder.getIndexType(), starCst));
1814 }
1815 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
1816 builder.getContext(), mlir::acc::GangArgType::Static));
1817 } else if (const auto *dim =
1818 std::get_if<Fortran::parser::AccGangArg::Dim>(
1819 &gangArg.u)) {
1820 gangValues.push_back(fir::getBase(converter.genExprValue(
1821 *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
1822 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
1823 builder.getContext(), mlir::acc::GangArgType::Dim));
1824 }
1825 }
1826 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
1827 for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
1828 gangOperands.push_back(std::get<0>(pair));
1829 gangArgTypes.push_back(std::get<1>(pair));
1830 }
1831 gangOperandsSegments.push_back(gangValues.size());
1832 gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1833 }
1834 } else {
1835 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1836 gangDeviceTypes.push_back(crtDeviceTypeAttr);
1837 }
1838 } else if (const auto *workerClause =
1839 std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
1840 if (workerClause->v) {
1841 mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
1842 *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
1843 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
1844 workerNumOperands.push_back(workerNumValue);
1845 workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1846 }
1847 } else {
1848 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1849 workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
1850 }
1851 } else if (const auto *vectorClause =
1852 std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
1853 if (vectorClause->v) {
1854 mlir::Value vectorValue = fir::getBase(converter.genExprValue(
1855 *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
1856 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
1857 vectorOperands.push_back(vectorValue);
1858 vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1859 }
1860 } else {
1861 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1862 vectorDeviceTypes.push_back(crtDeviceTypeAttr);
1863 }
1864 } else if (const auto *tileClause =
1865 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
1866 const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
1867 llvm::SmallVector<mlir::Value> tileValues;
1868 for (const auto &accTileExpr : accTileExprList.v) {
1869 const auto &expr =
1870 std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
1871 accTileExpr.t);
1872 if (expr) {
1873 tileValues.push_back(fir::getBase(converter.genExprValue(
1874 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
1875 } else {
1876 // * was passed as value and will be represented as a special
1877 // constant.
1878 mlir::Value tileStar = builder.createIntegerConstant(
1879 clauseLocation, builder.getIntegerType(32), starCst);
1880 tileValues.push_back(tileStar);
1881 }
1882 }
1883 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
1884 for (auto value : tileValues)
1885 tileOperands.push_back(value);
1886 tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1887 tileOperandsSegments.push_back(tileValues.size());
1888 }
1889 } else if (const auto *privateClause =
1890 std::get_if<Fortran::parser::AccClause::Private>(
1891 &clause.u)) {
1892 genPrivatizations<mlir::acc::PrivateRecipeOp>(
1893 privateClause->v, converter, semanticsContext, stmtCtx,
1894 privateOperands, privatizations);
1895 } else if (const auto *reductionClause =
1896 std::get_if<Fortran::parser::AccClause::Reduction>(
1897 &clause.u)) {
1898 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
1899 reductionOperands, reductionRecipes);
1900 } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
1901 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1902 seqDeviceTypes.push_back(crtDeviceTypeAttr);
1903 } else if (std::get_if<Fortran::parser::AccClause::Independent>(
1904 &clause.u)) {
1905 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1906 independentDeviceTypes.push_back(crtDeviceTypeAttr);
1907 } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
1908 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1909 autoDeviceTypes.push_back(crtDeviceTypeAttr);
1910 } else if (const auto *deviceTypeClause =
1911 std::get_if<Fortran::parser::AccClause::DeviceType>(
1912 &clause.u)) {
1913 crtDeviceTypes.clear();
1914 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
1915 } else if (const auto *collapseClause =
1916 std::get_if<Fortran::parser::AccClause::Collapse>(
1917 &clause.u)) {
1918 const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
1919 const auto &force = std::get<bool>(arg.t);
1920 if (force)
1921 TODO(clauseLocation, "OpenACC collapse force modifier");
1922
1923 const auto &intExpr =
1924 std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
1925 const auto *expr = Fortran::semantics::GetExpr(intExpr);
1926 const std::optional<int64_t> collapseValue =
1927 Fortran::evaluate::ToInt64(*expr);
1928 assert(collapseValue && "expect integer value for the collapse clause");
1929
1930 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
1931 collapseValues.push_back(*collapseValue);
1932 collapseDeviceTypes.push_back(crtDeviceTypeAttr);
1933 }
1934 }
1935 }
1936
1937 // Prepare the operand segment size attribute and the operands value range.
1938 llvm::SmallVector<mlir::Value> operands;
1939 llvm::SmallVector<int32_t> operandSegments;
1940 addOperands(operands, operandSegments, clauseOperands: lowerbounds);
1941 addOperands(operands, operandSegments, clauseOperands: upperbounds);
1942 addOperands(operands, operandSegments, clauseOperands: steps);
1943 addOperands(operands, operandSegments, clauseOperands: gangOperands);
1944 addOperands(operands, operandSegments, clauseOperands: workerNumOperands);
1945 addOperands(operands, operandSegments, clauseOperands: vectorOperands);
1946 addOperands(operands, operandSegments, clauseOperands: tileOperands);
1947 addOperands(operands, operandSegments, clauseOperands: cacheOperands);
1948 addOperands(operands, operandSegments, clauseOperands: privateOperands);
1949 addOperands(operands, operandSegments, clauseOperands: reductionOperands);
1950
1951 llvm::SmallVector<mlir::Type> retTy;
1952 mlir::Value yieldValue;
1953 if (needEarlyReturnHandling) {
1954 mlir::Type i1Ty = builder.getI1Type();
1955 yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
1956 retTy.push_back(Elt: i1Ty);
1957 }
1958
1959 auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
1960 builder, builder.getFusedLoc(locs), currentLocation, eval, operands,
1961 operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes,
1962 ivLocs);
1963
1964 for (auto [arg, value] : llvm::zip(
1965 loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
1966 builder.create<fir::StoreOp>(currentLocation, arg, value);
1967
1968 loopOp.setInclusiveUpperbound(inclusiveBounds);
1969
1970 if (!gangDeviceTypes.empty())
1971 loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
1972 if (!gangArgTypes.empty())
1973 loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
1974 if (!gangOperandsSegments.empty())
1975 loopOp.setGangOperandsSegmentsAttr(
1976 builder.getDenseI32ArrayAttr(gangOperandsSegments));
1977 if (!gangOperandsDeviceTypes.empty())
1978 loopOp.setGangOperandsDeviceTypeAttr(
1979 builder.getArrayAttr(gangOperandsDeviceTypes));
1980
1981 if (!workerNumDeviceTypes.empty())
1982 loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
1983 if (!workerNumOperandsDeviceTypes.empty())
1984 loopOp.setWorkerNumOperandsDeviceTypeAttr(
1985 builder.getArrayAttr(workerNumOperandsDeviceTypes));
1986
1987 if (!vectorDeviceTypes.empty())
1988 loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
1989 if (!vectorOperandsDeviceTypes.empty())
1990 loopOp.setVectorOperandsDeviceTypeAttr(
1991 builder.getArrayAttr(vectorOperandsDeviceTypes));
1992
1993 if (!tileOperandsDeviceTypes.empty())
1994 loopOp.setTileOperandsDeviceTypeAttr(
1995 builder.getArrayAttr(tileOperandsDeviceTypes));
1996 if (!tileOperandsSegments.empty())
1997 loopOp.setTileOperandsSegmentsAttr(
1998 builder.getDenseI32ArrayAttr(tileOperandsSegments));
1999
2000 if (!seqDeviceTypes.empty())
2001 loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
2002 if (!independentDeviceTypes.empty())
2003 loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
2004 if (!autoDeviceTypes.empty())
2005 loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
2006
2007 if (!privatizations.empty())
2008 loopOp.setPrivatizationsAttr(
2009 mlir::ArrayAttr::get(builder.getContext(), privatizations));
2010
2011 if (!reductionRecipes.empty())
2012 loopOp.setReductionRecipesAttr(
2013 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
2014
2015 if (!collapseValues.empty())
2016 loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
2017 if (!collapseDeviceTypes.empty())
2018 loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
2019
2020 if (combinedConstructs)
2021 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
2022 builder.getContext(), *combinedConstructs));
2023
2024 return loopOp;
2025}
2026
2027static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2028 bool hasReturnStmt = false;
2029 for (auto &e : eval.getNestedEvaluations()) {
2030 e.visit(Fortran::common::visitors{
2031 [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2032 [&](const auto &s) {},
2033 });
2034 if (e.hasNestedEvaluations())
2035 hasReturnStmt = hasEarlyReturn(e);
2036 }
2037 return hasReturnStmt;
2038}
2039
2040static mlir::Value
2041genACC(Fortran::lower::AbstractConverter &converter,
2042 Fortran::semantics::SemanticsContext &semanticsContext,
2043 Fortran::lower::pft::Evaluation &eval,
2044 const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
2045
2046 const auto &beginLoopDirective =
2047 std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
2048 const auto &loopDirective =
2049 std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
2050
2051 bool needEarlyExitHandling = false;
2052 if (eval.lowerAsUnstructured())
2053 needEarlyExitHandling = hasEarlyReturn(eval);
2054
2055 mlir::Location currentLocation =
2056 converter.genLocation(beginLoopDirective.source);
2057 Fortran::lower::StatementContext stmtCtx;
2058
2059 assert(loopDirective.v == llvm::acc::ACCD_loop &&
2060 "Unsupported OpenACC loop construct");
2061 (void)loopDirective;
2062
2063 const auto &accClauseList =
2064 std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
2065 const auto &outerDoConstruct =
2066 std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
2067 auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
2068 stmtCtx, *outerDoConstruct, eval, accClauseList,
2069 /*combinedConstructs=*/{}, needEarlyExitHandling);
2070 if (needEarlyExitHandling)
2071 return loopOp.getResult(0);
2072
2073 return mlir::Value{};
2074}
2075
2076template <typename Op, typename Clause>
2077static void genDataOperandOperationsWithModifier(
2078 const Clause *x, Fortran::lower::AbstractConverter &converter,
2079 Fortran::semantics::SemanticsContext &semanticsContext,
2080 Fortran::lower::StatementContext &stmtCtx,
2081 Fortran::parser::AccDataModifier::Modifier mod,
2082 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
2083 const mlir::acc::DataClause clause,
2084 const mlir::acc::DataClause clauseWithModifier,
2085 bool setDeclareAttr = false) {
2086 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
2087 const auto &accObjectList =
2088 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2089 const auto &modifier =
2090 std::get<std::optional<Fortran::parser::AccDataModifier>>(
2091 listWithModifier.t);
2092 mlir::acc::DataClause dataClause =
2093 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
2094 genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
2095 stmtCtx, dataClauseOperands, dataClause,
2096 /*structured=*/true, /*implicit=*/false,
2097 setDeclareAttr);
2098}
2099
2100template <typename Op>
2101static Op createComputeOp(
2102 Fortran::lower::AbstractConverter &converter,
2103 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
2104 Fortran::semantics::SemanticsContext &semanticsContext,
2105 Fortran::lower::StatementContext &stmtCtx,
2106 const Fortran::parser::AccClauseList &accClauseList,
2107 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2108 std::nullopt) {
2109
2110 // Parallel operation operands
2111 mlir::Value ifCond;
2112 mlir::Value selfCond;
2113 llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
2114 copyEntryOperands, copyoutEntryOperands, createEntryOperands,
2115 dataClauseOperands, numGangs, numWorkers, vectorLength, async;
2116 llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
2117 vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
2118 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2119 llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
2120 llvm::SmallVector<bool> hasWaitDevnums;
2121
2122 llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
2123 firstprivateOperands;
2124 llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
2125 reductionRecipes;
2126
2127 // Self clause has optional values but can be present with
2128 // no value as well. When there is no value, the op has an attribute to
2129 // represent the clause.
2130 bool addSelfAttr = false;
2131
2132 bool hasDefaultNone = false;
2133 bool hasDefaultPresent = false;
2134
2135 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2136
2137 // device_type attribute is set to `none` until a device_type clause is
2138 // encountered.
2139 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2140 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2141 builder.getContext(), mlir::acc::DeviceType::None);
2142 crtDeviceTypes.push_back(Elt: crtDeviceTypeAttr);
2143
2144 // Lower clauses values mapped to operands and array attributes.
2145 // Keep track of each group of operands separately as clauses can appear
2146 // more than once.
2147 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2148 mlir::Location clauseLocation = converter.genLocation(clause.source);
2149 if (const auto *asyncClause =
2150 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2151 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2152 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2153 } else if (const auto *waitClause =
2154 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2155 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
2156 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2157 hasWaitDevnums, waitOperandsSegments,
2158 crtDeviceTypes, stmtCtx);
2159 } else if (const auto *numGangsClause =
2160 std::get_if<Fortran::parser::AccClause::NumGangs>(
2161 &clause.u)) {
2162 llvm::SmallVector<mlir::Value> numGangValues;
2163 for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
2164 numGangValues.push_back(fir::getBase(converter.genExprValue(
2165 *Fortran::semantics::GetExpr(expr), stmtCtx)));
2166 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2167 for (auto value : numGangValues)
2168 numGangs.push_back(value);
2169 numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
2170 numGangsSegments.push_back(numGangValues.size());
2171 }
2172 } else if (const auto *numWorkersClause =
2173 std::get_if<Fortran::parser::AccClause::NumWorkers>(
2174 &clause.u)) {
2175 mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
2176 *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
2177 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2178 numWorkers.push_back(numWorkerValue);
2179 numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
2180 }
2181 } else if (const auto *vectorLengthClause =
2182 std::get_if<Fortran::parser::AccClause::VectorLength>(
2183 &clause.u)) {
2184 mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
2185 *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
2186 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2187 vectorLength.push_back(vectorLengthValue);
2188 vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
2189 }
2190 } else if (const auto *ifClause =
2191 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2192 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2193 } else if (const auto *selfClause =
2194 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
2195 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
2196 selfClause->v;
2197 if (accSelfClause) {
2198 if (const auto *optCondition =
2199 std::get_if<std::optional<Fortran::parser::ScalarLogicalExpr>>(
2200 &(*accSelfClause).u)) {
2201 if (*optCondition) {
2202 mlir::Value cond = fir::getBase(converter.genExprValue(
2203 *Fortran::semantics::GetExpr(*optCondition), stmtCtx));
2204 selfCond = builder.createConvert(clauseLocation,
2205 builder.getI1Type(), cond);
2206 }
2207 } else if (const auto *accClauseList =
2208 std::get_if<Fortran::parser::AccObjectList>(
2209 &(*accSelfClause).u)) {
2210 // TODO This would be nicer to be done in canonicalization step.
2211 if (accClauseList->v.size() == 1) {
2212 const auto &accObject = accClauseList->v.front();
2213 if (const auto *designator =
2214 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
2215 if (const auto *name =
2216 Fortran::semantics::getDesignatorNameIfDataRef(
2217 *designator)) {
2218 auto cond = converter.getSymbolAddress(*name->symbol);
2219 selfCond = builder.createConvert(clauseLocation,
2220 builder.getI1Type(), cond);
2221 }
2222 }
2223 }
2224 }
2225 } else {
2226 addSelfAttr = true;
2227 }
2228 } else if (const auto *copyClause =
2229 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2230 auto crtDataStart = dataClauseOperands.size();
2231 genDataOperandOperations<mlir::acc::CopyinOp>(
2232 copyClause->v, converter, semanticsContext, stmtCtx,
2233 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2234 /*structured=*/true, /*implicit=*/false);
2235 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2236 dataClauseOperands.end());
2237 } else if (const auto *copyinClause =
2238 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2239 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2240 Fortran::parser::AccClause::Copyin>(
2241 copyinClause, converter, semanticsContext, stmtCtx,
2242 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2243 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2244 mlir::acc::DataClause::acc_copyin_readonly);
2245 } else if (const auto *copyoutClause =
2246 std::get_if<Fortran::parser::AccClause::Copyout>(
2247 &clause.u)) {
2248 auto crtDataStart = dataClauseOperands.size();
2249 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2250 Fortran::parser::AccClause::Copyout>(
2251 copyoutClause, converter, semanticsContext, stmtCtx,
2252 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2253 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
2254 mlir::acc::DataClause::acc_copyout_zero);
2255 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2256 dataClauseOperands.end());
2257 } else if (const auto *createClause =
2258 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2259 auto crtDataStart = dataClauseOperands.size();
2260 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2261 Fortran::parser::AccClause::Create>(
2262 createClause, converter, semanticsContext, stmtCtx,
2263 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2264 mlir::acc::DataClause::acc_create,
2265 mlir::acc::DataClause::acc_create_zero);
2266 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2267 dataClauseOperands.end());
2268 } else if (const auto *noCreateClause =
2269 std::get_if<Fortran::parser::AccClause::NoCreate>(
2270 &clause.u)) {
2271 genDataOperandOperations<mlir::acc::NoCreateOp>(
2272 noCreateClause->v, converter, semanticsContext, stmtCtx,
2273 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2274 /*structured=*/true, /*implicit=*/false);
2275 } else if (const auto *presentClause =
2276 std::get_if<Fortran::parser::AccClause::Present>(
2277 &clause.u)) {
2278 genDataOperandOperations<mlir::acc::PresentOp>(
2279 presentClause->v, converter, semanticsContext, stmtCtx,
2280 dataClauseOperands, mlir::acc::DataClause::acc_present,
2281 /*structured=*/true, /*implicit=*/false);
2282 } else if (const auto *devicePtrClause =
2283 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2284 &clause.u)) {
2285 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2286 devicePtrClause->v, converter, semanticsContext, stmtCtx,
2287 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2288 /*structured=*/true, /*implicit=*/false);
2289 } else if (const auto *attachClause =
2290 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2291 auto crtDataStart = dataClauseOperands.size();
2292 genDataOperandOperations<mlir::acc::AttachOp>(
2293 attachClause->v, converter, semanticsContext, stmtCtx,
2294 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2295 /*structured=*/true, /*implicit=*/false);
2296 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2297 dataClauseOperands.end());
2298 } else if (const auto *privateClause =
2299 std::get_if<Fortran::parser::AccClause::Private>(
2300 &clause.u)) {
2301 if (!combinedConstructs)
2302 genPrivatizations<mlir::acc::PrivateRecipeOp>(
2303 privateClause->v, converter, semanticsContext, stmtCtx,
2304 privateOperands, privatizations);
2305 } else if (const auto *firstprivateClause =
2306 std::get_if<Fortran::parser::AccClause::Firstprivate>(
2307 &clause.u)) {
2308 genPrivatizations<mlir::acc::FirstprivateRecipeOp>(
2309 firstprivateClause->v, converter, semanticsContext, stmtCtx,
2310 firstprivateOperands, firstPrivatizations);
2311 } else if (const auto *reductionClause =
2312 std::get_if<Fortran::parser::AccClause::Reduction>(
2313 &clause.u)) {
2314 // A reduction clause on a combined construct is treated as if it appeared
2315 // on the loop construct. So don't generate a reduction clause when it is
2316 // combined - delay it to the loop. However, a reduction clause on a
2317 // combined construct implies a copy clause so issue an implicit copy
2318 // instead.
2319 if (!combinedConstructs) {
2320 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
2321 reductionOperands, reductionRecipes);
2322 } else {
2323 auto crtDataStart = dataClauseOperands.size();
2324 genDataOperandOperations<mlir::acc::CopyinOp>(
2325 std::get<Fortran::parser::AccObjectList>(reductionClause->v.t),
2326 converter, semanticsContext, stmtCtx, dataClauseOperands,
2327 mlir::acc::DataClause::acc_reduction,
2328 /*structured=*/true, /*implicit=*/true);
2329 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2330 dataClauseOperands.end());
2331 }
2332 } else if (const auto *defaultClause =
2333 std::get_if<Fortran::parser::AccClause::Default>(
2334 &clause.u)) {
2335 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2336 hasDefaultNone = true;
2337 else if ((defaultClause->v).v ==
2338 llvm::acc::DefaultValue::ACC_Default_present)
2339 hasDefaultPresent = true;
2340 } else if (const auto *deviceTypeClause =
2341 std::get_if<Fortran::parser::AccClause::DeviceType>(
2342 &clause.u)) {
2343 crtDeviceTypes.clear();
2344 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2345 }
2346 }
2347
2348 // Prepare the operand segment size attribute and the operands value range.
2349 llvm::SmallVector<mlir::Value, 8> operands;
2350 llvm::SmallVector<int32_t, 8> operandSegments;
2351 addOperands(operands, operandSegments, clauseOperands: async);
2352 addOperands(operands, operandSegments, clauseOperands: waitOperands);
2353 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2354 addOperands(operands, operandSegments, clauseOperands: numGangs);
2355 addOperands(operands, operandSegments, clauseOperands: numWorkers);
2356 addOperands(operands, operandSegments, clauseOperands: vectorLength);
2357 }
2358 addOperand(operands, operandSegments, clauseOperand: ifCond);
2359 addOperand(operands, operandSegments, clauseOperand: selfCond);
2360 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2361 addOperands(operands, operandSegments, clauseOperands: reductionOperands);
2362 addOperands(operands, operandSegments, clauseOperands: privateOperands);
2363 addOperands(operands, operandSegments, clauseOperands: firstprivateOperands);
2364 }
2365 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
2366
2367 Op computeOp;
2368 if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
2369 computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
2370 builder, currentLocation, currentLocation, eval, operands,
2371 operandSegments, /*outerCombined=*/combinedConstructs.has_value());
2372 else
2373 computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
2374 builder, currentLocation, currentLocation, eval, operands,
2375 operandSegments, /*outerCombined=*/combinedConstructs.has_value());
2376
2377 if (addSelfAttr)
2378 computeOp.setSelfAttrAttr(builder.getUnitAttr());
2379
2380 if (hasDefaultNone)
2381 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2382 if (hasDefaultPresent)
2383 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2384
2385 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2386 if (!numWorkersDeviceTypes.empty())
2387 computeOp.setNumWorkersDeviceTypeAttr(
2388 mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2389 if (!vectorLengthDeviceTypes.empty())
2390 computeOp.setVectorLengthDeviceTypeAttr(
2391 mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2392 if (!numGangsDeviceTypes.empty())
2393 computeOp.setNumGangsDeviceTypeAttr(
2394 mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2395 if (!numGangsSegments.empty())
2396 computeOp.setNumGangsSegmentsAttr(
2397 builder.getDenseI32ArrayAttr(numGangsSegments));
2398 }
2399 if (!asyncDeviceTypes.empty())
2400 computeOp.setAsyncOperandsDeviceTypeAttr(
2401 builder.getArrayAttr(asyncDeviceTypes));
2402 if (!asyncOnlyDeviceTypes.empty())
2403 computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2404
2405 if (!waitOperandsDeviceTypes.empty())
2406 computeOp.setWaitOperandsDeviceTypeAttr(
2407 builder.getArrayAttr(waitOperandsDeviceTypes));
2408 if (!waitOperandsSegments.empty())
2409 computeOp.setWaitOperandsSegmentsAttr(
2410 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2411 if (!hasWaitDevnums.empty())
2412 computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
2413 if (!waitOnlyDeviceTypes.empty())
2414 computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2415
2416 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2417 if (!privatizations.empty())
2418 computeOp.setPrivatizationsAttr(
2419 mlir::ArrayAttr::get(builder.getContext(), privatizations));
2420 if (!reductionRecipes.empty())
2421 computeOp.setReductionRecipesAttr(
2422 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
2423 if (!firstPrivatizations.empty())
2424 computeOp.setFirstprivatizationsAttr(
2425 mlir::ArrayAttr::get(builder.getContext(), firstPrivatizations));
2426 }
2427
2428 if (combinedConstructs)
2429 computeOp.setCombinedAttr(builder.getUnitAttr());
2430
2431 auto insPt = builder.saveInsertionPoint();
2432 builder.setInsertionPointAfter(computeOp);
2433
2434 // Create the exit operations after the region.
2435 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2436 builder, copyEntryOperands, /*structured=*/true);
2437 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2438 builder, copyoutEntryOperands, /*structured=*/true);
2439 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2440 builder, attachEntryOperands, /*structured=*/true);
2441 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2442 builder, createEntryOperands, /*structured=*/true);
2443
2444 builder.restoreInsertionPoint(insPt);
2445 return computeOp;
2446}
2447
2448static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2449 mlir::Location currentLocation,
2450 Fortran::lower::pft::Evaluation &eval,
2451 Fortran::semantics::SemanticsContext &semanticsContext,
2452 Fortran::lower::StatementContext &stmtCtx,
2453 const Fortran::parser::AccClauseList &accClauseList) {
2454 mlir::Value ifCond;
2455 llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
2456 copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
2457 async;
2458 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2459 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2460 llvm::SmallVector<int32_t> waitOperandsSegments;
2461 llvm::SmallVector<bool> hasWaitDevnums;
2462
2463 bool hasDefaultNone = false;
2464 bool hasDefaultPresent = false;
2465
2466 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2467
2468 // device_type attribute is set to `none` until a device_type clause is
2469 // encountered.
2470 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2471 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2472 builder.getContext(), mlir::acc::DeviceType::None));
2473
2474 // Lower clauses values mapped to operands and array attributes.
2475 // Keep track of each group of operands separately as clauses can appear
2476 // more than once.
2477 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2478 mlir::Location clauseLocation = converter.genLocation(clause.source);
2479 if (const auto *ifClause =
2480 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2481 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2482 } else if (const auto *copyClause =
2483 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2484 auto crtDataStart = dataClauseOperands.size();
2485 genDataOperandOperations<mlir::acc::CopyinOp>(
2486 copyClause->v, converter, semanticsContext, stmtCtx,
2487 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2488 /*structured=*/true, /*implicit=*/false);
2489 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2490 dataClauseOperands.end());
2491 } else if (const auto *copyinClause =
2492 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2493 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2494 Fortran::parser::AccClause::Copyin>(
2495 copyinClause, converter, semanticsContext, stmtCtx,
2496 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2497 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2498 mlir::acc::DataClause::acc_copyin_readonly);
2499 } else if (const auto *copyoutClause =
2500 std::get_if<Fortran::parser::AccClause::Copyout>(
2501 &clause.u)) {
2502 auto crtDataStart = dataClauseOperands.size();
2503 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2504 Fortran::parser::AccClause::Copyout>(
2505 copyoutClause, converter, semanticsContext, stmtCtx,
2506 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2507 mlir::acc::DataClause::acc_copyout,
2508 mlir::acc::DataClause::acc_copyout_zero);
2509 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2510 dataClauseOperands.end());
2511 } else if (const auto *createClause =
2512 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2513 auto crtDataStart = dataClauseOperands.size();
2514 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2515 Fortran::parser::AccClause::Create>(
2516 createClause, converter, semanticsContext, stmtCtx,
2517 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2518 mlir::acc::DataClause::acc_create,
2519 mlir::acc::DataClause::acc_create_zero);
2520 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2521 dataClauseOperands.end());
2522 } else if (const auto *noCreateClause =
2523 std::get_if<Fortran::parser::AccClause::NoCreate>(
2524 &clause.u)) {
2525 genDataOperandOperations<mlir::acc::NoCreateOp>(
2526 noCreateClause->v, converter, semanticsContext, stmtCtx,
2527 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2528 /*structured=*/true, /*implicit=*/false);
2529 } else if (const auto *presentClause =
2530 std::get_if<Fortran::parser::AccClause::Present>(
2531 &clause.u)) {
2532 genDataOperandOperations<mlir::acc::PresentOp>(
2533 presentClause->v, converter, semanticsContext, stmtCtx,
2534 dataClauseOperands, mlir::acc::DataClause::acc_present,
2535 /*structured=*/true, /*implicit=*/false);
2536 } else if (const auto *deviceptrClause =
2537 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2538 &clause.u)) {
2539 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2540 deviceptrClause->v, converter, semanticsContext, stmtCtx,
2541 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2542 /*structured=*/true, /*implicit=*/false);
2543 } else if (const auto *attachClause =
2544 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2545 auto crtDataStart = dataClauseOperands.size();
2546 genDataOperandOperations<mlir::acc::AttachOp>(
2547 attachClause->v, converter, semanticsContext, stmtCtx,
2548 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2549 /*structured=*/true, /*implicit=*/false);
2550 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2551 dataClauseOperands.end());
2552 } else if (const auto *asyncClause =
2553 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2554 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2555 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2556 } else if (const auto *waitClause =
2557 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2558 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
2559 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2560 hasWaitDevnums, waitOperandsSegments,
2561 crtDeviceTypes, stmtCtx);
2562 } else if(const auto *defaultClause =
2563 std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
2564 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2565 hasDefaultNone = true;
2566 else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present)
2567 hasDefaultPresent = true;
2568 } else if (const auto *deviceTypeClause =
2569 std::get_if<Fortran::parser::AccClause::DeviceType>(
2570 &clause.u)) {
2571 crtDeviceTypes.clear();
2572 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2573 }
2574 }
2575
2576 // Prepare the operand segment size attribute and the operands value range.
2577 llvm::SmallVector<mlir::Value> operands;
2578 llvm::SmallVector<int32_t> operandSegments;
2579 addOperand(operands, operandSegments, clauseOperand: ifCond);
2580 addOperands(operands, operandSegments, clauseOperands: async);
2581 addOperands(operands, operandSegments, clauseOperands: waitOperands);
2582 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
2583
2584 if (dataClauseOperands.empty() && !hasDefaultNone && !hasDefaultPresent)
2585 return;
2586
2587 auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
2588 builder, currentLocation, currentLocation, eval, operands,
2589 operandSegments);
2590
2591 if (!asyncDeviceTypes.empty())
2592 dataOp.setAsyncOperandsDeviceTypeAttr(
2593 builder.getArrayAttr(asyncDeviceTypes));
2594 if (!asyncOnlyDeviceTypes.empty())
2595 dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2596 if (!waitOperandsDeviceTypes.empty())
2597 dataOp.setWaitOperandsDeviceTypeAttr(
2598 builder.getArrayAttr(waitOperandsDeviceTypes));
2599 if (!waitOperandsSegments.empty())
2600 dataOp.setWaitOperandsSegmentsAttr(
2601 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2602 if (!hasWaitDevnums.empty())
2603 dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
2604 if (!waitOnlyDeviceTypes.empty())
2605 dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2606
2607 if (hasDefaultNone)
2608 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2609 if (hasDefaultPresent)
2610 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2611
2612 auto insPt = builder.saveInsertionPoint();
2613 builder.setInsertionPointAfter(dataOp);
2614
2615 // Create the exit operations after the region.
2616 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2617 builder, copyEntryOperands, /*structured=*/true);
2618 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2619 builder, copyoutEntryOperands, /*structured=*/true);
2620 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2621 builder, attachEntryOperands, /*structured=*/true);
2622 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2623 builder, createEntryOperands, /*structured=*/true);
2624
2625 builder.restoreInsertionPoint(insPt);
2626}
2627
2628static void
2629genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
2630 mlir::Location currentLocation,
2631 Fortran::lower::pft::Evaluation &eval,
2632 Fortran::semantics::SemanticsContext &semanticsContext,
2633 Fortran::lower::StatementContext &stmtCtx,
2634 const Fortran::parser::AccClauseList &accClauseList) {
2635 mlir::Value ifCond;
2636 llvm::SmallVector<mlir::Value> dataOperands;
2637 bool addIfPresentAttr = false;
2638
2639 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2640
2641 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2642 mlir::Location clauseLocation = converter.genLocation(clause.source);
2643 if (const auto *ifClause =
2644 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2645 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2646 } else if (const auto *useDevice =
2647 std::get_if<Fortran::parser::AccClause::UseDevice>(
2648 &clause.u)) {
2649 genDataOperandOperations<mlir::acc::UseDeviceOp>(
2650 useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
2651 mlir::acc::DataClause::acc_use_device,
2652 /*structured=*/true, /*implicit=*/false);
2653 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
2654 addIfPresentAttr = true;
2655 }
2656 }
2657
2658 if (ifCond) {
2659 if (auto cst =
2660 mlir::dyn_cast<mlir::arith::ConstantOp>(ifCond.getDefiningOp()))
2661 if (auto boolAttr = cst.getValue().dyn_cast<mlir::BoolAttr>()) {
2662 if (boolAttr.getValue()) {
2663 // get rid of the if condition if it is always true.
2664 ifCond = mlir::Value();
2665 } else {
2666 // Do not generate the acc.host_data op if the if condition is always
2667 // false.
2668 return;
2669 }
2670 }
2671 }
2672
2673 // Prepare the operand segment size attribute and the operands value range.
2674 llvm::SmallVector<mlir::Value> operands;
2675 llvm::SmallVector<int32_t> operandSegments;
2676 addOperand(operands, operandSegments, clauseOperand: ifCond);
2677 addOperands(operands, operandSegments, clauseOperands: dataOperands);
2678
2679 auto hostDataOp =
2680 createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
2681 builder, currentLocation, currentLocation, eval, operands,
2682 operandSegments);
2683
2684 if (addIfPresentAttr)
2685 hostDataOp.setIfPresentAttr(builder.getUnitAttr());
2686}
2687
2688static void
2689genACC(Fortran::lower::AbstractConverter &converter,
2690 Fortran::semantics::SemanticsContext &semanticsContext,
2691 Fortran::lower::pft::Evaluation &eval,
2692 const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
2693 const auto &beginBlockDirective =
2694 std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
2695 const auto &blockDirective =
2696 std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
2697 const auto &accClauseList =
2698 std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t);
2699
2700 mlir::Location currentLocation = converter.genLocation(blockDirective.source);
2701 Fortran::lower::StatementContext stmtCtx;
2702
2703 if (blockDirective.v == llvm::acc::ACCD_parallel) {
2704 createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval,
2705 semanticsContext, stmtCtx,
2706 accClauseList);
2707 } else if (blockDirective.v == llvm::acc::ACCD_data) {
2708 genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
2709 accClauseList);
2710 } else if (blockDirective.v == llvm::acc::ACCD_serial) {
2711 createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
2712 semanticsContext, stmtCtx,
2713 accClauseList);
2714 } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
2715 createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval,
2716 semanticsContext, stmtCtx,
2717 accClauseList);
2718 } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
2719 genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
2720 stmtCtx, accClauseList);
2721 }
2722}
2723
2724static void
2725genACC(Fortran::lower::AbstractConverter &converter,
2726 Fortran::semantics::SemanticsContext &semanticsContext,
2727 Fortran::lower::pft::Evaluation &eval,
2728 const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
2729 const auto &beginCombinedDirective =
2730 std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
2731 const auto &combinedDirective =
2732 std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
2733 const auto &accClauseList =
2734 std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
2735 const auto &outerDoConstruct =
2736 std::get<std::optional<Fortran::parser::DoConstruct>>(
2737 combinedConstruct.t);
2738
2739 mlir::Location currentLocation =
2740 converter.genLocation(beginCombinedDirective.source);
2741 Fortran::lower::StatementContext stmtCtx;
2742
2743 if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
2744 createComputeOp<mlir::acc::KernelsOp>(
2745 converter, currentLocation, eval, semanticsContext, stmtCtx,
2746 accClauseList, mlir::acc::CombinedConstructsType::KernelsLoop);
2747 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2748 *outerDoConstruct, eval, accClauseList,
2749 mlir::acc::CombinedConstructsType::KernelsLoop);
2750 } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
2751 createComputeOp<mlir::acc::ParallelOp>(
2752 converter, currentLocation, eval, semanticsContext, stmtCtx,
2753 accClauseList, mlir::acc::CombinedConstructsType::ParallelLoop);
2754 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2755 *outerDoConstruct, eval, accClauseList,
2756 mlir::acc::CombinedConstructsType::ParallelLoop);
2757 } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
2758 createComputeOp<mlir::acc::SerialOp>(
2759 converter, currentLocation, eval, semanticsContext, stmtCtx,
2760 accClauseList, mlir::acc::CombinedConstructsType::SerialLoop);
2761 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2762 *outerDoConstruct, eval, accClauseList,
2763 mlir::acc::CombinedConstructsType::SerialLoop);
2764 } else {
2765 llvm::report_fatal_error(reason: "Unknown combined construct encountered");
2766 }
2767}
2768
2769static void
2770genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
2771 mlir::Location currentLocation,
2772 Fortran::semantics::SemanticsContext &semanticsContext,
2773 Fortran::lower::StatementContext &stmtCtx,
2774 const Fortran::parser::AccClauseList &accClauseList) {
2775 mlir::Value ifCond, async, waitDevnum;
2776 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands;
2777
2778 // Async, wait and self clause have optional values but can be present with
2779 // no value as well. When there is no value, the op has an attribute to
2780 // represent the clause.
2781 bool addAsyncAttr = false;
2782 bool addWaitAttr = false;
2783
2784 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2785
2786 // Lower clauses values mapped to operands.
2787 // Keep track of each group of operands separately as clauses can appear
2788 // more than once.
2789 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2790 mlir::Location clauseLocation = converter.genLocation(clause.source);
2791 if (const auto *ifClause =
2792 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2793 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2794 } else if (const auto *asyncClause =
2795 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2796 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2797 } else if (const auto *waitClause =
2798 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2799 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
2800 addWaitAttr, stmtCtx);
2801 } else if (const auto *copyinClause =
2802 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2803 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
2804 copyinClause->v;
2805 const auto &accObjectList =
2806 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2807 genDataOperandOperations<mlir::acc::CopyinOp>(
2808 accObjectList, converter, semanticsContext, stmtCtx,
2809 dataClauseOperands, mlir::acc::DataClause::acc_copyin, false,
2810 /*implicit=*/false);
2811 } else if (const auto *createClause =
2812 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2813 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
2814 createClause->v;
2815 const auto &accObjectList =
2816 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2817 const auto &modifier =
2818 std::get<std::optional<Fortran::parser::AccDataModifier>>(
2819 listWithModifier.t);
2820 mlir::acc::DataClause clause = mlir::acc::DataClause::acc_create;
2821 if (modifier &&
2822 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::Zero)
2823 clause = mlir::acc::DataClause::acc_create_zero;
2824 genDataOperandOperations<mlir::acc::CreateOp>(
2825 accObjectList, converter, semanticsContext, stmtCtx,
2826 dataClauseOperands, clause, false, /*implicit=*/false);
2827 } else if (const auto *attachClause =
2828 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2829 genDataOperandOperations<mlir::acc::AttachOp>(
2830 attachClause->v, converter, semanticsContext, stmtCtx,
2831 dataClauseOperands, mlir::acc::DataClause::acc_attach, false,
2832 /*implicit=*/false);
2833 } else {
2834 llvm::report_fatal_error(
2835 "Unknown clause in ENTER DATA directive lowering");
2836 }
2837 }
2838
2839 // Prepare the operand segment size attribute and the operands value range.
2840 llvm::SmallVector<mlir::Value, 16> operands;
2841 llvm::SmallVector<int32_t, 8> operandSegments;
2842 addOperand(operands, operandSegments, clauseOperand: ifCond);
2843 addOperand(operands, operandSegments, clauseOperand: async);
2844 addOperand(operands, operandSegments, clauseOperand: waitDevnum);
2845 addOperands(operands, operandSegments, clauseOperands: waitOperands);
2846 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
2847
2848 mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
2849 firOpBuilder, currentLocation, operands, operandSegments);
2850
2851 if (addAsyncAttr)
2852 enterDataOp.setAsyncAttr(firOpBuilder.getUnitAttr());
2853 if (addWaitAttr)
2854 enterDataOp.setWaitAttr(firOpBuilder.getUnitAttr());
2855}
2856
2857static void
2858genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
2859 mlir::Location currentLocation,
2860 Fortran::semantics::SemanticsContext &semanticsContext,
2861 Fortran::lower::StatementContext &stmtCtx,
2862 const Fortran::parser::AccClauseList &accClauseList) {
2863 mlir::Value ifCond, async, waitDevnum;
2864 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands,
2865 copyoutOperands, deleteOperands, detachOperands;
2866
2867 // Async and wait clause have optional values but can be present with
2868 // no value as well. When there is no value, the op has an attribute to
2869 // represent the clause.
2870 bool addAsyncAttr = false;
2871 bool addWaitAttr = false;
2872 bool addFinalizeAttr = false;
2873
2874 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2875
2876 // Lower clauses values mapped to operands.
2877 // Keep track of each group of operands separately as clauses can appear
2878 // more than once.
2879 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2880 mlir::Location clauseLocation = converter.genLocation(clause.source);
2881 if (const auto *ifClause =
2882 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2883 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2884 } else if (const auto *asyncClause =
2885 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2886 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2887 } else if (const auto *waitClause =
2888 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2889 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
2890 addWaitAttr, stmtCtx);
2891 } else if (const auto *copyoutClause =
2892 std::get_if<Fortran::parser::AccClause::Copyout>(
2893 &clause.u)) {
2894 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
2895 copyoutClause->v;
2896 const auto &accObjectList =
2897 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2898 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2899 accObjectList, converter, semanticsContext, stmtCtx, copyoutOperands,
2900 mlir::acc::DataClause::acc_copyout, false, /*implicit=*/false);
2901 } else if (const auto *deleteClause =
2902 std::get_if<Fortran::parser::AccClause::Delete>(&clause.u)) {
2903 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2904 deleteClause->v, converter, semanticsContext, stmtCtx, deleteOperands,
2905 mlir::acc::DataClause::acc_delete, false, /*implicit=*/false);
2906 } else if (const auto *detachClause =
2907 std::get_if<Fortran::parser::AccClause::Detach>(&clause.u)) {
2908 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2909 detachClause->v, converter, semanticsContext, stmtCtx, detachOperands,
2910 mlir::acc::DataClause::acc_detach, false, /*implicit=*/false);
2911 } else if (std::get_if<Fortran::parser::AccClause::Finalize>(&clause.u)) {
2912 addFinalizeAttr = true;
2913 }
2914 }
2915
2916 dataClauseOperands.append(RHS: copyoutOperands);
2917 dataClauseOperands.append(RHS: deleteOperands);
2918 dataClauseOperands.append(RHS: detachOperands);
2919
2920 // Prepare the operand segment size attribute and the operands value range.
2921 llvm::SmallVector<mlir::Value, 14> operands;
2922 llvm::SmallVector<int32_t, 7> operandSegments;
2923 addOperand(operands, operandSegments, clauseOperand: ifCond);
2924 addOperand(operands, operandSegments, clauseOperand: async);
2925 addOperand(operands, operandSegments, clauseOperand: waitDevnum);
2926 addOperands(operands, operandSegments, clauseOperands: waitOperands);
2927 addOperands(operands, operandSegments, clauseOperands: dataClauseOperands);
2928
2929 mlir::acc::ExitDataOp exitDataOp = createSimpleOp<mlir::acc::ExitDataOp>(
2930 builder, currentLocation, operands, operandSegments);
2931
2932 if (addAsyncAttr)
2933 exitDataOp.setAsyncAttr(builder.getUnitAttr());
2934 if (addWaitAttr)
2935 exitDataOp.setWaitAttr(builder.getUnitAttr());
2936 if (addFinalizeAttr)
2937 exitDataOp.setFinalizeAttr(builder.getUnitAttr());
2938
2939 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::CopyoutOp>(
2940 builder, copyoutOperands, /*structured=*/false);
2941 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DeleteOp>(
2942 builder, deleteOperands, /*structured=*/false);
2943 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DetachOp>(
2944 builder, detachOperands, /*structured=*/false);
2945}
2946
2947template <typename Op>
2948static void
2949genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
2950 mlir::Location currentLocation,
2951 const Fortran::parser::AccClauseList &accClauseList) {
2952 mlir::Value ifCond, deviceNum;
2953
2954 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2955 Fortran::lower::StatementContext stmtCtx;
2956 llvm::SmallVector<mlir::Attribute> deviceTypes;
2957
2958 // Lower clauses values mapped to operands.
2959 // Keep track of each group of operands separately as clauses can appear
2960 // more than once.
2961 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2962 mlir::Location clauseLocation = converter.genLocation(clause.source);
2963 if (const auto *ifClause =
2964 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2965 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2966 } else if (const auto *deviceNumClause =
2967 std::get_if<Fortran::parser::AccClause::DeviceNum>(
2968 &clause.u)) {
2969 deviceNum = fir::getBase(converter.genExprValue(
2970 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
2971 } else if (const auto *deviceTypeClause =
2972 std::get_if<Fortran::parser::AccClause::DeviceType>(
2973 &clause.u)) {
2974 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
2975 }
2976 }
2977
2978 // Prepare the operand segment size attribute and the operands value range.
2979 llvm::SmallVector<mlir::Value, 6> operands;
2980 llvm::SmallVector<int32_t, 2> operandSegments;
2981
2982 addOperand(operands, operandSegments, clauseOperand: deviceNum);
2983 addOperand(operands, operandSegments, clauseOperand: ifCond);
2984
2985 Op op =
2986 createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
2987 if (!deviceTypes.empty())
2988 op.setDeviceTypesAttr(
2989 mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
2990}
2991
2992void genACCSetOp(Fortran::lower::AbstractConverter &converter,
2993 mlir::Location currentLocation,
2994 const Fortran::parser::AccClauseList &accClauseList) {
2995 mlir::Value ifCond, deviceNum, defaultAsync;
2996 llvm::SmallVector<mlir::Value> deviceTypeOperands;
2997
2998 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2999 Fortran::lower::StatementContext stmtCtx;
3000 llvm::SmallVector<mlir::Attribute> deviceTypes;
3001
3002 // Lower clauses values mapped to operands.
3003 // Keep track of each group of operands separately as clauses can appear
3004 // more than once.
3005 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3006 mlir::Location clauseLocation = converter.genLocation(clause.source);
3007 if (const auto *ifClause =
3008 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3009 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3010 } else if (const auto *defaultAsyncClause =
3011 std::get_if<Fortran::parser::AccClause::DefaultAsync>(
3012 &clause.u)) {
3013 defaultAsync = fir::getBase(converter.genExprValue(
3014 *Fortran::semantics::GetExpr(defaultAsyncClause->v), stmtCtx));
3015 } else if (const auto *deviceNumClause =
3016 std::get_if<Fortran::parser::AccClause::DeviceNum>(
3017 &clause.u)) {
3018 deviceNum = fir::getBase(converter.genExprValue(
3019 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
3020 } else if (const auto *deviceTypeClause =
3021 std::get_if<Fortran::parser::AccClause::DeviceType>(
3022 &clause.u)) {
3023 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
3024 }
3025 }
3026
3027 // Prepare the operand segment size attribute and the operands value range.
3028 llvm::SmallVector<mlir::Value> operands;
3029 llvm::SmallVector<int32_t, 3> operandSegments;
3030 addOperand(operands, operandSegments, clauseOperand: defaultAsync);
3031 addOperand(operands, operandSegments, clauseOperand: deviceNum);
3032 addOperand(operands, operandSegments, clauseOperand: ifCond);
3033
3034 auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
3035 operandSegments);
3036 if (!deviceTypes.empty()) {
3037 assert(deviceTypes.size() == 1 && "expect only one value for acc.set");
3038 op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
3039 }
3040}
3041
3042static inline mlir::ArrayAttr
3043getArrayAttr(fir::FirOpBuilder &b,
3044 llvm::SmallVector<mlir::Attribute> &attributes) {
3045 return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
3046}
3047
3048static inline mlir::ArrayAttr
3049getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
3050 return values.empty() ? nullptr : b.getBoolArrayAttr(values);
3051}
3052
3053static inline mlir::DenseI32ArrayAttr
3054getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
3055 llvm::SmallVector<int32_t> &values) {
3056 return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
3057}
3058
3059static void
3060genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
3061 mlir::Location currentLocation,
3062 Fortran::semantics::SemanticsContext &semanticsContext,
3063 Fortran::lower::StatementContext &stmtCtx,
3064 const Fortran::parser::AccClauseList &accClauseList) {
3065 mlir::Value ifCond;
3066 llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
3067 waitOperands, deviceTypeOperands, asyncOperands;
3068 llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
3069 asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
3070 llvm::SmallVector<bool> hasWaitDevnums;
3071 llvm::SmallVector<int32_t> waitOperandsSegments;
3072
3073 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3074
3075 // device_type attribute is set to `none` until a device_type clause is
3076 // encountered.
3077 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
3078 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
3079 builder.getContext(), mlir::acc::DeviceType::None));
3080
3081 bool ifPresent = false;
3082
3083 // Lower clauses values mapped to operands and array attributes.
3084 // Keep track of each group of operands separately as clauses can appear
3085 // more than once.
3086 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3087 mlir::Location clauseLocation = converter.genLocation(clause.source);
3088 if (const auto *ifClause =
3089 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3090 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3091 } else if (const auto *asyncClause =
3092 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3093 genAsyncClause(converter, asyncClause, asyncOperands,
3094 asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
3095 crtDeviceTypes, stmtCtx);
3096 } else if (const auto *waitClause =
3097 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3098 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
3099 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3100 hasWaitDevnums, waitOperandsSegments,
3101 crtDeviceTypes, stmtCtx);
3102 } else if (const auto *deviceTypeClause =
3103 std::get_if<Fortran::parser::AccClause::DeviceType>(
3104 &clause.u)) {
3105 crtDeviceTypes.clear();
3106 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
3107 } else if (const auto *hostClause =
3108 std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
3109 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3110 hostClause->v, converter, semanticsContext, stmtCtx,
3111 updateHostOperands, mlir::acc::DataClause::acc_update_host, false,
3112 /*implicit=*/false);
3113 } else if (const auto *deviceClause =
3114 std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
3115 genDataOperandOperations<mlir::acc::UpdateDeviceOp>(
3116 deviceClause->v, converter, semanticsContext, stmtCtx,
3117 dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
3118 /*implicit=*/false);
3119 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
3120 ifPresent = true;
3121 } else if (const auto *selfClause =
3122 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
3123 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
3124 selfClause->v;
3125 const auto *accObjectList =
3126 std::get_if<Fortran::parser::AccObjectList>(&(*accSelfClause).u);
3127 assert(accObjectList && "expect AccObjectList");
3128 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3129 *accObjectList, converter, semanticsContext, stmtCtx,
3130 updateHostOperands, mlir::acc::DataClause::acc_update_self, false,
3131 /*implicit=*/false);
3132 }
3133 }
3134
3135 dataClauseOperands.append(RHS: updateHostOperands);
3136
3137 builder.create<mlir::acc::UpdateOp>(
3138 currentLocation, ifCond, asyncOperands,
3139 getArrayAttr(builder, asyncOperandsDeviceTypes),
3140 getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
3141 getDenseI32ArrayAttr(builder, waitOperandsSegments),
3142 getArrayAttr(builder, waitOperandsDeviceTypes),
3143 getBoolArrayAttr(builder, hasWaitDevnums),
3144 getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
3145 ifPresent);
3146
3147 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
3148 builder, updateHostOperands, /*structured=*/false);
3149}
3150
3151static void
3152genACC(Fortran::lower::AbstractConverter &converter,
3153 Fortran::semantics::SemanticsContext &semanticsContext,
3154 const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) {
3155 const auto &standaloneDirective =
3156 std::get<Fortran::parser::AccStandaloneDirective>(standaloneConstruct.t);
3157 const auto &accClauseList =
3158 std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t);
3159
3160 mlir::Location currentLocation =
3161 converter.genLocation(standaloneDirective.source);
3162 Fortran::lower::StatementContext stmtCtx;
3163
3164 if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) {
3165 genACCEnterDataOp(converter, currentLocation, semanticsContext, stmtCtx,
3166 accClauseList);
3167 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) {
3168 genACCExitDataOp(converter, currentLocation, semanticsContext, stmtCtx,
3169 accClauseList);
3170 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) {
3171 genACCInitShutdownOp<mlir::acc::InitOp>(converter, currentLocation,
3172 accClauseList);
3173 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) {
3174 genACCInitShutdownOp<mlir::acc::ShutdownOp>(converter, currentLocation,
3175 accClauseList);
3176 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
3177 genACCSetOp(converter, currentLocation, accClauseList);
3178 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
3179 genACCUpdateOp(converter, currentLocation, semanticsContext, stmtCtx,
3180 accClauseList);
3181 }
3182}
3183
3184static void genACC(Fortran::lower::AbstractConverter &converter,
3185 const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
3186
3187 const auto &waitArgument =
3188 std::get<std::optional<Fortran::parser::AccWaitArgument>>(
3189 waitConstruct.t);
3190 const auto &accClauseList =
3191 std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
3192
3193 mlir::Value ifCond, waitDevnum, async;
3194 llvm::SmallVector<mlir::Value> waitOperands;
3195
3196 // Async clause have optional values but can be present with
3197 // no value as well. When there is no value, the op has an attribute to
3198 // represent the clause.
3199 bool addAsyncAttr = false;
3200
3201 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3202 mlir::Location currentLocation = converter.genLocation(waitConstruct.source);
3203 Fortran::lower::StatementContext stmtCtx;
3204
3205 if (waitArgument) { // wait has a value.
3206 const Fortran::parser::AccWaitArgument &waitArg = *waitArgument;
3207 const auto &waitList =
3208 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
3209 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
3210 mlir::Value v = fir::getBase(
3211 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
3212 waitOperands.push_back(v);
3213 }
3214
3215 const auto &waitDevnumValue =
3216 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
3217 if (waitDevnumValue)
3218 waitDevnum = fir::getBase(converter.genExprValue(
3219 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
3220 }
3221
3222 // Lower clauses values mapped to operands.
3223 // Keep track of each group of operands separately as clauses can appear
3224 // more than once.
3225 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3226 mlir::Location clauseLocation = converter.genLocation(clause.source);
3227 if (const auto *ifClause =
3228 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3229 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3230 } else if (const auto *asyncClause =
3231 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3232 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3233 }
3234 }
3235
3236 // Prepare the operand segment size attribute and the operands value range.
3237 llvm::SmallVector<mlir::Value> operands;
3238 llvm::SmallVector<int32_t> operandSegments;
3239 addOperands(operands, operandSegments, clauseOperands: waitOperands);
3240 addOperand(operands, operandSegments, clauseOperand: async);
3241 addOperand(operands, operandSegments, clauseOperand: waitDevnum);
3242 addOperand(operands, operandSegments, clauseOperand: ifCond);
3243
3244 mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
3245 firOpBuilder, currentLocation, operands, operandSegments);
3246
3247 if (addAsyncAttr)
3248 waitOp.setAsyncAttr(firOpBuilder.getUnitAttr());
3249}
3250
3251template <typename GlobalOp, typename EntryOp, typename DeclareOp,
3252 typename ExitOp>
3253static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
3254 fir::FirOpBuilder &builder,
3255 mlir::Location loc, fir::GlobalOp globalOp,
3256 mlir::acc::DataClause clause,
3257 const std::string &declareGlobalName,
3258 bool implicit, std::stringstream &asFortran) {
3259 GlobalOp declareGlobalOp =
3260 modBuilder.create<GlobalOp>(loc, declareGlobalName);
3261 builder.createBlock(&declareGlobalOp.getRegion(),
3262 declareGlobalOp.getRegion().end(), {}, {});
3263 builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back());
3264
3265 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3266 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3267 addDeclareAttr(builder, addrOp, clause);
3268
3269 llvm::SmallVector<mlir::Value> bounds;
3270 EntryOp entryOp = createDataEntryOp<EntryOp>(
3271 builder, loc, addrOp.getResTy(), asFortran, bounds,
3272 /*structured=*/false, implicit, clause, addrOp.getResTy().getType());
3273 if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
3274 builder.create<DeclareOp>(
3275 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3276 mlir::ValueRange(entryOp.getAccPtr()));
3277 else
3278 builder.create<DeclareOp>(loc, mlir::Value{},
3279 mlir::ValueRange(entryOp.getAccPtr()));
3280 if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
3281 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
3282 entryOp.getBounds(), entryOp.getDataClause(),
3283 /*structured=*/false, /*implicit=*/false,
3284 builder.getStringAttr(*entryOp.getName()));
3285 }
3286 builder.create<mlir::acc::TerminatorOp>(loc);
3287 modBuilder.setInsertionPointAfter(declareGlobalOp);
3288}
3289
3290template <typename EntryOp>
3291static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
3292 fir::FirOpBuilder &builder,
3293 mlir::Location loc, fir::GlobalOp &globalOp,
3294 mlir::acc::DataClause clause) {
3295 std::stringstream registerFuncName;
3296 registerFuncName << globalOp.getSymName().str()
3297 << Fortran::lower::declarePostAllocSuffix.str();
3298 auto registerFuncOp =
3299 createDeclareFunc(modBuilder, builder, loc, registerFuncName.str());
3300
3301 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3302 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3303
3304 std::stringstream asFortran;
3305 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3306 std::stringstream asFortranDesc;
3307 asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str();
3308 llvm::SmallVector<mlir::Value> bounds;
3309
3310 // Updating descriptor must occur before the mapping of the data so that
3311 // attached data pointer is not overwritten.
3312 mlir::acc::UpdateDeviceOp updateDeviceOp =
3313 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3314 builder, loc, addrOp, asFortranDesc, bounds,
3315 /*structured=*/false, /*implicit=*/true,
3316 mlir::acc::DataClause::acc_update_device, addrOp.getType());
3317 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
3318 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3319 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3320
3321 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3322 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3323 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
3324 EntryOp entryOp = createDataEntryOp<EntryOp>(
3325 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
3326 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
3327 builder.create<mlir::acc::DeclareEnterOp>(
3328 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3329 mlir::ValueRange(entryOp.getAccPtr()));
3330
3331 modBuilder.setInsertionPointAfter(registerFuncOp);
3332}
3333
3334/// Action to be performed on deallocation are split in two distinct functions.
3335/// - Pre deallocation function includes all the action to be performed before
3336/// the actual deallocation is done on the host side.
3337/// - Post deallocation function includes update to the descriptor.
3338template <typename ExitOp>
3339static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
3340 fir::FirOpBuilder &builder,
3341 mlir::Location loc,
3342 fir::GlobalOp &globalOp,
3343 mlir::acc::DataClause clause) {
3344
3345 // Generate the pre dealloc function.
3346 std::stringstream preDeallocFuncName;
3347 preDeallocFuncName << globalOp.getSymName().str()
3348 << Fortran::lower::declarePreDeallocSuffix.str();
3349 auto preDeallocOp =
3350 createDeclareFunc(modBuilder, builder, loc, preDeallocFuncName.str());
3351 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3352 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3353 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3354 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3355 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
3356
3357 std::stringstream asFortran;
3358 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3359 llvm::SmallVector<mlir::Value> bounds;
3360 mlir::acc::GetDevicePtrOp entryOp =
3361 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
3362 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
3363 /*structured=*/false, /*implicit=*/false, clause,
3364 boxAddrOp.getType());
3365
3366 builder.create<mlir::acc::DeclareExitOp>(
3367 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
3368
3369 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
3370 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
3371 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
3372 entryOp.getVarPtr(), entryOp.getBounds(),
3373 entryOp.getDataClause(),
3374 /*structured=*/false, /*implicit=*/false,
3375 builder.getStringAttr(*entryOp.getName()));
3376 else
3377 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
3378 entryOp.getBounds(), entryOp.getDataClause(),
3379 /*structured=*/false, /*implicit=*/false,
3380 builder.getStringAttr(*entryOp.getName()));
3381
3382 // Generate the post dealloc function.
3383 modBuilder.setInsertionPointAfter(preDeallocOp);
3384 std::stringstream postDeallocFuncName;
3385 postDeallocFuncName << globalOp.getSymName().str()
3386 << Fortran::lower::declarePostDeallocSuffix.str();
3387 auto postDeallocOp =
3388 createDeclareFunc(modBuilder, builder, loc, postDeallocFuncName.str());
3389
3390 addrOp = builder.create<fir::AddrOfOp>(
3391 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3392 asFortran << accFirDescriptorPostfix.str();
3393 mlir::acc::UpdateDeviceOp updateDeviceOp =
3394 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3395 builder, loc, addrOp, asFortran, bounds,
3396 /*structured=*/false, /*implicit=*/true,
3397 mlir::acc::DataClause::acc_update_device, addrOp.getType());
3398 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
3399 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3400 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3401 modBuilder.setInsertionPointAfter(postDeallocOp);
3402}
3403
3404template <typename EntryOp, typename ExitOp>
3405static void genGlobalCtors(Fortran::lower::AbstractConverter &converter,
3406 mlir::OpBuilder &modBuilder,
3407 const Fortran::parser::AccObjectList &accObjectList,
3408 mlir::acc::DataClause clause) {
3409 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3410 for (const auto &accObject : accObjectList.v) {
3411 mlir::Location operandLocation = genOperandLocation(converter, accObject);
3412 std::visit(
3413 Fortran::common::visitors{
3414 [&](const Fortran::parser::Designator &designator) {
3415 if (const auto *name =
3416 Fortran::semantics::getDesignatorNameIfDataRef(
3417 designator)) {
3418 std::string globalName = converter.mangleName(*name->symbol);
3419 fir::GlobalOp globalOp = builder.getNamedGlobal(globalName);
3420 std::stringstream declareGlobalCtorName;
3421 declareGlobalCtorName << globalName << "_acc_ctor";
3422 std::stringstream declareGlobalDtorName;
3423 declareGlobalDtorName << globalName << "_acc_dtor";
3424 std::stringstream asFortran;
3425 asFortran << name->symbol->name().ToString();
3426
3427 if (builder.getModule()
3428 .lookupSymbol<mlir::acc::GlobalConstructorOp>(
3429 declareGlobalCtorName.str()))
3430 return;
3431
3432 if (!globalOp) {
3433 if (Fortran::semantics::FindEquivalenceSet(*name->symbol)) {
3434 for (Fortran::semantics::EquivalenceObject eqObj :
3435 *Fortran::semantics::FindEquivalenceSet(
3436 *name->symbol)) {
3437 std::string eqName = converter.mangleName(eqObj.symbol);
3438 globalOp = builder.getNamedGlobal(eqName);
3439 if (globalOp)
3440 break;
3441 }
3442
3443 if (!globalOp)
3444 llvm::report_fatal_error(
3445 "could not retrieve global symbol");
3446 } else {
3447 llvm::report_fatal_error(
3448 "could not retrieve global symbol");
3449 }
3450 }
3451
3452 addDeclareAttr(builder, globalOp.getOperation(), clause);
3453 auto crtPos = builder.saveInsertionPoint();
3454 modBuilder.setInsertionPointAfter(globalOp);
3455 if (mlir::isa<fir::BaseBoxType>(
3456 fir::unwrapRefType(globalOp.getType()))) {
3457 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp,
3458 mlir::acc::CopyinOp,
3459 mlir::acc::DeclareEnterOp, ExitOp>(
3460 modBuilder, builder, operandLocation, globalOp, clause,
3461 declareGlobalCtorName.str(), /*implicit=*/true,
3462 asFortran);
3463 createDeclareAllocFunc<EntryOp>(
3464 modBuilder, builder, operandLocation, globalOp, clause);
3465 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
3466 createDeclareDeallocFunc<ExitOp>(
3467 modBuilder, builder, operandLocation, globalOp, clause);
3468 } else {
3469 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp,
3470 mlir::acc::DeclareEnterOp, ExitOp>(
3471 modBuilder, builder, operandLocation, globalOp, clause,
3472 declareGlobalCtorName.str(), /*implicit=*/false,
3473 asFortran);
3474 }
3475 if constexpr (!std::is_same_v<EntryOp, ExitOp>) {
3476 createDeclareGlobalOp<mlir::acc::GlobalDestructorOp,
3477 mlir::acc::GetDevicePtrOp,
3478 mlir::acc::DeclareExitOp, ExitOp>(
3479 modBuilder, builder, operandLocation, globalOp, clause,
3480 declareGlobalDtorName.str(), /*implicit=*/false,
3481 asFortran);
3482 }
3483 builder.restoreInsertionPoint(crtPos);
3484 }
3485 },
3486 [&](const Fortran::parser::Name &name) {
3487 TODO(operandLocation, "OpenACC Global Ctor from parser::Name");
3488 }},
3489 accObject.u);
3490 }
3491}
3492
3493template <typename Clause, typename EntryOp, typename ExitOp>
3494static void
3495genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter,
3496 mlir::OpBuilder &modBuilder, const Clause *x,
3497 Fortran::parser::AccDataModifier::Modifier mod,
3498 const mlir::acc::DataClause clause,
3499 const mlir::acc::DataClause clauseWithModifier) {
3500 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
3501 const auto &accObjectList =
3502 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3503 const auto &modifier =
3504 std::get<std::optional<Fortran::parser::AccDataModifier>>(
3505 listWithModifier.t);
3506 mlir::acc::DataClause dataClause =
3507 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
3508 genGlobalCtors<EntryOp, ExitOp>(converter, modBuilder, accObjectList,
3509 dataClause);
3510}
3511
3512static void
3513genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
3514 Fortran::semantics::SemanticsContext &semanticsContext,
3515 Fortran::lower::StatementContext &openAccCtx,
3516 mlir::Location loc,
3517 const Fortran::parser::AccClauseList &accClauseList) {
3518 llvm::SmallVector<mlir::Value> dataClauseOperands, copyEntryOperands,
3519 createEntryOperands, copyoutEntryOperands, deviceResidentEntryOperands;
3520 Fortran::lower::StatementContext stmtCtx;
3521 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3522
3523 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3524 if (const auto *copyClause =
3525 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
3526 auto crtDataStart = dataClauseOperands.size();
3527 genDeclareDataOperandOperations<mlir::acc::CopyinOp,
3528 mlir::acc::CopyoutOp>(
3529 copyClause->v, converter, semanticsContext, stmtCtx,
3530 dataClauseOperands, mlir::acc::DataClause::acc_copy,
3531 /*structured=*/true, /*implicit=*/false);
3532 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3533 dataClauseOperands.end());
3534 } else if (const auto *createClause =
3535 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3536 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3537 createClause->v;
3538 const auto &accObjectList =
3539 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3540 auto crtDataStart = dataClauseOperands.size();
3541 genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3542 accObjectList, converter, semanticsContext, stmtCtx,
3543 dataClauseOperands, mlir::acc::DataClause::acc_create,
3544 /*structured=*/true, /*implicit=*/false);
3545 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3546 dataClauseOperands.end());
3547 } else if (const auto *presentClause =
3548 std::get_if<Fortran::parser::AccClause::Present>(
3549 &clause.u)) {
3550 genDeclareDataOperandOperations<mlir::acc::PresentOp,
3551 mlir::acc::PresentOp>(
3552 presentClause->v, converter, semanticsContext, stmtCtx,
3553 dataClauseOperands, mlir::acc::DataClause::acc_present,
3554 /*structured=*/true, /*implicit=*/false);
3555 } else if (const auto *copyinClause =
3556 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3557 genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
3558 mlir::acc::DeleteOp>(
3559 copyinClause, converter, semanticsContext, stmtCtx,
3560 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3561 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
3562 mlir::acc::DataClause::acc_copyin_readonly);
3563 } else if (const auto *copyoutClause =
3564 std::get_if<Fortran::parser::AccClause::Copyout>(
3565 &clause.u)) {
3566 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3567 copyoutClause->v;
3568 const auto &accObjectList =
3569 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3570 auto crtDataStart = dataClauseOperands.size();
3571 genDeclareDataOperandOperations<mlir::acc::CreateOp,
3572 mlir::acc::CopyoutOp>(
3573 accObjectList, converter, semanticsContext, stmtCtx,
3574 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
3575 /*structured=*/true, /*implicit=*/false);
3576 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3577 dataClauseOperands.end());
3578 } else if (const auto *devicePtrClause =
3579 std::get_if<Fortran::parser::AccClause::Deviceptr>(
3580 &clause.u)) {
3581 genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
3582 mlir::acc::DevicePtrOp>(
3583 devicePtrClause->v, converter, semanticsContext, stmtCtx,
3584 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
3585 /*structured=*/true, /*implicit=*/false);
3586 } else if (const auto *linkClause =
3587 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
3588 genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
3589 mlir::acc::DeclareLinkOp>(
3590 linkClause->v, converter, semanticsContext, stmtCtx,
3591 dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
3592 /*structured=*/true, /*implicit=*/false);
3593 } else if (const auto *deviceResidentClause =
3594 std::get_if<Fortran::parser::AccClause::DeviceResident>(
3595 &clause.u)) {
3596 auto crtDataStart = dataClauseOperands.size();
3597 genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
3598 mlir::acc::DeleteOp>(
3599 deviceResidentClause->v, converter, semanticsContext, stmtCtx,
3600 dataClauseOperands,
3601 mlir::acc::DataClause::acc_declare_device_resident,
3602 /*structured=*/true, /*implicit=*/false);
3603 deviceResidentEntryOperands.append(
3604 dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end());
3605 } else {
3606 mlir::Location clauseLocation = converter.genLocation(clause.source);
3607 TODO(clauseLocation, "clause on declare directive");
3608 }
3609 }
3610
3611 mlir::func::FuncOp funcOp = builder.getFunction();
3612 auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
3613 mlir::Value declareToken;
3614 if (ops.empty()) {
3615 declareToken = builder.create<mlir::acc::DeclareEnterOp>(
3616 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
3617 dataClauseOperands);
3618 } else {
3619 auto declareOp = *ops.begin();
3620 auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
3621 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
3622 declareOp.getDataClauseOperands());
3623 newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
3624 declareToken = newDeclareOp.getToken();
3625 declareOp.erase();
3626 }
3627
3628 openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
3629 copyEntryOperands, copyoutEntryOperands,
3630 deviceResidentEntryOperands, declareToken]() {
3631 llvm::SmallVector<mlir::Value> operands;
3632 operands.append(RHS: createEntryOperands);
3633 operands.append(RHS: deviceResidentEntryOperands);
3634 operands.append(RHS: copyEntryOperands);
3635 operands.append(RHS: copyoutEntryOperands);
3636
3637 mlir::func::FuncOp funcOp = builder.getFunction();
3638 auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
3639 if (ops.empty()) {
3640 builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
3641 } else {
3642 auto declareOp = *ops.begin();
3643 declareOp.getDataClauseOperandsMutable().append(operands);
3644 }
3645
3646 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3647 builder, createEntryOperands, /*structured=*/true);
3648 genDataExitOperations<mlir::acc::DeclareDeviceResidentOp,
3649 mlir::acc::DeleteOp>(
3650 builder, deviceResidentEntryOperands, /*structured=*/true);
3651 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
3652 builder, copyEntryOperands, /*structured=*/true);
3653 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
3654 builder, copyoutEntryOperands, /*structured=*/true);
3655 });
3656}
3657
3658static void
3659genDeclareInModule(Fortran::lower::AbstractConverter &converter,
3660 mlir::ModuleOp &moduleOp,
3661 const Fortran::parser::AccClauseList &accClauseList) {
3662 mlir::OpBuilder modBuilder(moduleOp.getBodyRegion());
3663 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3664 if (const auto *createClause =
3665 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3666 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3667 createClause->v;
3668 const auto &accObjectList =
3669 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3670 genGlobalCtors<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3671 converter, modBuilder, accObjectList,
3672 mlir::acc::DataClause::acc_create);
3673 } else if (const auto *copyinClause =
3674 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3675 genGlobalCtorsWithModifier<Fortran::parser::AccClause::Copyin,
3676 mlir::acc::CopyinOp, mlir::acc::CopyinOp>(
3677 converter, modBuilder, copyinClause,
3678 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3679 mlir::acc::DataClause::acc_copyin,
3680 mlir::acc::DataClause::acc_copyin_readonly);
3681 } else if (const auto *deviceResidentClause =
3682 std::get_if<Fortran::parser::AccClause::DeviceResident>(
3683 &clause.u)) {
3684 genGlobalCtors<mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeleteOp>(
3685 converter, modBuilder, deviceResidentClause->v,
3686 mlir::acc::DataClause::acc_declare_device_resident);
3687 } else if (const auto *linkClause =
3688 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
3689 genGlobalCtors<mlir::acc::DeclareLinkOp, mlir::acc::DeclareLinkOp>(
3690 converter, modBuilder, linkClause->v,
3691 mlir::acc::DataClause::acc_declare_link);
3692 } else {
3693 llvm::report_fatal_error("unsupported clause on DECLARE directive");
3694 }
3695 }
3696}
3697
3698static void genACC(Fortran::lower::AbstractConverter &converter,
3699 Fortran::semantics::SemanticsContext &semanticsContext,
3700 Fortran::lower::StatementContext &openAccCtx,
3701 const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
3702 &declareConstruct) {
3703
3704 const auto &declarativeDir =
3705 std::get<Fortran::parser::AccDeclarativeDirective>(declareConstruct.t);
3706 mlir::Location directiveLocation =
3707 converter.genLocation(declarativeDir.source);
3708 const auto &accClauseList =
3709 std::get<Fortran::parser::AccClauseList>(declareConstruct.t);
3710
3711 if (declarativeDir.v == llvm::acc::Directive::ACCD_declare) {
3712 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3713 auto moduleOp =
3714 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
3715 auto funcOp =
3716 builder.getBlock()->getParent()->getParentOfType<mlir::func::FuncOp>();
3717 if (funcOp)
3718 genDeclareInFunction(converter, semanticsContext, openAccCtx,
3719 directiveLocation, accClauseList);
3720 else if (moduleOp)
3721 genDeclareInModule(converter, moduleOp, accClauseList);
3722 return;
3723 }
3724 llvm_unreachable("unsupported declarative directive");
3725}
3726
3727static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
3728 mlir::acc::DeviceType deviceType) {
3729 for (auto attr : arrayAttr) {
3730 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3731 if (deviceTypeAttr.getValue() == deviceType)
3732 return true;
3733 }
3734 return false;
3735}
3736
3737template <typename RetTy, typename AttrTy>
3738static std::optional<RetTy>
3739getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
3740 llvm::SmallVector<mlir::Attribute> &deviceTypes,
3741 mlir::acc::DeviceType deviceType) {
3742 assert(attributes.size() == deviceTypes.size() &&
3743 "expect same number of attributes");
3744 for (auto it : llvm::enumerate(First&: deviceTypes)) {
3745 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
3746 if (deviceTypeAttr.getValue() == deviceType) {
3747 if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
3748 auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
3749 return strAttr.getValue();
3750 } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
3751 auto intAttr =
3752 mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
3753 return intAttr.getInt();
3754 }
3755 }
3756 }
3757 return std::nullopt;
3758}
3759
3760static bool compareDeviceTypeInfo(
3761 mlir::acc::RoutineOp op,
3762 llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
3763 llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
3764 llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
3765 llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
3766 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
3767 llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
3768 llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
3769 llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
3770 for (uint32_t dtypeInt = 0;
3771 dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
3772 auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
3773 if (op.getBindNameValue(dtype) !=
3774 getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
3775 bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
3776 return false;
3777 if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
3778 return false;
3779 if (op.getGangDimValue(dtype) !=
3780 getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
3781 gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
3782 return false;
3783 if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
3784 return false;
3785 if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
3786 return false;
3787 if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
3788 return false;
3789 }
3790 return true;
3791}
3792
3793static void attachRoutineInfo(mlir::func::FuncOp func,
3794 mlir::SymbolRefAttr routineAttr) {
3795 llvm::SmallVector<mlir::SymbolRefAttr> routines;
3796 if (func.getOperation()->hasAttr(mlir::acc::getRoutineInfoAttrName())) {
3797 auto routineInfo =
3798 func.getOperation()->getAttrOfType<mlir::acc::RoutineInfoAttr>(
3799 mlir::acc::getRoutineInfoAttrName());
3800 routines.append(routineInfo.getAccRoutines().begin(),
3801 routineInfo.getAccRoutines().end());
3802 }
3803 routines.push_back(routineAttr);
3804 func.getOperation()->setAttr(
3805 mlir::acc::getRoutineInfoAttrName(),
3806 mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
3807}
3808
3809void Fortran::lower::genOpenACCRoutineConstruct(
3810 Fortran::lower::AbstractConverter &converter,
3811 Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp &mod,
3812 const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
3813 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
3814 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3815 mlir::Location loc = converter.genLocation(routineConstruct.source);
3816 std::optional<Fortran::parser::Name> name =
3817 std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
3818 const auto &clauses =
3819 std::get<Fortran::parser::AccClauseList>(routineConstruct.t);
3820 mlir::func::FuncOp funcOp;
3821 std::string funcName;
3822 if (name) {
3823 funcName = converter.mangleName(*name->symbol);
3824 funcOp =
3825 builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
3826 } else {
3827 Fortran::semantics::Scope &scope =
3828 semanticsContext.FindScope(routineConstruct.source);
3829 const Fortran::semantics::Scope &progUnit{GetProgramUnitContaining(scope)};
3830 const auto *subpDetails{
3831 progUnit.symbol()
3832 ? progUnit.symbol()
3833 ->detailsIf<Fortran::semantics::SubprogramDetails>()
3834 : nullptr};
3835 if (subpDetails && subpDetails->isInterface()) {
3836 funcName = converter.mangleName(*progUnit.symbol());
3837 funcOp =
3838 builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
3839 } else {
3840 funcOp = builder.getFunction();
3841 funcName = funcOp.getName();
3842 }
3843 }
3844 bool hasNohost = false;
3845
3846 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
3847 workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
3848 gangDimDeviceTypes, gangDimValues;
3849
3850 // device_type attribute is set to `none` until a device_type clause is
3851 // encountered.
3852 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
3853 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
3854 builder.getContext(), mlir::acc::DeviceType::None));
3855
3856 for (const Fortran::parser::AccClause &clause : clauses.v) {
3857 if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
3858 for (auto crtDeviceTypeAttr : crtDeviceTypes)
3859 seqDeviceTypes.push_back(crtDeviceTypeAttr);
3860 } else if (const auto *gangClause =
3861 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
3862 if (gangClause->v) {
3863 const Fortran::parser::AccGangArgList &x = *gangClause->v;
3864 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
3865 if (const auto *dim =
3866 std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
3867 const std::optional<int64_t> dimValue = Fortran::evaluate::ToInt64(
3868 *Fortran::semantics::GetExpr(dim->v));
3869 if (!dimValue)
3870 mlir::emitError(loc,
3871 "dim value must be a constant positive integer");
3872 mlir::Attribute gangDimAttr =
3873 builder.getIntegerAttr(builder.getI64Type(), *dimValue);
3874 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
3875 gangDimValues.push_back(gangDimAttr);
3876 gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
3877 }
3878 }
3879 }
3880 } else {
3881 for (auto crtDeviceTypeAttr : crtDeviceTypes)
3882 gangDeviceTypes.push_back(crtDeviceTypeAttr);
3883 }
3884 } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
3885 for (auto crtDeviceTypeAttr : crtDeviceTypes)
3886 vectorDeviceTypes.push_back(crtDeviceTypeAttr);
3887 } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
3888 for (auto crtDeviceTypeAttr : crtDeviceTypes)
3889 workerDeviceTypes.push_back(crtDeviceTypeAttr);
3890 } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
3891 hasNohost = true;
3892 } else if (const auto *bindClause =
3893 std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
3894 if (const auto *name =
3895 std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
3896 mlir::Attribute bindNameAttr =
3897 builder.getStringAttr(converter.mangleName(*name->symbol));
3898 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
3899 bindNames.push_back(bindNameAttr);
3900 bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
3901 }
3902 } else if (const auto charExpr =
3903 std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
3904 &bindClause->v.u)) {
3905 const std::optional<std::string> name =
3906 Fortran::semantics::GetConstExpr<std::string>(semanticsContext,
3907 *charExpr);
3908 if (!name)
3909 mlir::emitError(loc, "Could not retrieve the bind name");
3910
3911 mlir::Attribute bindNameAttr = builder.getStringAttr(*name);
3912 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
3913 bindNames.push_back(bindNameAttr);
3914 bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
3915 }
3916 }
3917 } else if (const auto *deviceTypeClause =
3918 std::get_if<Fortran::parser::AccClause::DeviceType>(
3919 &clause.u)) {
3920 crtDeviceTypes.clear();
3921 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
3922 }
3923 }
3924
3925 mlir::OpBuilder modBuilder(mod.getBodyRegion());
3926 std::stringstream routineOpName;
3927 routineOpName << accRoutinePrefix.str() << routineCounter++;
3928
3929 for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
3930 if (routineOp.getFuncName().str().compare(funcName) == 0) {
3931 // If the routine is already specified with the same clauses, just skip
3932 // the operation creation.
3933 if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
3934 gangDeviceTypes, gangDimValues,
3935 gangDimDeviceTypes, seqDeviceTypes,
3936 workerDeviceTypes, vectorDeviceTypes) &&
3937 routineOp.getNohost() == hasNohost)
3938 return;
3939 mlir::emitError(loc, "Routine already specified with different clauses");
3940 }
3941 }
3942
3943 modBuilder.create<mlir::acc::RoutineOp>(
3944 loc, routineOpName.str(), funcName,
3945 bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
3946 bindNameDeviceTypes.empty() ? nullptr
3947 : builder.getArrayAttr(bindNameDeviceTypes),
3948 workerDeviceTypes.empty() ? nullptr
3949 : builder.getArrayAttr(workerDeviceTypes),
3950 vectorDeviceTypes.empty() ? nullptr
3951 : builder.getArrayAttr(vectorDeviceTypes),
3952 seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
3953 hasNohost, /*implicit=*/false,
3954 gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
3955 gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
3956 gangDimDeviceTypes.empty() ? nullptr
3957 : builder.getArrayAttr(gangDimDeviceTypes));
3958
3959 if (funcOp)
3960 attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
3961 else
3962 // FuncOp is not lowered yet. Keep the information so the routine info
3963 // can be attached later to the funcOp.
3964 accRoutineInfos.push_back(std::make_pair(
3965 funcName, builder.getSymbolRefAttr(routineOpName.str())));
3966}
3967
3968void Fortran::lower::finalizeOpenACCRoutineAttachment(
3969 mlir::ModuleOp &mod,
3970 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
3971 for (auto &mapping : accRoutineInfos) {
3972 mlir::func::FuncOp funcOp =
3973 mod.lookupSymbol<mlir::func::FuncOp>(mapping.first);
3974 if (!funcOp)
3975 mlir::emitWarning(mod.getLoc(),
3976 llvm::Twine("function '") + llvm::Twine(mapping.first) +
3977 llvm::Twine("' in acc routine directive is not "
3978 "found in this translation unit."));
3979 else
3980 attachRoutineInfo(funcOp, mapping.second);
3981 }
3982 accRoutineInfos.clear();
3983}
3984
3985static void
3986genACC(Fortran::lower::AbstractConverter &converter,
3987 Fortran::lower::pft::Evaluation &eval,
3988 const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
3989
3990 mlir::Location loc = converter.genLocation(atomicConstruct.source);
3991 std::visit(
3992 Fortran::common::visitors{
3993 [&](const Fortran::parser::AccAtomicRead &atomicRead) {
3994 Fortran::lower::genOmpAccAtomicRead<Fortran::parser::AccAtomicRead,
3995 void>(converter, atomicRead,
3996 loc);
3997 },
3998 [&](const Fortran::parser::AccAtomicWrite &atomicWrite) {
3999 Fortran::lower::genOmpAccAtomicWrite<
4000 Fortran::parser::AccAtomicWrite, void>(converter, atomicWrite,
4001 loc);
4002 },
4003 [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) {
4004 Fortran::lower::genOmpAccAtomicUpdate<
4005 Fortran::parser::AccAtomicUpdate, void>(converter, atomicUpdate,
4006 loc);
4007 },
4008 [&](const Fortran::parser::AccAtomicCapture &atomicCapture) {
4009 Fortran::lower::genOmpAccAtomicCapture<
4010 Fortran::parser::AccAtomicCapture, void>(converter,
4011 atomicCapture, loc);
4012 },
4013 },
4014 atomicConstruct.u);
4015}
4016
4017static void
4018genACC(Fortran::lower::AbstractConverter &converter,
4019 Fortran::semantics::SemanticsContext &semanticsContext,
4020 const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
4021 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4022 auto loopOp = builder.getRegion().getParentOfType<mlir::acc::LoopOp>();
4023 auto crtPos = builder.saveInsertionPoint();
4024 if (loopOp) {
4025 builder.setInsertionPoint(loopOp);
4026 Fortran::lower::StatementContext stmtCtx;
4027 llvm::SmallVector<mlir::Value> cacheOperands;
4028 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
4029 std::get<Fortran::parser::AccObjectListWithModifier>(cacheConstruct.t);
4030 const auto &accObjectList =
4031 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4032 const auto &modifier =
4033 std::get<std::optional<Fortran::parser::AccDataModifier>>(
4034 listWithModifier.t);
4035
4036 mlir::acc::DataClause dataClause = mlir::acc::DataClause::acc_cache;
4037 if (modifier &&
4038 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly)
4039 dataClause = mlir::acc::DataClause::acc_cache_readonly;
4040 genDataOperandOperations<mlir::acc::CacheOp>(
4041 accObjectList, converter, semanticsContext, stmtCtx, cacheOperands,
4042 dataClause,
4043 /*structured=*/true, /*implicit=*/false, /*setDeclareAttr*/ false);
4044 loopOp.getCacheOperandsMutable().append(cacheOperands);
4045 } else {
4046 llvm::report_fatal_error(
4047 reason: "could not find loop to attach OpenACC cache information.");
4048 }
4049 builder.restoreInsertionPoint(crtPos);
4050}
4051
4052mlir::Value Fortran::lower::genOpenACCConstruct(
4053 Fortran::lower::AbstractConverter &converter,
4054 Fortran::semantics::SemanticsContext &semanticsContext,
4055 Fortran::lower::pft::Evaluation &eval,
4056 const Fortran::parser::OpenACCConstruct &accConstruct) {
4057
4058 mlir::Value exitCond;
4059 std::visit(
4060 common::visitors{
4061 [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
4062 genACC(converter, semanticsContext, eval, blockConstruct);
4063 },
4064 [&](const Fortran::parser::OpenACCCombinedConstruct
4065 &combinedConstruct) {
4066 genACC(converter, semanticsContext, eval, combinedConstruct);
4067 },
4068 [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
4069 exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
4070 },
4071 [&](const Fortran::parser::OpenACCStandaloneConstruct
4072 &standaloneConstruct) {
4073 genACC(converter, semanticsContext, standaloneConstruct);
4074 },
4075 [&](const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
4076 genACC(converter, semanticsContext, cacheConstruct);
4077 },
4078 [&](const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
4079 genACC(converter, waitConstruct);
4080 },
4081 [&](const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
4082 genACC(converter, eval, atomicConstruct);
4083 },
4084 [&](const Fortran::parser::OpenACCEndConstruct &) {
4085 // No op
4086 },
4087 },
4088 accConstruct.u);
4089 return exitCond;
4090}
4091
4092void Fortran::lower::genOpenACCDeclarativeConstruct(
4093 Fortran::lower::AbstractConverter &converter,
4094 Fortran::semantics::SemanticsContext &semanticsContext,
4095 Fortran::lower::StatementContext &openAccCtx,
4096 const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct,
4097 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
4098
4099 std::visit(
4100 common::visitors{
4101 [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
4102 &standaloneDeclarativeConstruct) {
4103 genACC(converter, semanticsContext, openAccCtx,
4104 standaloneDeclarativeConstruct);
4105 },
4106 [&](const Fortran::parser::OpenACCRoutineConstruct
4107 &routineConstruct) {
4108 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4109 mlir::ModuleOp mod = builder.getModule();
4110 Fortran::lower::genOpenACCRoutineConstruct(
4111 converter, semanticsContext, mod, routineConstruct,
4112 accRoutineInfos);
4113 },
4114 },
4115 accDeclConstruct.u);
4116}
4117
4118void Fortran::lower::attachDeclarePostAllocAction(
4119 AbstractConverter &converter, fir::FirOpBuilder &builder,
4120 const Fortran::semantics::Symbol &sym) {
4121 std::stringstream fctName;
4122 fctName << converter.mangleName(sym) << declarePostAllocSuffix.str();
4123 mlir::Operation &op = builder.getInsertionBlock()->back();
4124
4125 if (op.hasAttr(name: mlir::acc::getDeclareActionAttrName())) {
4126 auto attr = op.getAttrOfType<mlir::acc::DeclareActionAttr>(
4127 mlir::acc::getDeclareActionAttrName());
4128 op.setAttr(mlir::acc::getDeclareActionAttrName(),
4129 mlir::acc::DeclareActionAttr::get(
4130 builder.getContext(), attr.getPreAlloc(),
4131 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
4132 attr.getPreDealloc(), attr.getPostDealloc()));
4133 } else {
4134 op.setAttr(mlir::acc::getDeclareActionAttrName(),
4135 mlir::acc::DeclareActionAttr::get(
4136 builder.getContext(),
4137 /*preAlloc=*/{},
4138 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
4139 /*preDealloc=*/{}, /*postDealloc=*/{}));
4140 }
4141}
4142
4143void Fortran::lower::attachDeclarePreDeallocAction(
4144 AbstractConverter &converter, fir::FirOpBuilder &builder,
4145 mlir::Value beginOpValue, const Fortran::semantics::Symbol &sym) {
4146 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
4147 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
4148 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
4149 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
4150 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
4151 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
4152 return;
4153
4154 std::stringstream fctName;
4155 fctName << converter.mangleName(sym) << declarePreDeallocSuffix.str();
4156
4157 auto *op = beginOpValue.getDefiningOp();
4158 if (op->hasAttr(name: mlir::acc::getDeclareActionAttrName())) {
4159 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4160 mlir::acc::getDeclareActionAttrName());
4161 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4162 mlir::acc::DeclareActionAttr::get(
4163 builder.getContext(), attr.getPreAlloc(),
4164 attr.getPostAlloc(),
4165 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
4166 attr.getPostDealloc()));
4167 } else {
4168 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4169 mlir::acc::DeclareActionAttr::get(
4170 builder.getContext(),
4171 /*preAlloc=*/{}, /*postAlloc=*/{},
4172 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
4173 /*postDealloc=*/{}));
4174 }
4175}
4176
4177void Fortran::lower::attachDeclarePostDeallocAction(
4178 AbstractConverter &converter, fir::FirOpBuilder &builder,
4179 const Fortran::semantics::Symbol &sym) {
4180 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
4181 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
4182 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
4183 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
4184 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
4185 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
4186 return;
4187
4188 std::stringstream fctName;
4189 fctName << converter.mangleName(sym) << declarePostDeallocSuffix.str();
4190 mlir::Operation *op = &builder.getInsertionBlock()->back();
4191 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) {
4192 assert(resOp.getOperands().size() == 0 &&
4193 "expect only fir.result op with no operand");
4194 op = op->getPrevNode();
4195 }
4196 assert(op && "expect operation to attach the post deallocation action");
4197 if (op->hasAttr(name: mlir::acc::getDeclareActionAttrName())) {
4198 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4199 mlir::acc::getDeclareActionAttrName());
4200 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4201 mlir::acc::DeclareActionAttr::get(
4202 builder.getContext(), attr.getPreAlloc(),
4203 attr.getPostAlloc(), attr.getPreDealloc(),
4204 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
4205 } else {
4206 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4207 mlir::acc::DeclareActionAttr::get(
4208 builder.getContext(),
4209 /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{},
4210 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
4211 }
4212}
4213
4214void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
4215 mlir::Operation *op,
4216 mlir::Location loc) {
4217 if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op))
4218 builder.create<mlir::acc::YieldOp>(loc);
4219 else
4220 builder.create<mlir::acc::TerminatorOp>(loc);
4221}
4222
4223bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
4224 if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
4225 return true;
4226 return false;
4227}
4228
4229void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
4230 fir::FirOpBuilder &builder) {
4231 if (auto loopOp =
4232 builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
4233 builder.setInsertionPointAfter(loopOp);
4234}
4235
4236void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
4237 mlir::Location loc) {
4238 mlir::Value yieldValue =
4239 builder.createIntegerConstant(loc, builder.getI1Type(), 1);
4240 builder.create<mlir::acc::YieldOp>(loc, yieldValue);
4241}
4242
4243int64_t Fortran::lower::getCollapseValue(
4244 const Fortran::parser::AccClauseList &clauseList) {
4245 for (const Fortran::parser::AccClause &clause : clauseList.v) {
4246 if (const auto *collapseClause =
4247 std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
4248 const parser::AccCollapseArg &arg = collapseClause->v;
4249 const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
4250 return *Fortran::semantics::GetIntValue(collapseValue);
4251 }
4252 }
4253 return 1;
4254}
4255

source code of flang/lib/Lower/OpenACC.cpp