1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
13#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
18#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Support/LLVM.h"
22#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
23#include "mlir/Target/LLVMIR/ModuleTranslation.h"
24#include "mlir/Transforms/RegionUtils.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Frontend/OpenMP/OMPConstants.h"
31#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32#include "llvm/IR/Constants.h"
33#include "llvm/IR/DebugInfoMetadata.h"
34#include "llvm/IR/DerivedTypes.h"
35#include "llvm/IR/IRBuilder.h"
36#include "llvm/IR/MDBuilder.h"
37#include "llvm/IR/ReplaceConstant.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/TargetParser/Triple.h"
40#include "llvm/Transforms/Utils/ModuleUtils.h"
41
42#include <any>
43#include <cstdint>
44#include <iterator>
45#include <numeric>
46#include <optional>
47#include <utility>
48
49using namespace mlir;
50
51namespace {
52static llvm::omp::ScheduleKind
53convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
65 case omp::ClauseScheduleKind::Runtime:
66 return llvm::omp::OMP_SCHEDULE_Runtime;
67 }
68 llvm_unreachable("unhandled schedule clause argument");
69}
70
71/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
72/// insertion points for allocas.
73class OpenMPAllocaStackFrame
74 : public StateStackFrameBase<OpenMPAllocaStackFrame> {
75public:
76 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
77
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
81};
82
83/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
84/// collapsed canonical loop information corresponding to an \c omp.loop_nest
85/// operation.
86class OpenMPLoopInfoStackFrame
87 : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
88public:
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
90 llvm::CanonicalLoopInfo *loopInfo = nullptr;
91};
92
93/// Custom error class to signal translation errors that don't need reporting,
94/// since encountering them will have already triggered relevant error messages.
95///
96/// Its purpose is to serve as the glue between MLIR failures represented as
97/// \see LogicalResult instances and \see llvm::Error instances used to
98/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
99/// error of the first type is raised, a message is emitted directly (the \see
100/// LogicalResult itself does not hold any information). If we need to forward
101/// this error condition as an \see llvm::Error while avoiding triggering some
102/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
103/// class to just signal this situation has happened.
104///
105/// For example, this class should be used to trigger errors from within
106/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
107/// translation of their own regions. This unclutters the error log from
108/// redundant messages.
109class PreviouslyReportedError
110 : public llvm::ErrorInfo<PreviouslyReportedError> {
111public:
112 void log(raw_ostream &) const override {
113 // Do not log anything.
114 }
115
116 std::error_code convertToErrorCode() const override {
117 llvm_unreachable(
118 "PreviouslyReportedError doesn't support ECError conversion");
119 }
120
121 // Used by ErrorInfo::classID.
122 static char ID;
123};
124
125char PreviouslyReportedError::ID = 0;
126
127/*
128 * Custom class for processing linear clause for omp.wsloop
129 * and omp.simd. Linear clause translation requires setup,
130 * initialization, update, and finalization at varying
131 * basic blocks in the IR. This class helps maintain
132 * internal state to allow consistent translation in
133 * each of these stages.
134 */
135
136class LinearClauseProcessor {
137
138private:
139 SmallVector<llvm::Value *> linearPreconditionVars;
140 SmallVector<llvm::Value *> linearLoopBodyTemps;
141 SmallVector<llvm::AllocaInst *> linearOrigVars;
142 SmallVector<llvm::Value *> linearOrigVal;
143 SmallVector<llvm::Value *> linearSteps;
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
147
148public:
149 // Allocate space for linear variabes
150 void createLinearVar(llvm::IRBuilderBase &builder,
151 LLVM::ModuleTranslation &moduleTranslation,
152 mlir::Value &linearVar) {
153 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
154 Val: moduleTranslation.lookupValue(value: linearVar))) {
155 linearPreconditionVars.push_back(Elt: builder.CreateAlloca(
156 Ty: linearVarAlloca->getAllocatedType(), ArraySize: nullptr, Name: ".linear_var"));
157 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
158 Ty: linearVarAlloca->getAllocatedType(), ArraySize: nullptr, Name: ".linear_result");
159 linearOrigVal.push_back(Elt: moduleTranslation.lookupValue(value: linearVar));
160 linearLoopBodyTemps.push_back(Elt: linearLoopBodyTemp);
161 linearOrigVars.push_back(Elt: linearVarAlloca);
162 }
163 }
164
165 // Initialize linear step
166 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
167 mlir::Value &linearStep) {
168 linearSteps.push_back(Elt: moduleTranslation.lookupValue(value: linearStep));
169 }
170
171 // Emit IR for initialization of linear variables
172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173 initLinearVar(llvm::IRBuilderBase &builder,
174 LLVM::ModuleTranslation &moduleTranslation,
175 llvm::BasicBlock *loopPreHeader) {
176 builder.SetInsertPoint(loopPreHeader->getTerminator());
177 for (size_t index = 0; index < linearOrigVars.size(); index++) {
178 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179 Ty: linearOrigVars[index]->getAllocatedType(), Ptr: linearOrigVars[index]);
180 builder.CreateStore(Val: linearVarLoad, Ptr: linearPreconditionVars[index]);
181 }
182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
183 moduleTranslation.getOpenMPBuilder()->createBarrier(
184 Loc: builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
185 return afterBarrierIP;
186 }
187
188 // Emit IR for updating Linear variables
189 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190 llvm::Value *loopInductionVar) {
191 builder.SetInsertPoint(loopBody->getTerminator());
192 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
193 // Emit increments for linear vars
194 llvm::LoadInst *linearVarStart =
195 builder.CreateLoad(Ty: linearOrigVars[index]->getAllocatedType(),
196
197 Ptr: linearPreconditionVars[index]);
198 auto mulInst = builder.CreateMul(LHS: loopInductionVar, RHS: linearSteps[index]);
199 auto addInst = builder.CreateAdd(LHS: linearVarStart, RHS: mulInst);
200 builder.CreateStore(Val: addInst, Ptr: linearLoopBodyTemps[index]);
201 }
202 }
203
204 // Linear variable finalization is conditional on the last logical iteration.
205 // Create BB splits to manage the same.
206 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207 llvm::BasicBlock *loopExit) {
208 linearFinalizationBB = loopExit->splitBasicBlock(
209 I: loopExit->getTerminator(), BBName: "omp_loop.linear_finalization");
210 linearExitBB = linearFinalizationBB->splitBasicBlock(
211 I: linearFinalizationBB->getTerminator(), BBName: "omp_loop.linear_exit");
212 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
213 I: linearFinalizationBB->getTerminator(), BBName: "omp_loop.linear_lastiter_exit");
214 }
215
216 // Finalize the linear vars
217 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218 finalizeLinearVar(llvm::IRBuilderBase &builder,
219 LLVM::ModuleTranslation &moduleTranslation,
220 llvm::Value *lastIter) {
221 // Emit condition to check whether last logical iteration is being executed
222 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
223 llvm::Value *loopLastIterLoad = builder.CreateLoad(
224 Ty: llvm::Type::getInt32Ty(C&: builder.getContext()), Ptr: lastIter);
225 llvm::Value *isLast =
226 builder.CreateCmp(Pred: llvm::CmpInst::ICMP_NE, LHS: loopLastIterLoad,
227 RHS: llvm::ConstantInt::get(
228 Ty: llvm::Type::getInt32Ty(C&: builder.getContext()), V: 0));
229 // Store the linear variable values to original variables.
230 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
231 for (size_t index = 0; index < linearOrigVars.size(); index++) {
232 llvm::LoadInst *linearVarTemp =
233 builder.CreateLoad(Ty: linearOrigVars[index]->getAllocatedType(),
234 Ptr: linearLoopBodyTemps[index]);
235 builder.CreateStore(Val: linearVarTemp, Ptr: linearOrigVars[index]);
236 }
237
238 // Create conditional branch such that the linear variable
239 // values are stored to original variables only at the
240 // last logical iteration
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(Cond: isLast, True: linearLastIterExitBB, False: linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
244 // Emit barrier
245 builder.SetInsertPoint(linearExitBB->getTerminator());
246 return moduleTranslation.getOpenMPBuilder()->createBarrier(
247 Loc: builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
248 }
249
250 // Rewrite all uses of the original variable in `BBName`
251 // with the linear variable in-place
252 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
253 size_t varIndex) {
254 llvm::SmallVector<llvm::User *> users;
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(Elt: user);
257 for (auto *user : users) {
258 if (auto *userInst = dyn_cast<llvm::Instruction>(Val: user)) {
259 if (userInst->getParent()->getName().str() == BBName)
260 user->replaceUsesOfWith(From: linearOrigVal[varIndex],
261 To: linearLoopBodyTemps[varIndex]);
262 }
263 }
264 }
265};
266
267} // namespace
268
269/// Looks up from the operation from and returns the PrivateClauseOp with
270/// name symbolName
271static omp::PrivateClauseOp findPrivatizer(Operation *from,
272 SymbolRefAttr symbolName) {
273 omp::PrivateClauseOp privatizer =
274 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
275 symbol: symbolName);
276 assert(privatizer && "privatizer not found in the symbol table");
277 return privatizer;
278}
279
280/// Check whether translation to LLVM IR for the given operation is currently
281/// supported. If not, descriptive diagnostics will be emitted to let users know
282/// this is a not-yet-implemented feature.
283///
284/// \returns success if no unimplemented features are needed to translate the
285/// given operation.
286static LogicalResult checkImplementationStatus(Operation &op) {
287 auto todo = [&op](StringRef clauseName) {
288 return op.emitError() << "not yet implemented: Unhandled clause "
289 << clauseName << " in " << op.getName()
290 << " operation";
291 };
292
293 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
294 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
295 result = todo("allocate");
296 };
297 auto checkBare = [&todo](auto op, LogicalResult &result) {
298 if (op.getBare())
299 result = todo("ompx_bare");
300 };
301 auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
302 omp::ClauseCancellationConstructType cancelledDirective =
303 op.getCancelDirective();
304 // Cancelling a taskloop is not yet supported because we don't yet have LLVM
305 // IR conversion for taskloop
306 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
307 Operation *parent = op->getParentOp();
308 while (parent) {
309 if (parent->getDialect() == op->getDialect())
310 break;
311 parent = parent->getParentOp();
312 }
313 if (isa_and_nonnull<omp::TaskloopOp>(Val: parent))
314 result = todo("cancel directive inside of taskloop");
315 }
316 };
317 auto checkDepend = [&todo](auto op, LogicalResult &result) {
318 if (!op.getDependVars().empty() || op.getDependKinds())
319 result = todo("depend");
320 };
321 auto checkDevice = [&todo](auto op, LogicalResult &result) {
322 if (op.getDevice())
323 result = todo("device");
324 };
325 auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
326 if (op.getDistScheduleChunkSize())
327 result = todo("dist_schedule with chunk_size");
328 };
329 auto checkHint = [](auto op, LogicalResult &) {
330 if (op.getHint())
331 op.emitWarning("hint clause discarded");
332 };
333 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
334 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
335 op.getInReductionSyms())
336 result = todo("in_reduction");
337 };
338 auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
339 if (!op.getIsDevicePtrVars().empty())
340 result = todo("is_device_ptr");
341 };
342 auto checkLinear = [&todo](auto op, LogicalResult &result) {
343 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
344 result = todo("linear");
345 };
346 auto checkNowait = [&todo](auto op, LogicalResult &result) {
347 if (op.getNowait())
348 result = todo("nowait");
349 };
350 auto checkOrder = [&todo](auto op, LogicalResult &result) {
351 if (op.getOrder() || op.getOrderMod())
352 result = todo("order");
353 };
354 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
355 if (op.getParLevelSimd())
356 result = todo("parallelization-level");
357 };
358 auto checkPriority = [&todo](auto op, LogicalResult &result) {
359 if (op.getPriority())
360 result = todo("priority");
361 };
362 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
363 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
364 // Privatization is supported only for included target tasks.
365 if (!op.getPrivateVars().empty() && op.getNowait())
366 result = todo("privatization for deferred target tasks");
367 } else {
368 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
369 result = todo("privatization");
370 }
371 };
372 auto checkReduction = [&todo](auto op, LogicalResult &result) {
373 if (isa<omp::TeamsOp>(op))
374 if (!op.getReductionVars().empty() || op.getReductionByref() ||
375 op.getReductionSyms())
376 result = todo("reduction");
377 if (op.getReductionMod() &&
378 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
379 result = todo("reduction with modifier");
380 };
381 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
382 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
383 op.getTaskReductionSyms())
384 result = todo("task_reduction");
385 };
386 auto checkUntied = [&todo](auto op, LogicalResult &result) {
387 if (op.getUntied())
388 result = todo("untied");
389 };
390
391 LogicalResult result = success();
392 llvm::TypeSwitch<Operation &>(op)
393 .Case(caseFn: [&](omp::CancelOp op) { checkCancelDirective(op, result); })
394 .Case(caseFn: [&](omp::CancellationPointOp op) {
395 checkCancelDirective(op, result);
396 })
397 .Case(caseFn: [&](omp::DistributeOp op) {
398 checkAllocate(op, result);
399 checkDistSchedule(op, result);
400 checkOrder(op, result);
401 })
402 .Case(caseFn: [&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
403 .Case(caseFn: [&](omp::SectionsOp op) {
404 checkAllocate(op, result);
405 checkPrivate(op, result);
406 checkReduction(op, result);
407 })
408 .Case(caseFn: [&](omp::SingleOp op) {
409 checkAllocate(op, result);
410 checkPrivate(op, result);
411 })
412 .Case(caseFn: [&](omp::TeamsOp op) {
413 checkAllocate(op, result);
414 checkPrivate(op, result);
415 })
416 .Case(caseFn: [&](omp::TaskOp op) {
417 checkAllocate(op, result);
418 checkInReduction(op, result);
419 })
420 .Case(caseFn: [&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
423 })
424 .Case(caseFn: [&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
427 })
428 .Case(caseFn: [&](omp::TaskloopOp op) {
429 // TODO: Add other clauses check
430 checkUntied(op, result);
431 checkPriority(op, result);
432 })
433 .Case(caseFn: [&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkLinear(op, result);
436 checkOrder(op, result);
437 checkReduction(op, result);
438 })
439 .Case(caseFn: [&](omp::ParallelOp op) {
440 checkAllocate(op, result);
441 checkReduction(op, result);
442 })
443 .Case(caseFn: [&](omp::SimdOp op) {
444 checkLinear(op, result);
445 checkReduction(op, result);
446 })
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>(caseFn: [&](auto op) { checkHint(op, result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
450 caseFn: [&](auto op) { checkDepend(op, result); })
451 .Case(caseFn: [&](omp::TargetOp op) {
452 checkAllocate(op, result);
453 checkBare(op, result);
454 checkDevice(op, result);
455 checkInReduction(op, result);
456 checkIsDevicePtr(op, result);
457 checkPrivate(op, result);
458 })
459 .Default(defaultFn: [](Operation &) {
460 // Assume all clauses for an operation can be translated unless they are
461 // checked above.
462 });
463 return result;
464}
465
466static LogicalResult handleError(llvm::Error error, Operation &op) {
467 LogicalResult result = success();
468 if (error) {
469 llvm::handleAllErrors(
470 E: std::move(error),
471 Handlers: [&](const PreviouslyReportedError &) { result = failure(); },
472 Handlers: [&](const llvm::ErrorInfoBase &err) {
473 result = op.emitError(message: err.message());
474 });
475 }
476 return result;
477}
478
479template <typename T>
480static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
481 if (!result)
482 return handleError(result.takeError(), op);
483
484 return success();
485}
486
487/// Find the insertion point for allocas given the current insertion point for
488/// normal operations in the builder.
489static llvm::OpenMPIRBuilder::InsertPointTy
490findAllocaInsertPoint(llvm::IRBuilderBase &builder,
491 LLVM::ModuleTranslation &moduleTranslation) {
492 // If there is an alloca insertion point on stack, i.e. we are in a nested
493 // operation and a specific point was provided by some surrounding operation,
494 // use it.
495 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
496 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
497 callback: [&](OpenMPAllocaStackFrame &frame) {
498 allocaInsertPoint = frame.allocaInsertPoint;
499 return WalkResult::interrupt();
500 });
501 // In cases with multiple levels of outlining, the tree walk might find an
502 // alloca insertion point that is inside the original function while the
503 // builder insertion point is inside the outlined function. We need to make
504 // sure that we do not use it in those cases.
505 if (walkResult.wasInterrupted() &&
506 allocaInsertPoint.getBlock()->getParent() ==
507 builder.GetInsertBlock()->getParent())
508 return allocaInsertPoint;
509
510 // Otherwise, insert to the entry block of the surrounding function.
511 // If the current IRBuilder InsertPoint is the function's entry, it cannot
512 // also be used for alloca insertion which would result in insertion order
513 // confusion. Create a new BasicBlock for the Builder and use the entry block
514 // for the allocs.
515 // TODO: Create a dedicated alloca BasicBlock at function creation such that
516 // we do not need to move the current InertPoint here.
517 if (builder.GetInsertBlock() ==
518 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
519 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
520 "Assuming end of basic block");
521 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
522 Context&: builder.getContext(), Name: "entry", Parent: builder.GetInsertBlock()->getParent(),
523 InsertBefore: builder.GetInsertBlock()->getNextNode());
524 builder.CreateBr(Dest: entryBB);
525 builder.SetInsertPoint(entryBB);
526 }
527
528 llvm::BasicBlock &funcEntryBlock =
529 builder.GetInsertBlock()->getParent()->getEntryBlock();
530 return llvm::OpenMPIRBuilder::InsertPointTy(
531 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
532}
533
534/// Find the loop information structure for the loop nest being translated. It
535/// will return a `null` value unless called from the translation function for
536/// a loop wrapper operation after successfully translating its body.
537static llvm::CanonicalLoopInfo *
538findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
539 llvm::CanonicalLoopInfo *loopInfo = nullptr;
540 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
541 callback: [&](OpenMPLoopInfoStackFrame &frame) {
542 loopInfo = frame.loopInfo;
543 return WalkResult::interrupt();
544 });
545 return loopInfo;
546}
547
548/// Converts the given region that appears within an OpenMP dialect operation to
549/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
550/// region, and a branch from any block with an successor-less OpenMP terminator
551/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
552/// of the continuation block if provided.
553static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
554 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
555 LLVM::ModuleTranslation &moduleTranslation,
556 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
557 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(Val: region.getParentOp());
558
559 llvm::BasicBlock *continuationBlock =
560 splitBB(Builder&: builder, CreateBranch: true, Name: "omp.region.cont");
561 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
562
563 llvm::LLVMContext &llvmContext = builder.getContext();
564 for (Block &bb : region) {
565 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
566 Context&: llvmContext, Name: blockName, Parent: builder.GetInsertBlock()->getParent(),
567 InsertBefore: builder.GetInsertBlock()->getNextNode());
568 moduleTranslation.mapBlock(mlir: &bb, llvm: llvmBB);
569 }
570
571 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
572
573 // Terminators (namely YieldOp) may be forwarding values to the region that
574 // need to be available in the continuation block. Collect the types of these
575 // operands in preparation of creating PHI nodes. This is skipped for loop
576 // wrapper operations, for which we know in advance they have no terminators.
577 SmallVector<llvm::Type *> continuationBlockPHITypes;
578 unsigned numYields = 0;
579
580 if (!isLoopWrapper) {
581 bool operandsProcessed = false;
582 for (Block &bb : region.getBlocks()) {
583 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(Val: bb.getTerminator())) {
584 if (!operandsProcessed) {
585 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
586 continuationBlockPHITypes.push_back(
587 Elt: moduleTranslation.convertType(type: yield->getOperand(idx: i).getType()));
588 }
589 operandsProcessed = true;
590 } else {
591 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
592 "mismatching number of values yielded from the region");
593 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
594 llvm::Type *operandType =
595 moduleTranslation.convertType(type: yield->getOperand(idx: i).getType());
596 (void)operandType;
597 assert(continuationBlockPHITypes[i] == operandType &&
598 "values of mismatching types yielded from the region");
599 }
600 }
601 numYields++;
602 }
603 }
604 }
605
606 // Insert PHI nodes in the continuation block for any values forwarded by the
607 // terminators in this region.
608 if (!continuationBlockPHITypes.empty())
609 assert(
610 continuationBlockPHIs &&
611 "expected continuation block PHIs if converted regions yield values");
612 if (continuationBlockPHIs) {
613 llvm::IRBuilderBase::InsertPointGuard guard(builder);
614 continuationBlockPHIs->reserve(N: continuationBlockPHITypes.size());
615 builder.SetInsertPoint(TheBB: continuationBlock, IP: continuationBlock->begin());
616 for (llvm::Type *ty : continuationBlockPHITypes)
617 continuationBlockPHIs->push_back(Elt: builder.CreatePHI(Ty: ty, NumReservedValues: numYields));
618 }
619
620 // Convert blocks one by one in topological order to ensure
621 // defs are converted before uses.
622 SetVector<Block *> blocks = getBlocksSortedByDominance(region);
623 for (Block *bb : blocks) {
624 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(block: bb);
625 // Retarget the branch of the entry block to the entry block of the
626 // converted region (regions are single-entry).
627 if (bb->isEntryBlock()) {
628 assert(sourceTerminator->getNumSuccessors() == 1 &&
629 "provided entry block has multiple successors");
630 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
631 "ContinuationBlock is not the successor of the entry block");
632 sourceTerminator->setSuccessor(Idx: 0, BB: llvmBB);
633 }
634
635 llvm::IRBuilderBase::InsertPointGuard guard(builder);
636 if (failed(
637 Result: moduleTranslation.convertBlock(bb&: *bb, ignoreArguments: bb->isEntryBlock(), builder)))
638 return llvm::make_error<PreviouslyReportedError>();
639
640 // Create a direct branch here for loop wrappers to prevent their lack of a
641 // terminator from causing a crash below.
642 if (isLoopWrapper) {
643 builder.CreateBr(Dest: continuationBlock);
644 continue;
645 }
646
647 // Special handling for `omp.yield` and `omp.terminator` (we may have more
648 // than one): they return the control to the parent OpenMP dialect operation
649 // so replace them with the branch to the continuation block. We handle this
650 // here to avoid relying inter-function communication through the
651 // ModuleTranslation class to set up the correct insertion point. This is
652 // also consistent with MLIR's idiom of handling special region terminators
653 // in the same code that handles the region-owning operation.
654 Operation *terminator = bb->getTerminator();
655 if (isa<omp::TerminatorOp, omp::YieldOp>(Val: terminator)) {
656 builder.CreateBr(Dest: continuationBlock);
657
658 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
659 (*continuationBlockPHIs)[i]->addIncoming(
660 V: moduleTranslation.lookupValue(value: terminator->getOperand(idx: i)), BB: llvmBB);
661 }
662 }
663 // After all blocks have been traversed and values mapped, connect the PHI
664 // nodes to the results of preceding blocks.
665 LLVM::detail::connectPHINodes(region, state: moduleTranslation);
666
667 // Remove the blocks and values defined in this region from the mapping since
668 // they are not visible outside of this region. This allows the same region to
669 // be converted several times, that is cloned, without clashes, and slightly
670 // speeds up the lookups.
671 moduleTranslation.forgetMapping(region);
672
673 return continuationBlock;
674}
675
676/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
677static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
678 switch (kind) {
679 case omp::ClauseProcBindKind::Close:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
681 case omp::ClauseProcBindKind::Master:
682 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
683 case omp::ClauseProcBindKind::Primary:
684 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
685 case omp::ClauseProcBindKind::Spread:
686 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
687 }
688 llvm_unreachable("Unknown ClauseProcBindKind kind");
689}
690
691/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
692/// corresponding LLVM values of \p the interface's operands. This is useful
693/// when an OpenMP region with entry block arguments is converted to LLVM. In
694/// this case the block arguments are (part of) of the OpenMP region's entry
695/// arguments and the operands are (part of) of the operands to the OpenMP op
696/// containing the region.
697static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
698 omp::BlockArgOpenMPOpInterface blockArgIface) {
699 llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
700 blockArgIface.getBlockArgsPairs(pairs&: blockArgsPairs);
701 for (auto [var, arg] : blockArgsPairs)
702 moduleTranslation.mapValue(mlir: arg, llvm: moduleTranslation.lookupValue(value: var));
703}
704
705/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
706static LogicalResult
707convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
708 LLVM::ModuleTranslation &moduleTranslation) {
709 auto maskedOp = cast<omp::MaskedOp>(Val&: opInst);
710 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
711
712 if (failed(Result: checkImplementationStatus(op&: opInst)))
713 return failure();
714
715 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
716 // MaskedOp has only one region associated with it.
717 auto &region = maskedOp.getRegion();
718 builder.restoreIP(IP: codeGenIP);
719 return convertOmpOpRegions(region, blockName: "omp.masked.region", builder,
720 moduleTranslation)
721 .takeError();
722 };
723
724 // TODO: Perform finalization actions for variables. This has to be
725 // called for variables which have destructors/finalizers.
726 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
727
728 llvm::Value *filterVal = nullptr;
729 if (auto filterVar = maskedOp.getFilteredThreadId()) {
730 filterVal = moduleTranslation.lookupValue(value: filterVar);
731 } else {
732 llvm::LLVMContext &llvmContext = builder.getContext();
733 filterVal =
734 llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: llvmContext), /*V=*/0);
735 }
736 assert(filterVal != nullptr);
737 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
738 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
739 moduleTranslation.getOpenMPBuilder()->createMasked(Loc: ompLoc, BodyGenCB: bodyGenCB,
740 FiniCB: finiCB, Filter: filterVal);
741
742 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
743 return failure();
744
745 builder.restoreIP(IP: *afterIP);
746 return success();
747}
748
749/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
750static LogicalResult
751convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
752 LLVM::ModuleTranslation &moduleTranslation) {
753 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
754 auto masterOp = cast<omp::MasterOp>(Val&: opInst);
755
756 if (failed(Result: checkImplementationStatus(op&: opInst)))
757 return failure();
758
759 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
760 // MasterOp has only one region associated with it.
761 auto &region = masterOp.getRegion();
762 builder.restoreIP(IP: codeGenIP);
763 return convertOmpOpRegions(region, blockName: "omp.master.region", builder,
764 moduleTranslation)
765 .takeError();
766 };
767
768 // TODO: Perform finalization actions for variables. This has to be
769 // called for variables which have destructors/finalizers.
770 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
771
772 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
773 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
774 moduleTranslation.getOpenMPBuilder()->createMaster(Loc: ompLoc, BodyGenCB: bodyGenCB,
775 FiniCB: finiCB);
776
777 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
778 return failure();
779
780 builder.restoreIP(IP: *afterIP);
781 return success();
782}
783
784/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
785static LogicalResult
786convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
787 LLVM::ModuleTranslation &moduleTranslation) {
788 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
789 auto criticalOp = cast<omp::CriticalOp>(Val&: opInst);
790
791 if (failed(Result: checkImplementationStatus(op&: opInst)))
792 return failure();
793
794 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
795 // CriticalOp has only one region associated with it.
796 auto &region = cast<omp::CriticalOp>(Val&: opInst).getRegion();
797 builder.restoreIP(IP: codeGenIP);
798 return convertOmpOpRegions(region, blockName: "omp.critical.region", builder,
799 moduleTranslation)
800 .takeError();
801 };
802
803 // TODO: Perform finalization actions for variables. This has to be
804 // called for variables which have destructors/finalizers.
805 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
806
807 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
808 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
809 llvm::Constant *hint = nullptr;
810
811 // If it has a name, it probably has a hint too.
812 if (criticalOp.getNameAttr()) {
813 // The verifiers in OpenMP Dialect guarentee that all the pointers are
814 // non-null
815 auto symbolRef = cast<SymbolRefAttr>(Val: criticalOp.getNameAttr());
816 auto criticalDeclareOp =
817 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(from: criticalOp,
818 symbol: symbolRef);
819 hint =
820 llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: llvmContext),
821 V: static_cast<int>(criticalDeclareOp.getHint()));
822 }
823 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
824 moduleTranslation.getOpenMPBuilder()->createCritical(
825 Loc: ompLoc, BodyGenCB: bodyGenCB, FiniCB: finiCB, CriticalName: criticalOp.getName().value_or(u: ""), HintInst: hint);
826
827 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
828 return failure();
829
830 builder.restoreIP(IP: *afterIP);
831 return success();
832}
833
834/// A util to collect info needed to convert delayed privatizers from MLIR to
835/// LLVM.
836struct PrivateVarsInfo {
837 template <typename OP>
838 PrivateVarsInfo(OP op)
839 : blockArgs(
840 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
841 mlirVars.reserve(N: blockArgs.size());
842 llvmVars.reserve(N: blockArgs.size());
843 collectPrivatizationDecls<OP>(op);
844
845 for (mlir::Value privateVar : op.getPrivateVars())
846 mlirVars.push_back(Elt: privateVar);
847 }
848
849 MutableArrayRef<BlockArgument> blockArgs;
850 SmallVector<mlir::Value> mlirVars;
851 SmallVector<llvm::Value *> llvmVars;
852 SmallVector<omp::PrivateClauseOp> privatizers;
853
854private:
855 /// Populates `privatizations` with privatization declarations used for the
856 /// given op.
857 template <class OP>
858 void collectPrivatizationDecls(OP op) {
859 std::optional<ArrayAttr> attr = op.getPrivateSyms();
860 if (!attr)
861 return;
862
863 privatizers.reserve(N: privatizers.size() + attr->size());
864 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
865 privatizers.push_back(Elt: findPrivatizer(op, symbolRef));
866 }
867 }
868};
869
870/// Populates `reductions` with reduction declarations used in the given op.
871template <typename T>
872static void
873collectReductionDecls(T op,
874 SmallVectorImpl<omp::DeclareReductionOp> &reductions) {
875 std::optional<ArrayAttr> attr = op.getReductionSyms();
876 if (!attr)
877 return;
878
879 reductions.reserve(N: reductions.size() + op.getNumReductionVars());
880 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
881 reductions.push_back(
882 Elt: SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
883 op, symbolRef));
884 }
885}
886
887/// Translates the blocks contained in the given region and appends them to at
888/// the current insertion point of `builder`. The operations of the entry block
889/// are appended to the current insertion block. If set, `continuationBlockArgs`
890/// is populated with translated values that correspond to the values
891/// omp.yield'ed from the region.
892static LogicalResult inlineConvertOmpRegions(
893 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
894 LLVM::ModuleTranslation &moduleTranslation,
895 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
896 if (region.empty())
897 return success();
898
899 // Special case for single-block regions that don't create additional blocks:
900 // insert operations without creating additional blocks.
901 if (llvm::hasSingleElement(C&: region)) {
902 llvm::Instruction *potentialTerminator =
903 builder.GetInsertBlock()->empty() ? nullptr
904 : &builder.GetInsertBlock()->back();
905
906 if (potentialTerminator && potentialTerminator->isTerminator())
907 potentialTerminator->removeFromParent();
908 moduleTranslation.mapBlock(mlir: &region.front(), llvm: builder.GetInsertBlock());
909
910 if (failed(Result: moduleTranslation.convertBlock(
911 bb&: region.front(), /*ignoreArguments=*/true, builder)))
912 return failure();
913
914 // The continuation arguments are simply the translated terminator operands.
915 if (continuationBlockArgs)
916 llvm::append_range(
917 C&: *continuationBlockArgs,
918 R: moduleTranslation.lookupValues(values: region.front().back().getOperands()));
919
920 // Drop the mapping that is no longer necessary so that the same region can
921 // be processed multiple times.
922 moduleTranslation.forgetMapping(region);
923
924 if (potentialTerminator && potentialTerminator->isTerminator()) {
925 llvm::BasicBlock *block = builder.GetInsertBlock();
926 if (block->empty()) {
927 // this can happen for really simple reduction init regions e.g.
928 // %0 = llvm.mlir.constant(0 : i32) : i32
929 // omp.yield(%0 : i32)
930 // because the llvm.mlir.constant (MLIR op) isn't converted into any
931 // llvm op
932 potentialTerminator->insertInto(ParentBB: block, It: block->begin());
933 } else {
934 potentialTerminator->insertAfter(InsertPos: &block->back());
935 }
936 }
937
938 return success();
939 }
940
941 SmallVector<llvm::PHINode *> phis;
942 llvm::Expected<llvm::BasicBlock *> continuationBlock =
943 convertOmpOpRegions(region, blockName, builder, moduleTranslation, continuationBlockPHIs: &phis);
944
945 if (failed(Result: handleError(result&: continuationBlock, op&: *region.getParentOp())))
946 return failure();
947
948 if (continuationBlockArgs)
949 llvm::append_range(C&: *continuationBlockArgs, R&: phis);
950 builder.SetInsertPoint(TheBB: *continuationBlock,
951 IP: (*continuationBlock)->getFirstInsertionPt());
952 return success();
953}
954
955namespace {
956/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
957/// store lambdas with capture.
958using OwningReductionGen =
959 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
960 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
961 llvm::Value *&)>;
962using OwningAtomicReductionGen =
963 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
964 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
965 llvm::Value *)>;
966} // namespace
967
968/// Create an OpenMPIRBuilder-compatible reduction generator for the given
969/// reduction declaration. The generator uses `builder` but ignores its
970/// insertion point.
971static OwningReductionGen
972makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
973 LLVM::ModuleTranslation &moduleTranslation) {
974 // The lambda is mutable because we need access to non-const methods of decl
975 // (which aren't actually mutating it), and we must capture decl by-value to
976 // avoid the dangling reference after the parent function returns.
977 OwningReductionGen gen =
978 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
979 llvm::Value *lhs, llvm::Value *rhs,
980 llvm::Value *&result) mutable
981 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
982 moduleTranslation.mapValue(mlir: decl.getReductionLhsArg(), llvm: lhs);
983 moduleTranslation.mapValue(mlir: decl.getReductionRhsArg(), llvm: rhs);
984 builder.restoreIP(IP: insertPoint);
985 SmallVector<llvm::Value *> phis;
986 if (failed(Result: inlineConvertOmpRegions(region&: decl.getReductionRegion(),
987 blockName: "omp.reduction.nonatomic.body", builder,
988 moduleTranslation, continuationBlockArgs: &phis)))
989 return llvm::createStringError(
990 Fmt: "failed to inline `combiner` region of `omp.declare_reduction`");
991 result = llvm::getSingleElement(C&: phis);
992 return builder.saveIP();
993 };
994 return gen;
995}
996
997/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
998/// given reduction declaration. The generator uses `builder` but ignores its
999/// insertion point. Returns null if there is no atomic region available in the
1000/// reduction declaration.
1001static OwningAtomicReductionGen
1002makeAtomicReductionGen(omp::DeclareReductionOp decl,
1003 llvm::IRBuilderBase &builder,
1004 LLVM::ModuleTranslation &moduleTranslation) {
1005 if (decl.getAtomicReductionRegion().empty())
1006 return OwningAtomicReductionGen();
1007
1008 // The lambda is mutable because we need access to non-const methods of decl
1009 // (which aren't actually mutating it), and we must capture decl by-value to
1010 // avoid the dangling reference after the parent function returns.
1011 OwningAtomicReductionGen atomicGen =
1012 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1013 llvm::Value *lhs, llvm::Value *rhs) mutable
1014 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1015 moduleTranslation.mapValue(mlir: decl.getAtomicReductionLhsArg(), llvm: lhs);
1016 moduleTranslation.mapValue(mlir: decl.getAtomicReductionRhsArg(), llvm: rhs);
1017 builder.restoreIP(IP: insertPoint);
1018 SmallVector<llvm::Value *> phis;
1019 if (failed(Result: inlineConvertOmpRegions(region&: decl.getAtomicReductionRegion(),
1020 blockName: "omp.reduction.atomic.body", builder,
1021 moduleTranslation, continuationBlockArgs: &phis)))
1022 return llvm::createStringError(
1023 Fmt: "failed to inline `atomic` region of `omp.declare_reduction`");
1024 assert(phis.empty());
1025 return builder.saveIP();
1026 };
1027 return atomicGen;
1028}
1029
1030/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1031static LogicalResult
1032convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1033 LLVM::ModuleTranslation &moduleTranslation) {
1034 auto orderedOp = cast<omp::OrderedOp>(Val&: opInst);
1035
1036 if (failed(Result: checkImplementationStatus(op&: opInst)))
1037 return failure();
1038
1039 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1040 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1041 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1042 SmallVector<llvm::Value *> vecValues =
1043 moduleTranslation.lookupValues(values: orderedOp.getDoacrossDependVars());
1044
1045 size_t indexVecValues = 0;
1046 while (indexVecValues < vecValues.size()) {
1047 SmallVector<llvm::Value *> storeValues;
1048 storeValues.reserve(N: numLoops);
1049 for (unsigned i = 0; i < numLoops; i++) {
1050 storeValues.push_back(Elt: vecValues[indexVecValues]);
1051 indexVecValues++;
1052 }
1053 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1054 findAllocaInsertPoint(builder, moduleTranslation);
1055 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1056 builder.restoreIP(IP: moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1057 Loc: ompLoc, AllocaIP: allocaIP, NumLoops: numLoops, StoreValues: storeValues, Name: ".cnt.addr", IsDependSource: isDependSource));
1058 }
1059 return success();
1060}
1061
1062/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1063/// OpenMPIRBuilder.
1064static LogicalResult
1065convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1066 LLVM::ModuleTranslation &moduleTranslation) {
1067 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1068 auto orderedRegionOp = cast<omp::OrderedRegionOp>(Val&: opInst);
1069
1070 if (failed(Result: checkImplementationStatus(op&: opInst)))
1071 return failure();
1072
1073 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1074 // OrderedOp has only one region associated with it.
1075 auto &region = cast<omp::OrderedRegionOp>(Val&: opInst).getRegion();
1076 builder.restoreIP(IP: codeGenIP);
1077 return convertOmpOpRegions(region, blockName: "omp.ordered.region", builder,
1078 moduleTranslation)
1079 .takeError();
1080 };
1081
1082 // TODO: Perform finalization actions for variables. This has to be
1083 // called for variables which have destructors/finalizers.
1084 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1085
1086 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1087 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1088 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1089 Loc: ompLoc, BodyGenCB: bodyGenCB, FiniCB: finiCB, IsThreads: !orderedRegionOp.getParLevelSimd());
1090
1091 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
1092 return failure();
1093
1094 builder.restoreIP(IP: *afterIP);
1095 return success();
1096}
1097
1098namespace {
1099/// Contains the arguments for an LLVM store operation
1100struct DeferredStore {
1101 DeferredStore(llvm::Value *value, llvm::Value *address)
1102 : value(value), address(address) {}
1103
1104 llvm::Value *value;
1105 llvm::Value *address;
1106};
1107} // namespace
1108
1109/// Allocate space for privatized reduction variables.
1110/// `deferredStores` contains information to create store operations which needs
1111/// to be inserted after all allocas
1112template <typename T>
1113static LogicalResult
1114allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
1115 llvm::IRBuilderBase &builder,
1116 LLVM::ModuleTranslation &moduleTranslation,
1117 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1118 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1119 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1120 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1121 SmallVectorImpl<DeferredStore> &deferredStores,
1122 llvm::ArrayRef<bool> isByRefs) {
1123 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1124 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1125
1126 // delay creating stores until after all allocas
1127 deferredStores.reserve(N: loop.getNumReductionVars());
1128
1129 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1130 Region &allocRegion = reductionDecls[i].getAllocRegion();
1131 if (isByRefs[i]) {
1132 if (allocRegion.empty())
1133 continue;
1134
1135 SmallVector<llvm::Value *, 1> phis;
1136 if (failed(Result: inlineConvertOmpRegions(region&: allocRegion, blockName: "omp.reduction.alloc",
1137 builder, moduleTranslation, continuationBlockArgs: &phis)))
1138 return loop.emitError(
1139 "failed to inline `alloc` region of `omp.declare_reduction`");
1140
1141 assert(phis.size() == 1 && "expected one allocation to be yielded");
1142 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1143
1144 // Allocate reduction variable (which is a pointer to the real reduction
1145 // variable allocated in the inlined region)
1146 llvm::Value *var = builder.CreateAlloca(
1147 Ty: moduleTranslation.convertType(type: reductionDecls[i].getType()));
1148
1149 llvm::Type *ptrTy = builder.getPtrTy();
1150 llvm::Value *castVar =
1151 builder.CreatePointerBitCastOrAddrSpaceCast(V: var, DestTy: ptrTy);
1152 llvm::Value *castPhi =
1153 builder.CreatePointerBitCastOrAddrSpaceCast(V: phis[0], DestTy: ptrTy);
1154
1155 deferredStores.emplace_back(Args&: castPhi, Args&: castVar);
1156
1157 privateReductionVariables[i] = castVar;
1158 moduleTranslation.mapValue(mlir: reductionArgs[i], llvm: castPhi);
1159 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1160 } else {
1161 assert(allocRegion.empty() &&
1162 "allocaction is implicit for by-val reduction");
1163 llvm::Value *var = builder.CreateAlloca(
1164 Ty: moduleTranslation.convertType(type: reductionDecls[i].getType()));
1165
1166 llvm::Type *ptrTy = builder.getPtrTy();
1167 llvm::Value *castVar =
1168 builder.CreatePointerBitCastOrAddrSpaceCast(V: var, DestTy: ptrTy);
1169
1170 moduleTranslation.mapValue(mlir: reductionArgs[i], llvm: castVar);
1171 privateReductionVariables[i] = castVar;
1172 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1173 }
1174 }
1175
1176 return success();
1177}
1178
1179/// Map input arguments to reduction initialization region
1180template <typename T>
1181static void
1182mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
1183 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1184 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1185 unsigned i) {
1186 // map input argument to the initialization region
1187 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1188 Region &initializerRegion = reduction.getInitializerRegion();
1189 Block &entry = initializerRegion.front();
1190
1191 mlir::Value mlirSource = loop.getReductionVars()[i];
1192 llvm::Value *llvmSource = moduleTranslation.lookupValue(value: mlirSource);
1193 assert(llvmSource && "lookup reduction var");
1194 moduleTranslation.mapValue(mlir: reduction.getInitializerMoldArg(), llvm: llvmSource);
1195
1196 if (entry.getNumArguments() > 1) {
1197 llvm::Value *allocation =
1198 reductionVariableMap.lookup(Val: loop.getReductionVars()[i]);
1199 moduleTranslation.mapValue(mlir: reduction.getInitializerAllocArg(), llvm: allocation);
1200 }
1201}
1202
1203static void
1204setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1205 llvm::BasicBlock *block = nullptr) {
1206 if (block == nullptr)
1207 block = builder.GetInsertBlock();
1208
1209 if (block->empty() || block->getTerminator() == nullptr)
1210 builder.SetInsertPoint(block);
1211 else
1212 builder.SetInsertPoint(block->getTerminator());
1213}
1214
1215/// Inline reductions' `init` regions. This functions assumes that the
1216/// `builder`'s insertion point is where the user wants the `init` regions to be
1217/// inlined; i.e. it does not try to find a proper insertion location for the
1218/// `init` regions. It also leaves the `builder's insertions point in a state
1219/// where the user can continue the code-gen directly afterwards.
1220template <typename OP>
1221static LogicalResult
1222initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1223 llvm::IRBuilderBase &builder,
1224 LLVM::ModuleTranslation &moduleTranslation,
1225 llvm::BasicBlock *latestAllocaBlock,
1226 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1227 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1228 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1229 llvm::ArrayRef<bool> isByRef,
1230 SmallVectorImpl<DeferredStore> &deferredStores) {
1231 if (op.getNumReductionVars() == 0)
1232 return success();
1233
1234 llvm::BasicBlock *initBlock = splitBB(Builder&: builder, CreateBranch: true, Name: "omp.reduction.init");
1235 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1236 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1237 builder.restoreIP(IP: allocaIP);
1238 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1239
1240 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1241 if (isByRef[i]) {
1242 if (!reductionDecls[i].getAllocRegion().empty())
1243 continue;
1244
1245 // TODO: remove after all users of by-ref are updated to use the alloc
1246 // region: Allocate reduction variable (which is a pointer to the real
1247 // reduciton variable allocated in the inlined region)
1248 byRefVars[i] = builder.CreateAlloca(
1249 Ty: moduleTranslation.convertType(type: reductionDecls[i].getType()));
1250 }
1251 }
1252
1253 setInsertPointForPossiblyEmptyBlock(builder, block: initBlock);
1254
1255 // store result of the alloc region to the allocated pointer to the real
1256 // reduction variable
1257 for (auto [data, addr] : deferredStores)
1258 builder.CreateStore(Val: data, Ptr: addr);
1259
1260 // Before the loop, store the initial values of reductions into reduction
1261 // variables. Although this could be done after allocas, we don't want to mess
1262 // up with the alloca insertion point.
1263 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1264 SmallVector<llvm::Value *, 1> phis;
1265
1266 // map block argument to initializer region
1267 mapInitializationArgs(op, moduleTranslation, reductionDecls,
1268 reductionVariableMap, i);
1269
1270 // TODO In some cases (specially on the GPU), the init regions may
1271 // contains stack alloctaions. If the region is inlined in a loop, this is
1272 // problematic. Instead of just inlining the region, handle allocations by
1273 // hoisting fixed length allocations to the function entry and using
1274 // stacksave and restore for variable length ones.
1275 if (failed(Result: inlineConvertOmpRegions(region&: reductionDecls[i].getInitializerRegion(),
1276 blockName: "omp.reduction.neutral", builder,
1277 moduleTranslation, continuationBlockArgs: &phis)))
1278 return failure();
1279
1280 assert(phis.size() == 1 && "expected one value to be yielded from the "
1281 "reduction neutral element declaration region");
1282
1283 setInsertPointForPossiblyEmptyBlock(builder);
1284
1285 if (isByRef[i]) {
1286 if (!reductionDecls[i].getAllocRegion().empty())
1287 // done in allocReductionVars
1288 continue;
1289
1290 // TODO: this path can be removed once all users of by-ref are updated to
1291 // use an alloc region
1292
1293 // Store the result of the inlined region to the allocated reduction var
1294 // ptr
1295 builder.CreateStore(Val: phis[0], Ptr: byRefVars[i]);
1296
1297 privateReductionVariables[i] = byRefVars[i];
1298 moduleTranslation.mapValue(mlir: reductionArgs[i], llvm: phis[0]);
1299 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1300 } else {
1301 // for by-ref case the store is inside of the reduction region
1302 builder.CreateStore(Val: phis[0], Ptr: privateReductionVariables[i]);
1303 // the rest was handled in allocByValReductionVars
1304 }
1305
1306 // forget the mapping for the initializer region because we might need a
1307 // different mapping if this reduction declaration is re-used for a
1308 // different variable
1309 moduleTranslation.forgetMapping(region&: reductionDecls[i].getInitializerRegion());
1310 }
1311
1312 return success();
1313}
1314
1315/// Collect reduction info
1316template <typename T>
1317static void collectReductionInfo(
1318 T loop, llvm::IRBuilderBase &builder,
1319 LLVM::ModuleTranslation &moduleTranslation,
1320 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1321 SmallVectorImpl<OwningReductionGen> &owningReductionGens,
1322 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1323 const ArrayRef<llvm::Value *> privateReductionVariables,
1324 SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
1325 unsigned numReductions = loop.getNumReductionVars();
1326
1327 for (unsigned i = 0; i < numReductions; ++i) {
1328 owningReductionGens.push_back(
1329 Elt: makeReductionGen(decl: reductionDecls[i], builder, moduleTranslation));
1330 owningAtomicReductionGens.push_back(
1331 Elt: makeAtomicReductionGen(decl: reductionDecls[i], builder, moduleTranslation));
1332 }
1333
1334 // Collect the reduction information.
1335 reductionInfos.reserve(N: numReductions);
1336 for (unsigned i = 0; i < numReductions; ++i) {
1337 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1338 if (owningAtomicReductionGens[i])
1339 atomicGen = owningAtomicReductionGens[i];
1340 llvm::Value *variable =
1341 moduleTranslation.lookupValue(value: loop.getReductionVars()[i]);
1342 reductionInfos.push_back(
1343 Elt: {moduleTranslation.convertType(type: reductionDecls[i].getType()), variable,
1344 privateReductionVariables[i],
1345 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1346 owningReductionGens[i],
1347 /*ReductionGenClang=*/nullptr, atomicGen});
1348 }
1349}
1350
1351/// handling of DeclareReductionOp's cleanup region
1352static LogicalResult
1353inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
1354 llvm::ArrayRef<llvm::Value *> privateVariables,
1355 LLVM::ModuleTranslation &moduleTranslation,
1356 llvm::IRBuilderBase &builder, StringRef regionName,
1357 bool shouldLoadCleanupRegionArg = true) {
1358 for (auto [i, cleanupRegion] : llvm::enumerate(First&: cleanupRegions)) {
1359 if (cleanupRegion->empty())
1360 continue;
1361
1362 // map the argument to the cleanup region
1363 Block &entry = cleanupRegion->front();
1364
1365 llvm::Instruction *potentialTerminator =
1366 builder.GetInsertBlock()->empty() ? nullptr
1367 : &builder.GetInsertBlock()->back();
1368 if (potentialTerminator && potentialTerminator->isTerminator())
1369 builder.SetInsertPoint(potentialTerminator);
1370 llvm::Value *privateVarValue =
1371 shouldLoadCleanupRegionArg
1372 ? builder.CreateLoad(
1373 Ty: moduleTranslation.convertType(type: entry.getArgument(i: 0).getType()),
1374 Ptr: privateVariables[i])
1375 : privateVariables[i];
1376
1377 moduleTranslation.mapValue(mlir: entry.getArgument(i: 0), llvm: privateVarValue);
1378
1379 if (failed(Result: inlineConvertOmpRegions(region&: *cleanupRegion, blockName: regionName, builder,
1380 moduleTranslation)))
1381 return failure();
1382
1383 // clear block argument mapping in case it needs to be re-created with a
1384 // different source for another use of the same reduction decl
1385 moduleTranslation.forgetMapping(region&: *cleanupRegion);
1386 }
1387 return success();
1388}
1389
1390// TODO: not used by ParallelOp
1391template <class OP>
1392static LogicalResult createReductionsAndCleanup(
1393 OP op, llvm::IRBuilderBase &builder,
1394 LLVM::ModuleTranslation &moduleTranslation,
1395 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1396 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1397 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1398 bool isNowait = false, bool isTeamsReduction = false) {
1399 // Process the reductions if required.
1400 if (op.getNumReductionVars() == 0)
1401 return success();
1402
1403 SmallVector<OwningReductionGen> owningReductionGens;
1404 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1405 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1406
1407 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1408
1409 // Create the reduction generators. We need to own them here because
1410 // ReductionInfo only accepts references to the generators.
1411 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1412 owningReductionGens, owningAtomicReductionGens,
1413 privateReductionVariables, reductionInfos);
1414
1415 // The call to createReductions below expects the block to have a
1416 // terminator. Create an unreachable instruction to serve as terminator
1417 // and remove it later.
1418 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1419 builder.SetInsertPoint(tempTerminator);
1420 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1421 ompBuilder->createReductions(Loc: builder.saveIP(), AllocaIP: allocaIP, ReductionInfos: reductionInfos,
1422 IsByRef: isByRef, IsNoWait: isNowait, IsTeamsReduction: isTeamsReduction);
1423
1424 if (failed(handleError(contInsertPoint, *op)))
1425 return failure();
1426
1427 if (!contInsertPoint->getBlock())
1428 return op->emitOpError() << "failed to convert reductions";
1429
1430 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1431 ompBuilder->createBarrier(Loc: *contInsertPoint, Kind: llvm::omp::OMPD_for);
1432
1433 if (failed(handleError(afterIP, *op)))
1434 return failure();
1435
1436 tempTerminator->eraseFromParent();
1437 builder.restoreIP(IP: *afterIP);
1438
1439 // after the construct, deallocate private reduction variables
1440 SmallVector<Region *> reductionRegions;
1441 llvm::transform(reductionDecls, std::back_inserter(x&: reductionRegions),
1442 [](omp::DeclareReductionOp reductionDecl) {
1443 return &reductionDecl.getCleanupRegion();
1444 });
1445 return inlineOmpRegionCleanup(cleanupRegions&: reductionRegions, privateVariables: privateReductionVariables,
1446 moduleTranslation, builder,
1447 regionName: "omp.reduction.cleanup");
1448 return success();
1449}
1450
1451static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1452 if (!attr)
1453 return {};
1454 return *attr;
1455}
1456
1457// TODO: not used by omp.parallel
1458template <typename OP>
1459static LogicalResult allocAndInitializeReductionVars(
1460 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1461 LLVM::ModuleTranslation &moduleTranslation,
1462 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1463 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1464 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1465 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1466 llvm::ArrayRef<bool> isByRef) {
1467 if (op.getNumReductionVars() == 0)
1468 return success();
1469
1470 SmallVector<DeferredStore> deferredStores;
1471
1472 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1473 allocaIP, reductionDecls,
1474 privateReductionVariables, reductionVariableMap,
1475 deferredStores, isByRef)))
1476 return failure();
1477
1478 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1479 allocaIP.getBlock(), reductionDecls,
1480 privateReductionVariables, reductionVariableMap,
1481 isByRef, deferredStores);
1482}
1483
1484/// Return the llvm::Value * corresponding to the `privateVar` that
1485/// is being privatized. It isn't always as simple as looking up
1486/// moduleTranslation with privateVar. For instance, in case of
1487/// an allocatable, the descriptor for the allocatable is privatized.
1488/// This descriptor is mapped using an MapInfoOp. So, this function
1489/// will return a pointer to the llvm::Value corresponding to the
1490/// block argument for the mapped descriptor.
1491static llvm::Value *
1492findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1493 LLVM::ModuleTranslation &moduleTranslation,
1494 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1495 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(Val: privateVar))
1496 return moduleTranslation.lookupValue(value: privateVar);
1497
1498 Value blockArg = (*mappedPrivateVars)[privateVar];
1499 Type privVarType = privateVar.getType();
1500 Type blockArgType = blockArg.getType();
1501 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1502 "A block argument corresponding to a mapped var should have "
1503 "!llvm.ptr type");
1504
1505 if (privVarType == blockArgType)
1506 return moduleTranslation.lookupValue(value: blockArg);
1507
1508 // This typically happens when the privatized type is lowered from
1509 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1510 // struct/pair is passed by value. But, mapped values are passed only as
1511 // pointers, so before we privatize, we must load the pointer.
1512 if (!isa<LLVM::LLVMPointerType>(Val: privVarType))
1513 return builder.CreateLoad(Ty: moduleTranslation.convertType(type: privVarType),
1514 Ptr: moduleTranslation.lookupValue(value: blockArg));
1515
1516 return moduleTranslation.lookupValue(value: privateVar);
1517}
1518
1519/// Initialize a single (first)private variable. You probably want to use
1520/// allocateAndInitPrivateVars instead of this.
1521/// This returns the private variable which has been initialized. This
1522/// variable should be mapped before constructing the body of the Op.
1523static llvm::Expected<llvm::Value *> initPrivateVar(
1524 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1525 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1526 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1527 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1528 Region &initRegion = privDecl.getInitRegion();
1529 if (initRegion.empty())
1530 return llvmPrivateVar;
1531
1532 // map initialization region block arguments
1533 llvm::Value *nonPrivateVar = findAssociatedValue(
1534 privateVar: mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1535 assert(nonPrivateVar);
1536 moduleTranslation.mapValue(mlir: privDecl.getInitMoldArg(), llvm: nonPrivateVar);
1537 moduleTranslation.mapValue(mlir: privDecl.getInitPrivateArg(), llvm: llvmPrivateVar);
1538
1539 // in-place convert the private initialization region
1540 SmallVector<llvm::Value *, 1> phis;
1541 if (failed(Result: inlineConvertOmpRegions(region&: initRegion, blockName: "omp.private.init", builder,
1542 moduleTranslation, continuationBlockArgs: &phis)))
1543 return llvm::createStringError(
1544 Fmt: "failed to inline `init` region of `omp.private`");
1545
1546 assert(phis.size() == 1 && "expected one allocation to be yielded");
1547
1548 // clear init region block argument mapping in case it needs to be
1549 // re-created with a different source for another use of the same
1550 // reduction decl
1551 moduleTranslation.forgetMapping(region&: initRegion);
1552
1553 // Prefer the value yielded from the init region to the allocated private
1554 // variable in case the region is operating on arguments by-value (e.g.
1555 // Fortran character boxes).
1556 return phis[0];
1557}
1558
1559static llvm::Error
1560initPrivateVars(llvm::IRBuilderBase &builder,
1561 LLVM::ModuleTranslation &moduleTranslation,
1562 PrivateVarsInfo &privateVarsInfo,
1563 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1564 if (privateVarsInfo.blockArgs.empty())
1565 return llvm::Error::success();
1566
1567 llvm::BasicBlock *privInitBlock = splitBB(Builder&: builder, CreateBranch: true, Name: "omp.private.init");
1568 setInsertPointForPossiblyEmptyBlock(builder, block: privInitBlock);
1569
1570 for (auto [idx, zip] : llvm::enumerate(First: llvm::zip_equal(
1571 t&: privateVarsInfo.privatizers, u&: privateVarsInfo.mlirVars,
1572 args&: privateVarsInfo.blockArgs, args&: privateVarsInfo.llvmVars))) {
1573 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1574 llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
1575 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1576 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1577
1578 if (!privVarOrErr)
1579 return privVarOrErr.takeError();
1580
1581 llvmPrivateVar = privVarOrErr.get();
1582 moduleTranslation.mapValue(mlir: blockArg, llvm: llvmPrivateVar);
1583
1584 setInsertPointForPossiblyEmptyBlock(builder);
1585 }
1586
1587 return llvm::Error::success();
1588}
1589
1590/// Allocate and initialize delayed private variables. Returns the basic block
1591/// which comes after all of these allocations. llvm::Value * for each of these
1592/// private variables are populated in llvmPrivateVars.
1593static llvm::Expected<llvm::BasicBlock *>
1594allocatePrivateVars(llvm::IRBuilderBase &builder,
1595 LLVM::ModuleTranslation &moduleTranslation,
1596 PrivateVarsInfo &privateVarsInfo,
1597 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1598 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1599 // Allocate private vars
1600 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1601 splitBB(IP: llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1602 allocaTerminator->getIterator()),
1603 CreateBranch: true, DL: allocaTerminator->getStableDebugLoc(),
1604 Name: "omp.region.after_alloca");
1605
1606 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1607 // Update the allocaTerminator since the alloca block was split above.
1608 allocaTerminator = allocaIP.getBlock()->getTerminator();
1609 builder.SetInsertPoint(allocaTerminator);
1610 // The new terminator is an uncondition branch created by the splitBB above.
1611 assert(allocaTerminator->getNumSuccessors() == 1 &&
1612 "This is an unconditional branch created by splitBB");
1613
1614 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1615 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(Idx: 0);
1616
1617 unsigned int allocaAS =
1618 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1619 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1620 ->getDataLayout()
1621 .getProgramAddressSpace();
1622
1623 for (auto [privDecl, mlirPrivVar, blockArg] :
1624 llvm::zip_equal(t&: privateVarsInfo.privatizers, u&: privateVarsInfo.mlirVars,
1625 args&: privateVarsInfo.blockArgs)) {
1626 llvm::Type *llvmAllocType =
1627 moduleTranslation.convertType(type: privDecl.getType());
1628 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1629 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1630 Ty: llvmAllocType, /*ArraySize=*/nullptr, Name: "omp.private.alloc");
1631 if (allocaAS != defaultAS)
1632 llvmPrivateVar = builder.CreateAddrSpaceCast(V: llvmPrivateVar,
1633 DestTy: builder.getPtrTy(AddrSpace: defaultAS));
1634
1635 privateVarsInfo.llvmVars.push_back(Elt: llvmPrivateVar);
1636 }
1637
1638 return afterAllocas;
1639}
1640
1641static LogicalResult copyFirstPrivateVars(
1642 mlir::Operation *op, llvm::IRBuilderBase &builder,
1643 LLVM::ModuleTranslation &moduleTranslation,
1644 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1645 ArrayRef<llvm::Value *> llvmPrivateVars,
1646 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1647 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1648 // Apply copy region for firstprivate.
1649 bool needsFirstprivate =
1650 llvm::any_of(Range&: privateDecls, P: [](omp::PrivateClauseOp &privOp) {
1651 return privOp.getDataSharingType() ==
1652 omp::DataSharingClauseType::FirstPrivate;
1653 });
1654
1655 if (!needsFirstprivate)
1656 return success();
1657
1658 llvm::BasicBlock *copyBlock =
1659 splitBB(Builder&: builder, /*CreateBranch=*/true, Name: "omp.private.copy");
1660 setInsertPointForPossiblyEmptyBlock(builder, block: copyBlock);
1661
1662 for (auto [decl, mlirVar, llvmVar] :
1663 llvm::zip_equal(t&: privateDecls, u&: mlirPrivateVars, args&: llvmPrivateVars)) {
1664 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1665 continue;
1666
1667 // copyRegion implements `lhs = rhs`
1668 Region &copyRegion = decl.getCopyRegion();
1669
1670 // map copyRegion rhs arg
1671 llvm::Value *nonPrivateVar = findAssociatedValue(
1672 privateVar: mlirVar, builder, moduleTranslation, mappedPrivateVars);
1673 assert(nonPrivateVar);
1674 moduleTranslation.mapValue(mlir: decl.getCopyMoldArg(), llvm: nonPrivateVar);
1675
1676 // map copyRegion lhs arg
1677 moduleTranslation.mapValue(mlir: decl.getCopyPrivateArg(), llvm: llvmVar);
1678
1679 // in-place convert copy region
1680 if (failed(Result: inlineConvertOmpRegions(region&: copyRegion, blockName: "omp.private.copy", builder,
1681 moduleTranslation)))
1682 return decl.emitError(message: "failed to inline `copy` region of `omp.private`");
1683
1684 setInsertPointForPossiblyEmptyBlock(builder);
1685
1686 // ignore unused value yielded from copy region
1687
1688 // clear copy region block argument mapping in case it needs to be
1689 // re-created with different sources for reuse of the same reduction
1690 // decl
1691 moduleTranslation.forgetMapping(region&: copyRegion);
1692 }
1693
1694 if (insertBarrier) {
1695 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1696 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1697 ompBuilder->createBarrier(Loc: builder.saveIP(), Kind: llvm::omp::OMPD_barrier);
1698 if (failed(Result: handleError(result&: res, op&: *op)))
1699 return failure();
1700 }
1701
1702 return success();
1703}
1704
1705static LogicalResult
1706cleanupPrivateVars(llvm::IRBuilderBase &builder,
1707 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1708 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1709 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) {
1710 // private variable deallocation
1711 SmallVector<Region *> privateCleanupRegions;
1712 llvm::transform(Range&: privateDecls, d_first: std::back_inserter(x&: privateCleanupRegions),
1713 F: [](omp::PrivateClauseOp privatizer) {
1714 return &privatizer.getDeallocRegion();
1715 });
1716
1717 if (failed(Result: inlineOmpRegionCleanup(
1718 cleanupRegions&: privateCleanupRegions, privateVariables: llvmPrivateVars, moduleTranslation, builder,
1719 regionName: "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1720 return mlir::emitError(loc, message: "failed to inline `dealloc` region of an "
1721 "`omp.private` op in");
1722
1723 return success();
1724}
1725
1726/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1727static bool constructIsCancellable(Operation *op) {
1728 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1729 // be visible and not inside of function calls. This is enforced by the
1730 // verifier.
1731 return op
1732 ->walk(callback: [](Operation *child) {
1733 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(Val: child))
1734 return WalkResult::interrupt();
1735 return WalkResult::advance();
1736 })
1737 .wasInterrupted();
1738}
1739
1740static LogicalResult
1741convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1742 LLVM::ModuleTranslation &moduleTranslation) {
1743 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1744 using StorableBodyGenCallbackTy =
1745 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1746
1747 auto sectionsOp = cast<omp::SectionsOp>(Val&: opInst);
1748
1749 if (failed(Result: checkImplementationStatus(op&: opInst)))
1750 return failure();
1751
1752 llvm::ArrayRef<bool> isByRef = getIsByRef(attr: sectionsOp.getReductionByref());
1753 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1754
1755 SmallVector<omp::DeclareReductionOp> reductionDecls;
1756 collectReductionDecls(op: sectionsOp, reductions&: reductionDecls);
1757 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1758 findAllocaInsertPoint(builder, moduleTranslation);
1759
1760 SmallVector<llvm::Value *> privateReductionVariables(
1761 sectionsOp.getNumReductionVars());
1762 DenseMap<Value, llvm::Value *> reductionVariableMap;
1763
1764 MutableArrayRef<BlockArgument> reductionArgs =
1765 cast<omp::BlockArgOpenMPOpInterface>(Val&: opInst).getReductionBlockArgs();
1766
1767 if (failed(Result: allocAndInitializeReductionVars(
1768 op: sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1769 reductionDecls, privateReductionVariables, reductionVariableMap,
1770 isByRef)))
1771 return failure();
1772
1773 SmallVector<StorableBodyGenCallbackTy> sectionCBs;
1774
1775 for (Operation &op : *sectionsOp.getRegion().begin()) {
1776 auto sectionOp = dyn_cast<omp::SectionOp>(Val&: op);
1777 if (!sectionOp) // omp.terminator
1778 continue;
1779
1780 Region &region = sectionOp.getRegion();
1781 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1782 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1783 builder.restoreIP(IP: codeGenIP);
1784
1785 // map the omp.section reduction block argument to the omp.sections block
1786 // arguments
1787 // TODO: this assumes that the only block arguments are reduction
1788 // variables
1789 assert(region.getNumArguments() ==
1790 sectionsOp.getRegion().getNumArguments());
1791 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1792 t: sectionsOp.getRegion().getArguments(), u: region.getArguments())) {
1793 llvm::Value *llvmVal = moduleTranslation.lookupValue(value: sectionsArg);
1794 assert(llvmVal);
1795 moduleTranslation.mapValue(mlir: sectionArg, llvm: llvmVal);
1796 }
1797
1798 return convertOmpOpRegions(region, blockName: "omp.section.region", builder,
1799 moduleTranslation)
1800 .takeError();
1801 };
1802 sectionCBs.push_back(Elt: sectionCB);
1803 }
1804
1805 // No sections within omp.sections operation - skip generation. This situation
1806 // is only possible if there is only a terminator operation inside the
1807 // sections operation
1808 if (sectionCBs.empty())
1809 return success();
1810
1811 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1812
1813 // TODO: Perform appropriate actions according to the data-sharing
1814 // attribute (shared, private, firstprivate, ...) of variables.
1815 // Currently defaults to shared.
1816 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1817 llvm::Value &vPtr, llvm::Value *&replacementValue)
1818 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1819 replacementValue = &vPtr;
1820 return codeGenIP;
1821 };
1822
1823 // TODO: Perform finalization actions for variables. This has to be
1824 // called for variables which have destructors/finalizers.
1825 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1826
1827 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1828 bool isCancellable = constructIsCancellable(op: sectionsOp);
1829 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1830 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1831 moduleTranslation.getOpenMPBuilder()->createSections(
1832 Loc: ompLoc, AllocaIP: allocaIP, SectionCBs: sectionCBs, PrivCB: privCB, FiniCB: finiCB, IsCancellable: isCancellable,
1833 IsNowait: sectionsOp.getNowait());
1834
1835 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
1836 return failure();
1837
1838 builder.restoreIP(IP: *afterIP);
1839
1840 // Process the reductions if required.
1841 return createReductionsAndCleanup(
1842 op: sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1843 privateReductionVariables, isByRef, isNowait: sectionsOp.getNowait());
1844}
1845
1846/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1847static LogicalResult
1848convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1849 LLVM::ModuleTranslation &moduleTranslation) {
1850 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1851 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1852
1853 if (failed(Result: checkImplementationStatus(op&: *singleOp)))
1854 return failure();
1855
1856 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1857 builder.restoreIP(IP: codegenIP);
1858 return convertOmpOpRegions(region&: singleOp.getRegion(), blockName: "omp.single.region",
1859 builder, moduleTranslation)
1860 .takeError();
1861 };
1862 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1863
1864 // Handle copyprivate
1865 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1866 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1867 llvm::SmallVector<llvm::Value *> llvmCPVars;
1868 llvm::SmallVector<llvm::Function *> llvmCPFuncs;
1869 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1870 llvmCPVars.push_back(Elt: moduleTranslation.lookupValue(value: cpVars[i]));
1871 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1872 from: singleOp, symbol: cast<SymbolRefAttr>(Val: (*cpFuncs)[i]));
1873 llvmCPFuncs.push_back(
1874 Elt: moduleTranslation.lookupFunction(name: llvmFuncOp.getName()));
1875 }
1876
1877 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1878 moduleTranslation.getOpenMPBuilder()->createSingle(
1879 Loc: ompLoc, BodyGenCB: bodyCB, FiniCB: finiCB, IsNowait: singleOp.getNowait(), CPVars: llvmCPVars,
1880 CPFuncs: llvmCPFuncs);
1881
1882 if (failed(Result: handleError(result&: afterIP, op&: *singleOp)))
1883 return failure();
1884
1885 builder.restoreIP(IP: *afterIP);
1886 return success();
1887}
1888
1889static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1890 auto iface =
1891 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(Val: teamsOp.getOperation());
1892 // Check that all uses of the reduction block arg has the same distribute op
1893 // parent.
1894 llvm::SmallVector<mlir::Operation *> debugUses;
1895 Operation *distOp = nullptr;
1896 for (auto ra : iface.getReductionBlockArgs())
1897 for (auto &use : ra.getUses()) {
1898 auto *useOp = use.getOwner();
1899 // Ignore debug uses.
1900 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(Val: useOp)) {
1901 debugUses.push_back(Elt: useOp);
1902 continue;
1903 }
1904
1905 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1906 // Use is not inside a distribute op - return false
1907 if (!currentDistOp)
1908 return false;
1909 // Multiple distribute operations - return false
1910 Operation *currentOp = currentDistOp.getOperation();
1911 if (distOp && (distOp != currentOp))
1912 return false;
1913
1914 distOp = currentOp;
1915 }
1916
1917 // If we are going to use distribute reduction then remove any debug uses of
1918 // the reduction parameters in teamsOp. Otherwise they will be left without
1919 // any mapped value in moduleTranslation and will eventually error out.
1920 for (auto use : debugUses)
1921 use->erase();
1922 return true;
1923}
1924
1925// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
1926static LogicalResult
1927convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
1928 LLVM::ModuleTranslation &moduleTranslation) {
1929 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1930 if (failed(Result: checkImplementationStatus(op&: *op)))
1931 return failure();
1932
1933 DenseMap<Value, llvm::Value *> reductionVariableMap;
1934 unsigned numReductionVars = op.getNumReductionVars();
1935 SmallVector<omp::DeclareReductionOp> reductionDecls;
1936 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
1937 llvm::ArrayRef<bool> isByRef;
1938 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1939 findAllocaInsertPoint(builder, moduleTranslation);
1940
1941 // Only do teams reduction if there is no distribute op that captures the
1942 // reduction instead.
1943 bool doTeamsReduction = !teamsReductionContainedInDistribute(teamsOp: op);
1944 if (doTeamsReduction) {
1945 isByRef = getIsByRef(attr: op.getReductionByref());
1946
1947 assert(isByRef.size() == op.getNumReductionVars());
1948
1949 MutableArrayRef<BlockArgument> reductionArgs =
1950 llvm::cast<omp::BlockArgOpenMPOpInterface>(Val&: *op).getReductionBlockArgs();
1951
1952 collectReductionDecls(op, reductions&: reductionDecls);
1953
1954 if (failed(Result: allocAndInitializeReductionVars(
1955 op, reductionArgs, builder, moduleTranslation, allocaIP,
1956 reductionDecls, privateReductionVariables, reductionVariableMap,
1957 isByRef)))
1958 return failure();
1959 }
1960
1961 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1962 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1963 moduleTranslation, allocaIP);
1964 builder.restoreIP(IP: codegenIP);
1965 return convertOmpOpRegions(region&: op.getRegion(), blockName: "omp.teams.region", builder,
1966 moduleTranslation)
1967 .takeError();
1968 };
1969
1970 llvm::Value *numTeamsLower = nullptr;
1971 if (Value numTeamsLowerVar = op.getNumTeamsLower())
1972 numTeamsLower = moduleTranslation.lookupValue(value: numTeamsLowerVar);
1973
1974 llvm::Value *numTeamsUpper = nullptr;
1975 if (Value numTeamsUpperVar = op.getNumTeamsUpper())
1976 numTeamsUpper = moduleTranslation.lookupValue(value: numTeamsUpperVar);
1977
1978 llvm::Value *threadLimit = nullptr;
1979 if (Value threadLimitVar = op.getThreadLimit())
1980 threadLimit = moduleTranslation.lookupValue(value: threadLimitVar);
1981
1982 llvm::Value *ifExpr = nullptr;
1983 if (Value ifVar = op.getIfExpr())
1984 ifExpr = moduleTranslation.lookupValue(value: ifVar);
1985
1986 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1987 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1988 moduleTranslation.getOpenMPBuilder()->createTeams(
1989 Loc: ompLoc, BodyGenCB: bodyCB, NumTeamsLower: numTeamsLower, NumTeamsUpper: numTeamsUpper, ThreadLimit: threadLimit, IfExpr: ifExpr);
1990
1991 if (failed(Result: handleError(result&: afterIP, op&: *op)))
1992 return failure();
1993
1994 builder.restoreIP(IP: *afterIP);
1995 if (doTeamsReduction) {
1996 // Process the reductions if required.
1997 return createReductionsAndCleanup(
1998 op, builder, moduleTranslation, allocaIP, reductionDecls,
1999 privateReductionVariables, isByRef,
2000 /*isNoWait*/ isNowait: false, /*isTeamsReduction*/ true);
2001 }
2002 return success();
2003}
2004
2005static void
2006buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2007 LLVM::ModuleTranslation &moduleTranslation,
2008 SmallVectorImpl<llvm::OpenMPIRBuilder::DependData> &dds) {
2009 if (dependVars.empty())
2010 return;
2011 for (auto dep : llvm::zip(t&: dependVars, u: dependKinds->getValue())) {
2012 llvm::omp::RTLDependenceKindTy type;
2013 switch (
2014 cast<mlir::omp::ClauseTaskDependAttr>(Val: std::get<1>(t&: dep)).getValue()) {
2015 case mlir::omp::ClauseTaskDepend::taskdependin:
2016 type = llvm::omp::RTLDependenceKindTy::DepIn;
2017 break;
2018 // The OpenMP runtime requires that the codegen for 'depend' clause for
2019 // 'out' dependency kind must be the same as codegen for 'depend' clause
2020 // with 'inout' dependency.
2021 case mlir::omp::ClauseTaskDepend::taskdependout:
2022 case mlir::omp::ClauseTaskDepend::taskdependinout:
2023 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2024 break;
2025 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2026 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2027 break;
2028 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2029 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2030 break;
2031 };
2032 llvm::Value *depVal = moduleTranslation.lookupValue(value: std::get<0>(t&: dep));
2033 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2034 dds.emplace_back(Args&: dd);
2035 }
2036}
2037
2038/// Shared implementation of a callback which adds a termiator for the new block
2039/// created for the branch taken when an openmp construct is cancelled. The
2040/// terminator is saved in \p cancelTerminators. This callback is invoked only
2041/// if there is cancellation inside of the taskgroup body.
2042/// The terminator will need to be fixed to branch to the correct block to
2043/// cleanup the construct.
2044static void
2045pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
2046 llvm::IRBuilderBase &llvmBuilder,
2047 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2048 llvm::omp::Directive cancelDirective) {
2049 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2050 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2051
2052 // ip is currently in the block branched to if cancellation occured.
2053 // We need to create a branch to terminate that block.
2054 llvmBuilder.restoreIP(IP: ip);
2055
2056 // We must still clean up the construct after cancelling it, so we need to
2057 // branch to the block that finalizes the taskgroup.
2058 // That block has not been created yet so use this block as a dummy for now
2059 // and fix this after creating the operation.
2060 cancelTerminators.push_back(Elt: llvmBuilder.CreateBr(Dest: ip.getBlock()));
2061 return llvm::Error::success();
2062 };
2063 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2064 // created in case the body contains omp.cancel (which will then expect to be
2065 // able to find this cleanup callback).
2066 ompBuilder.pushFinalizationCB(
2067 FI: {.FiniCB: finiCB, .DK: cancelDirective, .IsCancellable: constructIsCancellable(op)});
2068}
2069
2070/// If we cancelled the construct, we should branch to the finalization block of
2071/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2072/// is immediately before the continuation block. Now this finalization has
2073/// been created we can fix the branch.
2074static void
2075popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
2076 llvm::OpenMPIRBuilder &ompBuilder,
2077 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2078 ompBuilder.popFinalizationCB();
2079 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2080 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2081 assert(cancelBranch->getNumSuccessors() == 1 &&
2082 "cancel branch should have one target");
2083 cancelBranch->setSuccessor(idx: 0, NewSucc: constructFini);
2084 }
2085}
2086
2087namespace {
2088/// TaskContextStructManager takes care of creating and freeing a structure
2089/// containing information needed by the task body to execute.
2090class TaskContextStructManager {
2091public:
2092 TaskContextStructManager(llvm::IRBuilderBase &builder,
2093 LLVM::ModuleTranslation &moduleTranslation,
2094 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2095 : builder{builder}, moduleTranslation{moduleTranslation},
2096 privateDecls{privateDecls} {}
2097
2098 /// Creates a heap allocated struct containing space for each private
2099 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2100 /// the structure should all have the same order (although privateDecls which
2101 /// do not read from the mold argument are skipped).
2102 void generateTaskContextStruct();
2103
2104 /// Create GEPs to access each member of the structure representing a private
2105 /// variable, adding them to llvmPrivateVars. Null values are added where
2106 /// private decls were skipped so that the ordering continues to match the
2107 /// private decls.
2108 void createGEPsToPrivateVars();
2109
2110 /// De-allocate the task context structure.
2111 void freeStructPtr();
2112
2113 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2114 return llvmPrivateVarGEPs;
2115 }
2116
2117 llvm::Value *getStructPtr() { return structPtr; }
2118
2119private:
2120 llvm::IRBuilderBase &builder;
2121 LLVM::ModuleTranslation &moduleTranslation;
2122 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2123
2124 /// The type of each member of the structure, in order.
2125 SmallVector<llvm::Type *> privateVarTypes;
2126
2127 /// LLVM values for each private variable, or null if that private variable is
2128 /// not included in the task context structure
2129 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2130
2131 /// A pointer to the structure containing context for this task.
2132 llvm::Value *structPtr = nullptr;
2133 /// The type of the structure
2134 llvm::Type *structTy = nullptr;
2135};
2136} // namespace
2137
2138void TaskContextStructManager::generateTaskContextStruct() {
2139 if (privateDecls.empty())
2140 return;
2141 privateVarTypes.reserve(N: privateDecls.size());
2142
2143 for (omp::PrivateClauseOp &privOp : privateDecls) {
2144 // Skip private variables which can safely be allocated and initialised
2145 // inside of the task
2146 if (!privOp.readsFromMold())
2147 continue;
2148 Type mlirType = privOp.getType();
2149 privateVarTypes.push_back(Elt: moduleTranslation.convertType(type: mlirType));
2150 }
2151
2152 structTy = llvm::StructType::get(Context&: moduleTranslation.getLLVMContext(),
2153 Elements: privateVarTypes);
2154
2155 llvm::DataLayout dataLayout =
2156 builder.GetInsertBlock()->getModule()->getDataLayout();
2157 llvm::Type *intPtrTy = builder.getIntPtrTy(DL: dataLayout);
2158 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(Ty: structTy);
2159
2160 // Heap allocate the structure
2161 structPtr = builder.CreateMalloc(IntPtrTy: intPtrTy, AllocTy: structTy, AllocSize: allocSize,
2162 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2163 Name: "omp.task.context_ptr");
2164}
2165
2166void TaskContextStructManager::createGEPsToPrivateVars() {
2167 if (!structPtr) {
2168 assert(privateVarTypes.empty());
2169 return;
2170 }
2171
2172 // Create GEPs for each struct member
2173 llvmPrivateVarGEPs.clear();
2174 llvmPrivateVarGEPs.reserve(N: privateDecls.size());
2175 llvm::Value *zero = builder.getInt32(C: 0);
2176 unsigned i = 0;
2177 for (auto privDecl : privateDecls) {
2178 if (!privDecl.readsFromMold()) {
2179 // Handle this inside of the task so we don't pass unnessecary vars in
2180 llvmPrivateVarGEPs.push_back(Elt: nullptr);
2181 continue;
2182 }
2183 llvm::Value *iVal = builder.getInt32(C: i);
2184 llvm::Value *gep = builder.CreateGEP(Ty: structTy, Ptr: structPtr, IdxList: {zero, iVal});
2185 llvmPrivateVarGEPs.push_back(Elt: gep);
2186 i += 1;
2187 }
2188}
2189
2190void TaskContextStructManager::freeStructPtr() {
2191 if (!structPtr)
2192 return;
2193
2194 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2195 // Ensure we don't put the call to free() after the terminator
2196 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2197 builder.CreateFree(Source: structPtr);
2198}
2199
2200/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2201static LogicalResult
2202convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2203 LLVM::ModuleTranslation &moduleTranslation) {
2204 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2205 if (failed(Result: checkImplementationStatus(op&: *taskOp)))
2206 return failure();
2207
2208 PrivateVarsInfo privateVarsInfo(taskOp);
2209 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2210 privateVarsInfo.privatizers};
2211
2212 // Allocate and copy private variables before creating the task. This avoids
2213 // accessing invalid memory if (after this scope ends) the private variables
2214 // are initialized from host variables or if the variables are copied into
2215 // from host variables (firstprivate). The insertion point is just before
2216 // where the code for creating and scheduling the task will go. That puts this
2217 // code outside of the outlined task region, which is what we want because
2218 // this way the initialization and copy regions are executed immediately while
2219 // the host variable data are still live.
2220
2221 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2222 findAllocaInsertPoint(builder, moduleTranslation);
2223
2224 // Not using splitBB() because that requires the current block to have a
2225 // terminator.
2226 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2227 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2228 Context&: builder.getContext(), Name: "omp.task.start",
2229 /*Parent=*/builder.GetInsertBlock()->getParent());
2230 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(Dest: taskStartBlock);
2231 builder.SetInsertPoint(branchToTaskStartBlock);
2232
2233 // Now do this again to make the initialization and copy blocks
2234 llvm::BasicBlock *copyBlock =
2235 splitBB(Builder&: builder, /*CreateBranch=*/true, Name: "omp.private.copy");
2236 llvm::BasicBlock *initBlock =
2237 splitBB(Builder&: builder, /*CreateBranch=*/true, Name: "omp.private.init");
2238
2239 // Now the control flow graph should look like
2240 // starter_block:
2241 // <---- where we started when convertOmpTaskOp was called
2242 // br %omp.private.init
2243 // omp.private.init:
2244 // br %omp.private.copy
2245 // omp.private.copy:
2246 // br %omp.task.start
2247 // omp.task.start:
2248 // <---- where we want the insertion point to be when we call createTask()
2249
2250 // Save the alloca insertion point on ModuleTranslation stack for use in
2251 // nested regions.
2252 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2253 moduleTranslation, allocaIP);
2254
2255 // Allocate and initialize private variables
2256 builder.SetInsertPoint(initBlock->getTerminator());
2257
2258 // Create task variable structure
2259 taskStructMgr.generateTaskContextStruct();
2260 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2261 // of the body otherwise it will be the GEP not the struct which is fowarded
2262 // to the outlined function. GEPs forwarded in this way are passed in a
2263 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2264 // which may not be executed until after the current stack frame goes out of
2265 // scope.
2266 taskStructMgr.createGEPsToPrivateVars();
2267
2268 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2269 llvm::zip_equal(t&: privateVarsInfo.privatizers, u&: privateVarsInfo.mlirVars,
2270 args&: privateVarsInfo.blockArgs,
2271 args: taskStructMgr.getLLVMPrivateVarGEPs())) {
2272 // To be handled inside the task.
2273 if (!privDecl.readsFromMold())
2274 continue;
2275 assert(llvmPrivateVarAlloc &&
2276 "reads from mold so shouldn't have been skipped");
2277
2278 llvm::Expected<llvm::Value *> privateVarOrErr =
2279 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2280 blockArg, llvmPrivateVar: llvmPrivateVarAlloc, privInitBlock: initBlock);
2281 if (!privateVarOrErr)
2282 return handleError(result&: privateVarOrErr, op&: *taskOp.getOperation());
2283
2284 setInsertPointForPossiblyEmptyBlock(builder);
2285
2286 // TODO: this is a bit of a hack for Fortran character boxes.
2287 // Character boxes are passed by value into the init region and then the
2288 // initialized character box is yielded by value. Here we need to store the
2289 // yielded value into the private allocation, and load the private
2290 // allocation to match the type expected by region block arguments.
2291 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2292 !mlir::isa<LLVM::LLVMPointerType>(Val: blockArg.getType())) {
2293 builder.CreateStore(Val: privateVarOrErr.get(), Ptr: llvmPrivateVarAlloc);
2294 // Load it so we have the value pointed to by the GEP
2295 llvmPrivateVarAlloc = builder.CreateLoad(Ty: privateVarOrErr.get()->getType(),
2296 Ptr: llvmPrivateVarAlloc);
2297 }
2298 assert(llvmPrivateVarAlloc->getType() ==
2299 moduleTranslation.convertType(blockArg.getType()));
2300
2301 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2302 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2303 // stack allocated structure.
2304 }
2305
2306 // firstprivate copy region
2307 setInsertPointForPossiblyEmptyBlock(builder, block: copyBlock);
2308 if (failed(Result: copyFirstPrivateVars(
2309 op: taskOp, builder, moduleTranslation, mlirPrivateVars&: privateVarsInfo.mlirVars,
2310 llvmPrivateVars: taskStructMgr.getLLVMPrivateVarGEPs(), privateDecls&: privateVarsInfo.privatizers,
2311 insertBarrier: taskOp.getPrivateNeedsBarrier())))
2312 return llvm::failure();
2313
2314 // Set up for call to createTask()
2315 builder.SetInsertPoint(taskStartBlock);
2316
2317 auto bodyCB = [&](InsertPointTy allocaIP,
2318 InsertPointTy codegenIP) -> llvm::Error {
2319 // Save the alloca insertion point on ModuleTranslation stack for use in
2320 // nested regions.
2321 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2322 moduleTranslation, allocaIP);
2323
2324 // translate the body of the task:
2325 builder.restoreIP(IP: codegenIP);
2326
2327 llvm::BasicBlock *privInitBlock = nullptr;
2328 privateVarsInfo.llvmVars.resize(N: privateVarsInfo.blockArgs.size());
2329 for (auto [i, zip] : llvm::enumerate(First: llvm::zip_equal(
2330 t&: privateVarsInfo.blockArgs, u&: privateVarsInfo.privatizers,
2331 args&: privateVarsInfo.mlirVars))) {
2332 auto [blockArg, privDecl, mlirPrivVar] = zip;
2333 // This is handled before the task executes
2334 if (privDecl.readsFromMold())
2335 continue;
2336
2337 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2338 llvm::Type *llvmAllocType =
2339 moduleTranslation.convertType(type: privDecl.getType());
2340 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2341 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2342 Ty: llvmAllocType, /*ArraySize=*/nullptr, Name: "omp.private.alloc");
2343
2344 llvm::Expected<llvm::Value *> privateVarOrError =
2345 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2346 blockArg, llvmPrivateVar, privInitBlock);
2347 if (!privateVarOrError)
2348 return privateVarOrError.takeError();
2349 moduleTranslation.mapValue(mlir: blockArg, llvm: privateVarOrError.get());
2350 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2351 }
2352
2353 taskStructMgr.createGEPsToPrivateVars();
2354 for (auto [i, llvmPrivVar] :
2355 llvm::enumerate(First: taskStructMgr.getLLVMPrivateVarGEPs())) {
2356 if (!llvmPrivVar) {
2357 assert(privateVarsInfo.llvmVars[i] &&
2358 "This is added in the loop above");
2359 continue;
2360 }
2361 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2362 }
2363
2364 // Find and map the addresses of each variable within the task context
2365 // structure
2366 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2367 llvm::zip_equal(t&: privateVarsInfo.blockArgs, u&: privateVarsInfo.llvmVars,
2368 args&: privateVarsInfo.privatizers)) {
2369 // This was handled above.
2370 if (!privateDecl.readsFromMold())
2371 continue;
2372 // Fix broken pass-by-value case for Fortran character boxes
2373 if (!mlir::isa<LLVM::LLVMPointerType>(Val: blockArg.getType())) {
2374 llvmPrivateVar = builder.CreateLoad(
2375 Ty: moduleTranslation.convertType(type: blockArg.getType()), Ptr: llvmPrivateVar);
2376 }
2377 assert(llvmPrivateVar->getType() ==
2378 moduleTranslation.convertType(blockArg.getType()));
2379 moduleTranslation.mapValue(mlir: blockArg, llvm: llvmPrivateVar);
2380 }
2381
2382 auto continuationBlockOrError = convertOmpOpRegions(
2383 region&: taskOp.getRegion(), blockName: "omp.task.region", builder, moduleTranslation);
2384 if (failed(Result: handleError(result&: continuationBlockOrError, op&: *taskOp)))
2385 return llvm::make_error<PreviouslyReportedError>();
2386
2387 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2388
2389 if (failed(Result: cleanupPrivateVars(builder, moduleTranslation, loc: taskOp.getLoc(),
2390 llvmPrivateVars&: privateVarsInfo.llvmVars,
2391 privateDecls&: privateVarsInfo.privatizers)))
2392 return llvm::make_error<PreviouslyReportedError>();
2393
2394 // Free heap allocated task context structure at the end of the task.
2395 taskStructMgr.freeStructPtr();
2396
2397 return llvm::Error::success();
2398 };
2399
2400 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2401 SmallVector<llvm::BranchInst *> cancelTerminators;
2402 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2403 // which is canceled. This is handled here because it is the task's cleanup
2404 // block which should be branched to.
2405 pushCancelFinalizationCB(cancelTerminators, llvmBuilder&: builder, ompBuilder, op: taskOp,
2406 cancelDirective: llvm::omp::Directive::OMPD_taskgroup);
2407
2408 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
2409 buildDependData(dependKinds: taskOp.getDependKinds(), dependVars: taskOp.getDependVars(),
2410 moduleTranslation, dds);
2411
2412 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2413 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2414 moduleTranslation.getOpenMPBuilder()->createTask(
2415 Loc: ompLoc, AllocaIP: allocaIP, BodyGenCB: bodyCB, Tied: !taskOp.getUntied(),
2416 Final: moduleTranslation.lookupValue(value: taskOp.getFinal()),
2417 IfCondition: moduleTranslation.lookupValue(value: taskOp.getIfExpr()), Dependencies: dds,
2418 Mergeable: taskOp.getMergeable(),
2419 EventHandle: moduleTranslation.lookupValue(value: taskOp.getEventHandle()),
2420 Priority: moduleTranslation.lookupValue(value: taskOp.getPriority()));
2421
2422 if (failed(Result: handleError(result&: afterIP, op&: *taskOp)))
2423 return failure();
2424
2425 // Set the correct branch target for task cancellation
2426 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP: afterIP.get());
2427
2428 builder.restoreIP(IP: *afterIP);
2429 return success();
2430}
2431
2432/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2433static LogicalResult
2434convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2435 LLVM::ModuleTranslation &moduleTranslation) {
2436 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2437 if (failed(Result: checkImplementationStatus(op&: *tgOp)))
2438 return failure();
2439
2440 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2441 builder.restoreIP(IP: codegenIP);
2442 return convertOmpOpRegions(region&: tgOp.getRegion(), blockName: "omp.taskgroup.region",
2443 builder, moduleTranslation)
2444 .takeError();
2445 };
2446
2447 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2448 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2449 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2450 moduleTranslation.getOpenMPBuilder()->createTaskgroup(Loc: ompLoc, AllocaIP: allocaIP,
2451 BodyGenCB: bodyCB);
2452
2453 if (failed(Result: handleError(result&: afterIP, op&: *tgOp)))
2454 return failure();
2455
2456 builder.restoreIP(IP: *afterIP);
2457 return success();
2458}
2459
2460static LogicalResult
2461convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2462 LLVM::ModuleTranslation &moduleTranslation) {
2463 if (failed(Result: checkImplementationStatus(op&: *twOp)))
2464 return failure();
2465
2466 moduleTranslation.getOpenMPBuilder()->createTaskwait(Loc: builder.saveIP());
2467 return success();
2468}
2469
2470/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2471static LogicalResult
2472convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2473 LLVM::ModuleTranslation &moduleTranslation) {
2474 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2475 auto wsloopOp = cast<omp::WsloopOp>(Val&: opInst);
2476 if (failed(Result: checkImplementationStatus(op&: opInst)))
2477 return failure();
2478
2479 auto loopOp = cast<omp::LoopNestOp>(Val: wsloopOp.getWrappedLoop());
2480 llvm::ArrayRef<bool> isByRef = getIsByRef(attr: wsloopOp.getReductionByref());
2481 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2482
2483 // Static is the default.
2484 auto schedule =
2485 wsloopOp.getScheduleKind().value_or(u: omp::ClauseScheduleKind::Static);
2486
2487 // Find the loop configuration.
2488 llvm::Value *step = moduleTranslation.lookupValue(value: loopOp.getLoopSteps()[0]);
2489 llvm::Type *ivType = step->getType();
2490 llvm::Value *chunk = nullptr;
2491 if (wsloopOp.getScheduleChunk()) {
2492 llvm::Value *chunkVar =
2493 moduleTranslation.lookupValue(value: wsloopOp.getScheduleChunk());
2494 chunk = builder.CreateSExtOrTrunc(V: chunkVar, DestTy: ivType);
2495 }
2496
2497 PrivateVarsInfo privateVarsInfo(wsloopOp);
2498
2499 SmallVector<omp::DeclareReductionOp> reductionDecls;
2500 collectReductionDecls(op: wsloopOp, reductions&: reductionDecls);
2501 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2502 findAllocaInsertPoint(builder, moduleTranslation);
2503
2504 SmallVector<llvm::Value *> privateReductionVariables(
2505 wsloopOp.getNumReductionVars());
2506
2507 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2508 builder, moduleTranslation, privateVarsInfo, allocaIP);
2509 if (handleError(result&: afterAllocas, op&: opInst).failed())
2510 return failure();
2511
2512 DenseMap<Value, llvm::Value *> reductionVariableMap;
2513
2514 MutableArrayRef<BlockArgument> reductionArgs =
2515 cast<omp::BlockArgOpenMPOpInterface>(Val&: opInst).getReductionBlockArgs();
2516
2517 SmallVector<DeferredStore> deferredStores;
2518
2519 if (failed(Result: allocReductionVars(loop: wsloopOp, reductionArgs, builder,
2520 moduleTranslation, allocaIP, reductionDecls,
2521 privateReductionVariables, reductionVariableMap,
2522 deferredStores, isByRefs: isByRef)))
2523 return failure();
2524
2525 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2526 op&: opInst)
2527 .failed())
2528 return failure();
2529
2530 if (failed(Result: copyFirstPrivateVars(
2531 op: wsloopOp, builder, moduleTranslation, mlirPrivateVars&: privateVarsInfo.mlirVars,
2532 llvmPrivateVars: privateVarsInfo.llvmVars, privateDecls&: privateVarsInfo.privatizers,
2533 insertBarrier: wsloopOp.getPrivateNeedsBarrier())))
2534 return failure();
2535
2536 assert(afterAllocas.get()->getSinglePredecessor());
2537 if (failed(Result: initReductionVars(op: wsloopOp, reductionArgs, builder,
2538 moduleTranslation,
2539 latestAllocaBlock: afterAllocas.get()->getSinglePredecessor(),
2540 reductionDecls, privateReductionVariables,
2541 reductionVariableMap, isByRef, deferredStores)))
2542 return failure();
2543
2544 // TODO: Handle doacross loops when the ordered clause has a parameter.
2545 bool isOrdered = wsloopOp.getOrdered().has_value();
2546 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2547 bool isSimd = wsloopOp.getScheduleSimd();
2548 bool loopNeedsBarrier = !wsloopOp.getNowait();
2549
2550 // The only legal way for the direct parent to be omp.distribute is that this
2551 // represents 'distribute parallel do'. Otherwise, this is a regular
2552 // worksharing loop.
2553 llvm::omp::WorksharingLoopType workshareLoopType =
2554 llvm::isa_and_present<omp::DistributeOp>(Val: opInst.getParentOp())
2555 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2556 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2557
2558 SmallVector<llvm::BranchInst *> cancelTerminators;
2559 pushCancelFinalizationCB(cancelTerminators, llvmBuilder&: builder, ompBuilder&: *ompBuilder, op: wsloopOp,
2560 cancelDirective: llvm::omp::Directive::OMPD_for);
2561
2562 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2563
2564 // Initialize linear variables and linear step
2565 LinearClauseProcessor linearClauseProcessor;
2566 if (wsloopOp.getLinearVars().size()) {
2567 for (mlir::Value linearVar : wsloopOp.getLinearVars())
2568 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2569 linearVar);
2570 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
2571 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2572 }
2573
2574 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2575 region&: wsloopOp.getRegion(), blockName: "omp.wsloop.region", builder, moduleTranslation);
2576
2577 if (failed(Result: handleError(result&: regionBlock, op&: opInst)))
2578 return failure();
2579
2580 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2581
2582 // Emit Initialization and Update IR for linear variables
2583 if (wsloopOp.getLinearVars().size()) {
2584 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2585 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2586 loopPreHeader: loopInfo->getPreheader());
2587 if (failed(Result: handleError(result&: afterBarrierIP, op&: *loopOp)))
2588 return failure();
2589 builder.restoreIP(IP: *afterBarrierIP);
2590 linearClauseProcessor.updateLinearVar(builder, loopBody: loopInfo->getBody(),
2591 loopInductionVar: loopInfo->getIndVar());
2592 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2593 loopExit: loopInfo->getExit());
2594 }
2595
2596 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
2597 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2598 ompBuilder->applyWorkshareLoop(
2599 DL: ompLoc.DL, CLI: loopInfo, AllocaIP: allocaIP, NeedsBarrier: loopNeedsBarrier,
2600 SchedKind: convertToScheduleKind(schedKind: schedule), ChunkSize: chunk, HasSimdModifier: isSimd,
2601 HasMonotonicModifier: scheduleMod == omp::ScheduleModifier::monotonic,
2602 HasNonmonotonicModifier: scheduleMod == omp::ScheduleModifier::nonmonotonic, HasOrderedClause: isOrdered,
2603 LoopType: workshareLoopType);
2604
2605 if (failed(Result: handleError(result&: wsloopIP, op&: opInst)))
2606 return failure();
2607
2608 // Emit finalization and in-place rewrites for linear vars.
2609 if (wsloopOp.getLinearVars().size()) {
2610 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2611 assert(loopInfo->getLastIter() &&
2612 "`lastiter` in CanonicalLoopInfo is nullptr");
2613 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2614 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2615 lastIter: loopInfo->getLastIter());
2616 if (failed(Result: handleError(result&: afterBarrierIP, op&: *loopOp)))
2617 return failure();
2618 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2619 linearClauseProcessor.rewriteInPlace(builder, BBName: "omp.loop_nest.region",
2620 varIndex: index);
2621 builder.restoreIP(IP: oldIP);
2622 }
2623
2624 // Set the correct branch target for task cancellation
2625 popCancelFinalizationCB(cancelTerminators, ompBuilder&: *ompBuilder, afterIP: wsloopIP.get());
2626
2627 // Process the reductions if required.
2628 if (failed(Result: createReductionsAndCleanup(
2629 op: wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2630 privateReductionVariables, isByRef, isNowait: wsloopOp.getNowait(),
2631 /*isTeamsReduction=*/false)))
2632 return failure();
2633
2634 return cleanupPrivateVars(builder, moduleTranslation, loc: wsloopOp.getLoc(),
2635 llvmPrivateVars&: privateVarsInfo.llvmVars,
2636 privateDecls&: privateVarsInfo.privatizers);
2637}
2638
2639/// Converts the OpenMP parallel operation to LLVM IR.
2640static LogicalResult
2641convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2642 LLVM::ModuleTranslation &moduleTranslation) {
2643 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2644 ArrayRef<bool> isByRef = getIsByRef(attr: opInst.getReductionByref());
2645 assert(isByRef.size() == opInst.getNumReductionVars());
2646 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2647
2648 if (failed(Result: checkImplementationStatus(op&: *opInst)))
2649 return failure();
2650
2651 PrivateVarsInfo privateVarsInfo(opInst);
2652
2653 // Collect reduction declarations
2654 SmallVector<omp::DeclareReductionOp> reductionDecls;
2655 collectReductionDecls(op: opInst, reductions&: reductionDecls);
2656 SmallVector<llvm::Value *> privateReductionVariables(
2657 opInst.getNumReductionVars());
2658 SmallVector<DeferredStore> deferredStores;
2659
2660 auto bodyGenCB = [&](InsertPointTy allocaIP,
2661 InsertPointTy codeGenIP) -> llvm::Error {
2662 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2663 builder, moduleTranslation, privateVarsInfo, allocaIP);
2664 if (handleError(result&: afterAllocas, op&: *opInst).failed())
2665 return llvm::make_error<PreviouslyReportedError>();
2666
2667 // Allocate reduction vars
2668 DenseMap<Value, llvm::Value *> reductionVariableMap;
2669
2670 MutableArrayRef<BlockArgument> reductionArgs =
2671 cast<omp::BlockArgOpenMPOpInterface>(Val&: *opInst).getReductionBlockArgs();
2672
2673 allocaIP =
2674 InsertPointTy(allocaIP.getBlock(),
2675 allocaIP.getBlock()->getTerminator()->getIterator());
2676
2677 if (failed(Result: allocReductionVars(
2678 loop: opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2679 reductionDecls, privateReductionVariables, reductionVariableMap,
2680 deferredStores, isByRefs: isByRef)))
2681 return llvm::make_error<PreviouslyReportedError>();
2682
2683 assert(afterAllocas.get()->getSinglePredecessor());
2684 builder.restoreIP(IP: codeGenIP);
2685
2686 if (handleError(
2687 error: initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2688 op&: *opInst)
2689 .failed())
2690 return llvm::make_error<PreviouslyReportedError>();
2691
2692 if (failed(Result: copyFirstPrivateVars(
2693 op: opInst, builder, moduleTranslation, mlirPrivateVars&: privateVarsInfo.mlirVars,
2694 llvmPrivateVars: privateVarsInfo.llvmVars, privateDecls&: privateVarsInfo.privatizers,
2695 insertBarrier: opInst.getPrivateNeedsBarrier())))
2696 return llvm::make_error<PreviouslyReportedError>();
2697
2698 if (failed(
2699 Result: initReductionVars(op: opInst, reductionArgs, builder, moduleTranslation,
2700 latestAllocaBlock: afterAllocas.get()->getSinglePredecessor(),
2701 reductionDecls, privateReductionVariables,
2702 reductionVariableMap, isByRef, deferredStores)))
2703 return llvm::make_error<PreviouslyReportedError>();
2704
2705 // Save the alloca insertion point on ModuleTranslation stack for use in
2706 // nested regions.
2707 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2708 moduleTranslation, allocaIP);
2709
2710 // ParallelOp has only one region associated with it.
2711 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2712 region&: opInst.getRegion(), blockName: "omp.par.region", builder, moduleTranslation);
2713 if (!regionBlock)
2714 return regionBlock.takeError();
2715
2716 // Process the reductions if required.
2717 if (opInst.getNumReductionVars() > 0) {
2718 // Collect reduction info
2719 SmallVector<OwningReductionGen> owningReductionGens;
2720 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2721 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2722 collectReductionInfo(loop: opInst, builder, moduleTranslation, reductionDecls,
2723 owningReductionGens, owningAtomicReductionGens,
2724 privateReductionVariables, reductionInfos);
2725
2726 // Move to region cont block
2727 builder.SetInsertPoint((*regionBlock)->getTerminator());
2728
2729 // Generate reductions from info
2730 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2731 builder.SetInsertPoint(tempTerminator);
2732
2733 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2734 ompBuilder->createReductions(
2735 Loc: builder.saveIP(), AllocaIP: allocaIP, ReductionInfos: reductionInfos, IsByRef: isByRef,
2736 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
2737 if (!contInsertPoint)
2738 return contInsertPoint.takeError();
2739
2740 if (!contInsertPoint->getBlock())
2741 return llvm::make_error<PreviouslyReportedError>();
2742
2743 tempTerminator->eraseFromParent();
2744 builder.restoreIP(IP: *contInsertPoint);
2745 }
2746
2747 return llvm::Error::success();
2748 };
2749
2750 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2751 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2752 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
2753 // bodyGenCB.
2754 replVal = &val;
2755 return codeGenIP;
2756 };
2757
2758 // TODO: Perform finalization actions for variables. This has to be
2759 // called for variables which have destructors/finalizers.
2760 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2761 InsertPointTy oldIP = builder.saveIP();
2762 builder.restoreIP(IP: codeGenIP);
2763
2764 // if the reduction has a cleanup region, inline it here to finalize the
2765 // reduction variables
2766 SmallVector<Region *> reductionCleanupRegions;
2767 llvm::transform(Range&: reductionDecls, d_first: std::back_inserter(x&: reductionCleanupRegions),
2768 F: [](omp::DeclareReductionOp reductionDecl) {
2769 return &reductionDecl.getCleanupRegion();
2770 });
2771 if (failed(Result: inlineOmpRegionCleanup(
2772 cleanupRegions&: reductionCleanupRegions, privateVariables: privateReductionVariables,
2773 moduleTranslation, builder, regionName: "omp.reduction.cleanup")))
2774 return llvm::createStringError(
2775 Fmt: "failed to inline `cleanup` region of `omp.declare_reduction`");
2776
2777 if (failed(Result: cleanupPrivateVars(builder, moduleTranslation, loc: opInst.getLoc(),
2778 llvmPrivateVars&: privateVarsInfo.llvmVars,
2779 privateDecls&: privateVarsInfo.privatizers)))
2780 return llvm::make_error<PreviouslyReportedError>();
2781
2782 builder.restoreIP(IP: oldIP);
2783 return llvm::Error::success();
2784 };
2785
2786 llvm::Value *ifCond = nullptr;
2787 if (auto ifVar = opInst.getIfExpr())
2788 ifCond = moduleTranslation.lookupValue(value: ifVar);
2789 llvm::Value *numThreads = nullptr;
2790 if (auto numThreadsVar = opInst.getNumThreads())
2791 numThreads = moduleTranslation.lookupValue(value: numThreadsVar);
2792 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2793 if (auto bind = opInst.getProcBindKind())
2794 pbKind = getProcBindKind(kind: *bind);
2795 bool isCancellable = constructIsCancellable(op: opInst);
2796
2797 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2798 findAllocaInsertPoint(builder, moduleTranslation);
2799 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2800
2801 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2802 ompBuilder->createParallel(Loc: ompLoc, AllocaIP: allocaIP, BodyGenCB: bodyGenCB, PrivCB: privCB, FiniCB: finiCB,
2803 IfCondition: ifCond, NumThreads: numThreads, ProcBind: pbKind, IsCancellable: isCancellable);
2804
2805 if (failed(Result: handleError(result&: afterIP, op&: *opInst)))
2806 return failure();
2807
2808 builder.restoreIP(IP: *afterIP);
2809 return success();
2810}
2811
2812/// Convert Order attribute to llvm::omp::OrderKind.
2813static llvm::omp::OrderKind
2814convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2815 if (!o)
2816 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2817 switch (*o) {
2818 case omp::ClauseOrderKind::Concurrent:
2819 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2820 }
2821 llvm_unreachable("Unknown ClauseOrderKind kind");
2822}
2823
2824/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2825static LogicalResult
2826convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2827 LLVM::ModuleTranslation &moduleTranslation) {
2828 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2829 auto simdOp = cast<omp::SimdOp>(Val&: opInst);
2830
2831 if (failed(Result: checkImplementationStatus(op&: opInst)))
2832 return failure();
2833
2834 PrivateVarsInfo privateVarsInfo(simdOp);
2835
2836 MutableArrayRef<BlockArgument> reductionArgs =
2837 cast<omp::BlockArgOpenMPOpInterface>(Val&: opInst).getReductionBlockArgs();
2838 DenseMap<Value, llvm::Value *> reductionVariableMap;
2839 SmallVector<llvm::Value *> privateReductionVariables(
2840 simdOp.getNumReductionVars());
2841 SmallVector<DeferredStore> deferredStores;
2842 SmallVector<omp::DeclareReductionOp> reductionDecls;
2843 collectReductionDecls(op: simdOp, reductions&: reductionDecls);
2844 llvm::ArrayRef<bool> isByRef = getIsByRef(attr: simdOp.getReductionByref());
2845 assert(isByRef.size() == simdOp.getNumReductionVars());
2846
2847 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2848 findAllocaInsertPoint(builder, moduleTranslation);
2849
2850 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2851 builder, moduleTranslation, privateVarsInfo, allocaIP);
2852 if (handleError(result&: afterAllocas, op&: opInst).failed())
2853 return failure();
2854
2855 if (failed(Result: allocReductionVars(loop: simdOp, reductionArgs, builder,
2856 moduleTranslation, allocaIP, reductionDecls,
2857 privateReductionVariables, reductionVariableMap,
2858 deferredStores, isByRefs: isByRef)))
2859 return failure();
2860
2861 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2862 op&: opInst)
2863 .failed())
2864 return failure();
2865
2866 // No call to copyFirstPrivateVars because FIRSTPRIVATE is not allowed for
2867 // SIMD.
2868
2869 assert(afterAllocas.get()->getSinglePredecessor());
2870 if (failed(Result: initReductionVars(op: simdOp, reductionArgs, builder,
2871 moduleTranslation,
2872 latestAllocaBlock: afterAllocas.get()->getSinglePredecessor(),
2873 reductionDecls, privateReductionVariables,
2874 reductionVariableMap, isByRef, deferredStores)))
2875 return failure();
2876
2877 llvm::ConstantInt *simdlen = nullptr;
2878 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2879 simdlen = builder.getInt64(C: simdlenVar.value());
2880
2881 llvm::ConstantInt *safelen = nullptr;
2882 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2883 safelen = builder.getInt64(C: safelenVar.value());
2884
2885 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2886 llvm::omp::OrderKind order = convertOrderKind(o: simdOp.getOrder());
2887
2888 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2889 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2890 mlir::OperandRange operands = simdOp.getAlignedVars();
2891 for (size_t i = 0; i < operands.size(); ++i) {
2892 llvm::Value *alignment = nullptr;
2893 llvm::Value *llvmVal = moduleTranslation.lookupValue(value: operands[i]);
2894 llvm::Type *ty = llvmVal->getType();
2895
2896 auto intAttr = cast<IntegerAttr>(Val: (*alignmentValues)[i]);
2897 alignment = builder.getInt64(C: intAttr.getInt());
2898 assert(ty->isPointerTy() && "Invalid type for aligned variable");
2899 assert(alignment && "Invalid alignment value");
2900 auto curInsert = builder.saveIP();
2901 builder.SetInsertPoint(sourceBlock);
2902 llvmVal = builder.CreateLoad(Ty: ty, Ptr: llvmVal);
2903 builder.restoreIP(IP: curInsert);
2904 alignedVars[llvmVal] = alignment;
2905 }
2906
2907 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2908 region&: simdOp.getRegion(), blockName: "omp.simd.region", builder, moduleTranslation);
2909
2910 if (failed(Result: handleError(result&: regionBlock, op&: opInst)))
2911 return failure();
2912
2913 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
2914 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2915 ompBuilder->applySimd(Loop: loopInfo, AlignedVars: alignedVars,
2916 IfCond: simdOp.getIfExpr()
2917 ? moduleTranslation.lookupValue(value: simdOp.getIfExpr())
2918 : nullptr,
2919 Order: order, Simdlen: simdlen, Safelen: safelen);
2920
2921 // We now need to reduce the per-simd-lane reduction variable into the
2922 // original variable. This works a bit differently to other reductions (e.g.
2923 // wsloop) because we don't need to call into the OpenMP runtime to handle
2924 // threads: everything happened in this one thread.
2925 for (auto [i, tuple] : llvm::enumerate(
2926 First: llvm::zip(t&: reductionDecls, u&: isByRef, args: simdOp.getReductionVars(),
2927 args&: privateReductionVariables))) {
2928 auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
2929
2930 OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
2931 llvm::Value *originalVariable = moduleTranslation.lookupValue(value: reductionVar);
2932 llvm::Type *reductionType = moduleTranslation.convertType(type: decl.getType());
2933
2934 // We have one less load for by-ref case because that load is now inside of
2935 // the reduction region.
2936 llvm::Value *redValue = originalVariable;
2937 if (!byRef)
2938 redValue =
2939 builder.CreateLoad(Ty: reductionType, Ptr: redValue, Name: "red.value." + Twine(i));
2940 llvm::Value *privateRedValue = builder.CreateLoad(
2941 Ty: reductionType, Ptr: privateReductionVar, Name: "red.private.value." + Twine(i));
2942 llvm::Value *reduced;
2943
2944 auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
2945 if (failed(Result: handleError(result&: res, op&: opInst)))
2946 return failure();
2947 builder.restoreIP(IP: res.get());
2948
2949 // For by-ref case, the store is inside of the reduction region.
2950 if (!byRef)
2951 builder.CreateStore(Val: reduced, Ptr: originalVariable);
2952 }
2953
2954 // After the construct, deallocate private reduction variables.
2955 SmallVector<Region *> reductionRegions;
2956 llvm::transform(Range&: reductionDecls, d_first: std::back_inserter(x&: reductionRegions),
2957 F: [](omp::DeclareReductionOp reductionDecl) {
2958 return &reductionDecl.getCleanupRegion();
2959 });
2960 if (failed(Result: inlineOmpRegionCleanup(cleanupRegions&: reductionRegions, privateVariables: privateReductionVariables,
2961 moduleTranslation, builder,
2962 regionName: "omp.reduction.cleanup")))
2963 return failure();
2964
2965 return cleanupPrivateVars(builder, moduleTranslation, loc: simdOp.getLoc(),
2966 llvmPrivateVars&: privateVarsInfo.llvmVars,
2967 privateDecls&: privateVarsInfo.privatizers);
2968}
2969
2970/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
2971static LogicalResult
2972convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
2973 LLVM::ModuleTranslation &moduleTranslation) {
2974 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2975 auto loopOp = cast<omp::LoopNestOp>(Val&: opInst);
2976
2977 // Set up the source location value for OpenMP runtime.
2978 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2979
2980 // Generator of the canonical loop body.
2981 SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
2982 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
2983 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2984 llvm::Value *iv) -> llvm::Error {
2985 // Make sure further conversions know about the induction variable.
2986 moduleTranslation.mapValue(
2987 mlir: loopOp.getRegion().front().getArgument(i: loopInfos.size()), llvm: iv);
2988
2989 // Capture the body insertion point for use in nested loops. BodyIP of the
2990 // CanonicalLoopInfo always points to the beginning of the entry block of
2991 // the body.
2992 bodyInsertPoints.push_back(Elt: ip);
2993
2994 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2995 return llvm::Error::success();
2996
2997 // Convert the body of the loop.
2998 builder.restoreIP(IP: ip);
2999 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
3000 region&: loopOp.getRegion(), blockName: "omp.loop_nest.region", builder, moduleTranslation);
3001 if (!regionBlock)
3002 return regionBlock.takeError();
3003
3004 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
3005 return llvm::Error::success();
3006 };
3007
3008 // Delegate actual loop construction to the OpenMP IRBuilder.
3009 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
3010 // loop, i.e. it has a positive step, uses signed integer semantics.
3011 // Reconsider this code when the nested loop operation clearly supports more
3012 // cases.
3013 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
3014 llvm::Value *lowerBound =
3015 moduleTranslation.lookupValue(value: loopOp.getLoopLowerBounds()[i]);
3016 llvm::Value *upperBound =
3017 moduleTranslation.lookupValue(value: loopOp.getLoopUpperBounds()[i]);
3018 llvm::Value *step = moduleTranslation.lookupValue(value: loopOp.getLoopSteps()[i]);
3019
3020 // Make sure loop trip count are emitted in the preheader of the outermost
3021 // loop at the latest so that they are all available for the new collapsed
3022 // loop will be created below.
3023 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
3024 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
3025 if (i != 0) {
3026 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
3027 ompLoc.DL);
3028 computeIP = loopInfos.front()->getPreheaderIP();
3029 }
3030
3031 llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
3032 ompBuilder->createCanonicalLoop(
3033 Loc: loc, BodyGenCB: bodyGen, Start: lowerBound, Stop: upperBound, Step: step,
3034 /*IsSigned=*/true, InclusiveStop: loopOp.getLoopInclusive(), ComputeIP: computeIP);
3035
3036 if (failed(Result: handleError(result&: loopResult, op&: *loopOp)))
3037 return failure();
3038
3039 loopInfos.push_back(Elt: *loopResult);
3040 }
3041
3042 // Collapse loops. Store the insertion point because LoopInfos may get
3043 // invalidated.
3044 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
3045 loopInfos.front()->getAfterIP();
3046
3047 // Update the stack frame created for this loop to point to the resulting loop
3048 // after applying transformations.
3049 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3050 callback: [&](OpenMPLoopInfoStackFrame &frame) {
3051 frame.loopInfo = ompBuilder->collapseLoops(DL: ompLoc.DL, Loops: loopInfos, ComputeIP: {});
3052 return WalkResult::interrupt();
3053 });
3054
3055 // Continue building IR after the loop. Note that the LoopInfo returned by
3056 // `collapseLoops` points inside the outermost loop and is intended for
3057 // potential further loop transformations. Use the insertion point stored
3058 // before collapsing loops instead.
3059 builder.restoreIP(IP: afterIP);
3060 return success();
3061}
3062
3063/// Convert an omp.canonical_loop to LLVM-IR
3064static LogicalResult
3065convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3066 LLVM::ModuleTranslation &moduleTranslation) {
3067 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3068
3069 llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3070 Value loopIV = op.getInductionVar();
3071 Value loopTC = op.getTripCount();
3072
3073 llvm::Value *llvmTC = moduleTranslation.lookupValue(value: loopTC);
3074
3075 llvm::Expected<llvm::CanonicalLoopInfo *> llvmOrError =
3076 ompBuilder->createCanonicalLoop(
3077 Loc: loopLoc,
3078 BodyGenCB: [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3079 // Register the mapping of MLIR induction variable to LLVM-IR
3080 // induction variable
3081 moduleTranslation.mapValue(mlir: loopIV, llvm: llvmIV);
3082
3083 builder.restoreIP(IP: ip);
3084 llvm::Expected<llvm::BasicBlock *> bodyGenStatus =
3085 convertOmpOpRegions(region&: op.getRegion(), blockName: "omp.loop.region", builder,
3086 moduleTranslation);
3087
3088 return bodyGenStatus.takeError();
3089 },
3090 TripCount: llvmTC, Name: "omp.loop");
3091 if (!llvmOrError)
3092 return op.emitError(message: llvm::toString(E: llvmOrError.takeError()));
3093
3094 llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3095 llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3096 builder.restoreIP(IP: afterIP);
3097
3098 // Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3099 if (Value cli = op.getCli())
3100 moduleTranslation.mapOmpLoop(mlir: cli, llvm: llvmCLI);
3101
3102 return success();
3103}
3104
3105/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3106/// OpenMPIRBuilder.
3107static LogicalResult
3108applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3109 LLVM::ModuleTranslation &moduleTranslation) {
3110 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3111
3112 Value applyee = op.getApplyee();
3113 assert(applyee && "Loop to apply unrolling on required");
3114
3115 llvm::CanonicalLoopInfo *consBuilderCLI =
3116 moduleTranslation.lookupOMPLoop(mlir: applyee);
3117 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3118 ompBuilder->unrollLoopHeuristic(DL: loc.DL, Loop: consBuilderCLI);
3119
3120 moduleTranslation.invalidateOmpLoop(mlir: applyee);
3121 return success();
3122}
3123
3124/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3125static llvm::AtomicOrdering
3126convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3127 if (!ao)
3128 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3129
3130 switch (*ao) {
3131 case omp::ClauseMemoryOrderKind::Seq_cst:
3132 return llvm::AtomicOrdering::SequentiallyConsistent;
3133 case omp::ClauseMemoryOrderKind::Acq_rel:
3134 return llvm::AtomicOrdering::AcquireRelease;
3135 case omp::ClauseMemoryOrderKind::Acquire:
3136 return llvm::AtomicOrdering::Acquire;
3137 case omp::ClauseMemoryOrderKind::Release:
3138 return llvm::AtomicOrdering::Release;
3139 case omp::ClauseMemoryOrderKind::Relaxed:
3140 return llvm::AtomicOrdering::Monotonic;
3141 }
3142 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3143}
3144
3145/// Convert omp.atomic.read operation to LLVM IR.
3146static LogicalResult
3147convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3148 LLVM::ModuleTranslation &moduleTranslation) {
3149 auto readOp = cast<omp::AtomicReadOp>(Val&: opInst);
3150 if (failed(Result: checkImplementationStatus(op&: opInst)))
3151 return failure();
3152
3153 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3154 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3155 findAllocaInsertPoint(builder, moduleTranslation);
3156
3157 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3158
3159 llvm::AtomicOrdering AO = convertAtomicOrdering(ao: readOp.getMemoryOrder());
3160 llvm::Value *x = moduleTranslation.lookupValue(value: readOp.getX());
3161 llvm::Value *v = moduleTranslation.lookupValue(value: readOp.getV());
3162
3163 llvm::Type *elementType =
3164 moduleTranslation.convertType(type: readOp.getElementType());
3165
3166 llvm::OpenMPIRBuilder::AtomicOpValue V = {.Var: v, .ElemTy: elementType, .IsSigned: false, .IsVolatile: false};
3167 llvm::OpenMPIRBuilder::AtomicOpValue X = {.Var: x, .ElemTy: elementType, .IsSigned: false, .IsVolatile: false};
3168 builder.restoreIP(IP: ompBuilder->createAtomicRead(Loc: ompLoc, X, V, AO, AllocaIP: allocaIP));
3169 return success();
3170}
3171
3172/// Converts an omp.atomic.write operation to LLVM IR.
3173static LogicalResult
3174convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3175 LLVM::ModuleTranslation &moduleTranslation) {
3176 auto writeOp = cast<omp::AtomicWriteOp>(Val&: opInst);
3177 if (failed(Result: checkImplementationStatus(op&: opInst)))
3178 return failure();
3179
3180 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3181 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3182 findAllocaInsertPoint(builder, moduleTranslation);
3183
3184 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3185 llvm::AtomicOrdering ao = convertAtomicOrdering(ao: writeOp.getMemoryOrder());
3186 llvm::Value *expr = moduleTranslation.lookupValue(value: writeOp.getExpr());
3187 llvm::Value *dest = moduleTranslation.lookupValue(value: writeOp.getX());
3188 llvm::Type *ty = moduleTranslation.convertType(type: writeOp.getExpr().getType());
3189 llvm::OpenMPIRBuilder::AtomicOpValue x = {.Var: dest, .ElemTy: ty, /*isSigned=*/.IsSigned: false,
3190 /*isVolatile=*/.IsVolatile: false};
3191 builder.restoreIP(
3192 IP: ompBuilder->createAtomicWrite(Loc: ompLoc, X&: x, Expr: expr, AO: ao, AllocaIP: allocaIP));
3193 return success();
3194}
3195
3196/// Converts an LLVM dialect binary operation to the corresponding enum value
3197/// for `atomicrmw` supported binary operation.
3198llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3199 return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
3200 .Case(caseFn: [&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3201 .Case(caseFn: [&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3202 .Case(caseFn: [&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3203 .Case(caseFn: [&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3204 .Case(caseFn: [&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3205 .Case(caseFn: [&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3206 .Case(caseFn: [&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3207 .Case(caseFn: [&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3208 .Case(caseFn: [&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3209 .Default(defaultResult: llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3210}
3211
3212/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3213static LogicalResult
3214convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3215 llvm::IRBuilderBase &builder,
3216 LLVM::ModuleTranslation &moduleTranslation) {
3217 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3218 if (failed(Result: checkImplementationStatus(op&: *opInst)))
3219 return failure();
3220
3221 // Convert values and types.
3222 auto &innerOpList = opInst.getRegion().front().getOperations();
3223 bool isXBinopExpr{false};
3224 llvm::AtomicRMWInst::BinOp binop;
3225 mlir::Value mlirExpr;
3226 llvm::Value *llvmExpr = nullptr;
3227 llvm::Value *llvmX = nullptr;
3228 llvm::Type *llvmXElementType = nullptr;
3229 if (innerOpList.size() == 2) {
3230 // The two operations here are the update and the terminator.
3231 // Since we can identify the update operation, there is a possibility
3232 // that we can generate the atomicrmw instruction.
3233 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3234 if (!llvm::is_contained(Range: innerOp.getOperands(),
3235 Element: opInst.getRegion().getArgument(i: 0))) {
3236 return opInst.emitError(message: "no atomic update operation with region argument"
3237 " as operand found inside atomic.update region");
3238 }
3239 binop = convertBinOpToAtomic(op&: innerOp);
3240 isXBinopExpr = innerOp.getOperand(idx: 0) == opInst.getRegion().getArgument(i: 0);
3241 mlirExpr = (isXBinopExpr ? innerOp.getOperand(idx: 1) : innerOp.getOperand(idx: 0));
3242 llvmExpr = moduleTranslation.lookupValue(value: mlirExpr);
3243 } else {
3244 // Since the update region includes more than one operation
3245 // we will resort to generating a cmpxchg loop.
3246 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3247 }
3248 llvmX = moduleTranslation.lookupValue(value: opInst.getX());
3249 llvmXElementType = moduleTranslation.convertType(
3250 type: opInst.getRegion().getArgument(i: 0).getType());
3251 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {.Var: llvmX, .ElemTy: llvmXElementType,
3252 /*isSigned=*/.IsSigned: false,
3253 /*isVolatile=*/.IsVolatile: false};
3254
3255 llvm::AtomicOrdering atomicOrdering =
3256 convertAtomicOrdering(ao: opInst.getMemoryOrder());
3257
3258 // Generate update code.
3259 auto updateFn =
3260 [&opInst, &moduleTranslation](
3261 llvm::Value *atomicx,
3262 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3263 Block &bb = *opInst.getRegion().begin();
3264 moduleTranslation.mapValue(mlir: *opInst.getRegion().args_begin(), llvm: atomicx);
3265 moduleTranslation.mapBlock(mlir: &bb, llvm: builder.GetInsertBlock());
3266 if (failed(Result: moduleTranslation.convertBlock(bb, ignoreArguments: true, builder)))
3267 return llvm::make_error<PreviouslyReportedError>();
3268
3269 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(Val: bb.getTerminator());
3270 assert(yieldop && yieldop.getResults().size() == 1 &&
3271 "terminator must be omp.yield op and it must have exactly one "
3272 "argument");
3273 return moduleTranslation.lookupValue(value: yieldop.getResults()[0]);
3274 };
3275
3276 // Handle ambiguous alloca, if any.
3277 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3278 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3279 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3280 ompBuilder->createAtomicUpdate(Loc: ompLoc, AllocaIP: allocaIP, X&: llvmAtomicX, Expr: llvmExpr,
3281 AO: atomicOrdering, RMWOp: binop, UpdateOp: updateFn,
3282 IsXBinopExpr: isXBinopExpr);
3283
3284 if (failed(Result: handleError(result&: afterIP, op&: *opInst)))
3285 return failure();
3286
3287 builder.restoreIP(IP: *afterIP);
3288 return success();
3289}
3290
3291static LogicalResult
3292convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
3293 llvm::IRBuilderBase &builder,
3294 LLVM::ModuleTranslation &moduleTranslation) {
3295 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3296 if (failed(Result: checkImplementationStatus(op&: *atomicCaptureOp)))
3297 return failure();
3298
3299 mlir::Value mlirExpr;
3300 bool isXBinopExpr = false, isPostfixUpdate = false;
3301 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3302
3303 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3304 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3305
3306 assert((atomicUpdateOp || atomicWriteOp) &&
3307 "internal op must be an atomic.update or atomic.write op");
3308
3309 if (atomicWriteOp) {
3310 isPostfixUpdate = true;
3311 mlirExpr = atomicWriteOp.getExpr();
3312 } else {
3313 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3314 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3315 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3316 // Find the binary update operation that uses the region argument
3317 // and get the expression to update
3318 if (innerOpList.size() == 2) {
3319 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
3320 if (!llvm::is_contained(Range: innerOp.getOperands(),
3321 Element: atomicUpdateOp.getRegion().getArgument(i: 0))) {
3322 return atomicUpdateOp.emitError(
3323 message: "no atomic update operation with region argument"
3324 " as operand found inside atomic.update region");
3325 }
3326 binop = convertBinOpToAtomic(op&: innerOp);
3327 isXBinopExpr =
3328 innerOp.getOperand(idx: 0) == atomicUpdateOp.getRegion().getArgument(i: 0);
3329 mlirExpr = (isXBinopExpr ? innerOp.getOperand(idx: 1) : innerOp.getOperand(idx: 0));
3330 } else {
3331 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3332 }
3333 }
3334
3335 llvm::Value *llvmExpr = moduleTranslation.lookupValue(value: mlirExpr);
3336 llvm::Value *llvmX =
3337 moduleTranslation.lookupValue(value: atomicCaptureOp.getAtomicReadOp().getX());
3338 llvm::Value *llvmV =
3339 moduleTranslation.lookupValue(value: atomicCaptureOp.getAtomicReadOp().getV());
3340 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3341 type: atomicCaptureOp.getAtomicReadOp().getElementType());
3342 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {.Var: llvmX, .ElemTy: llvmXElementType,
3343 /*isSigned=*/.IsSigned: false,
3344 /*isVolatile=*/.IsVolatile: false};
3345 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {.Var: llvmV, .ElemTy: llvmXElementType,
3346 /*isSigned=*/.IsSigned: false,
3347 /*isVolatile=*/.IsVolatile: false};
3348
3349 llvm::AtomicOrdering atomicOrdering =
3350 convertAtomicOrdering(ao: atomicCaptureOp.getMemoryOrder());
3351
3352 auto updateFn =
3353 [&](llvm::Value *atomicx,
3354 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3355 if (atomicWriteOp)
3356 return moduleTranslation.lookupValue(value: atomicWriteOp.getExpr());
3357 Block &bb = *atomicUpdateOp.getRegion().begin();
3358 moduleTranslation.mapValue(mlir: *atomicUpdateOp.getRegion().args_begin(),
3359 llvm: atomicx);
3360 moduleTranslation.mapBlock(mlir: &bb, llvm: builder.GetInsertBlock());
3361 if (failed(Result: moduleTranslation.convertBlock(bb, ignoreArguments: true, builder)))
3362 return llvm::make_error<PreviouslyReportedError>();
3363
3364 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(Val: bb.getTerminator());
3365 assert(yieldop && yieldop.getResults().size() == 1 &&
3366 "terminator must be omp.yield op and it must have exactly one "
3367 "argument");
3368 return moduleTranslation.lookupValue(value: yieldop.getResults()[0]);
3369 };
3370
3371 // Handle ambiguous alloca, if any.
3372 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3373 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3374 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3375 ompBuilder->createAtomicCapture(
3376 Loc: ompLoc, AllocaIP: allocaIP, X&: llvmAtomicX, V&: llvmAtomicV, Expr: llvmExpr, AO: atomicOrdering,
3377 RMWOp: binop, UpdateOp: updateFn, UpdateExpr: atomicUpdateOp, IsPostfixUpdate: isPostfixUpdate, IsXBinopExpr: isXBinopExpr);
3378
3379 if (failed(Result: handleError(result&: afterIP, op&: *atomicCaptureOp)))
3380 return failure();
3381
3382 builder.restoreIP(IP: *afterIP);
3383 return success();
3384}
3385
3386static llvm::omp::Directive convertCancellationConstructType(
3387 omp::ClauseCancellationConstructType directive) {
3388 switch (directive) {
3389 case omp::ClauseCancellationConstructType::Loop:
3390 return llvm::omp::Directive::OMPD_for;
3391 case omp::ClauseCancellationConstructType::Parallel:
3392 return llvm::omp::Directive::OMPD_parallel;
3393 case omp::ClauseCancellationConstructType::Sections:
3394 return llvm::omp::Directive::OMPD_sections;
3395 case omp::ClauseCancellationConstructType::Taskgroup:
3396 return llvm::omp::Directive::OMPD_taskgroup;
3397 }
3398 llvm_unreachable("Unhandled cancellation construct type");
3399}
3400
3401static LogicalResult
3402convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3403 LLVM::ModuleTranslation &moduleTranslation) {
3404 if (failed(Result: checkImplementationStatus(op&: *op.getOperation())))
3405 return failure();
3406
3407 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3408 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3409
3410 llvm::Value *ifCond = nullptr;
3411 if (Value ifVar = op.getIfExpr())
3412 ifCond = moduleTranslation.lookupValue(value: ifVar);
3413
3414 llvm::omp::Directive cancelledDirective =
3415 convertCancellationConstructType(directive: op.getCancelDirective());
3416
3417 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3418 ompBuilder->createCancel(Loc: ompLoc, IfCondition: ifCond, CanceledDirective: cancelledDirective);
3419
3420 if (failed(Result: handleError(result&: afterIP, op&: *op.getOperation())))
3421 return failure();
3422
3423 builder.restoreIP(IP: afterIP.get());
3424
3425 return success();
3426}
3427
3428static LogicalResult
3429convertOmpCancellationPoint(omp::CancellationPointOp op,
3430 llvm::IRBuilderBase &builder,
3431 LLVM::ModuleTranslation &moduleTranslation) {
3432 if (failed(Result: checkImplementationStatus(op&: *op.getOperation())))
3433 return failure();
3434
3435 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3436 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3437
3438 llvm::omp::Directive cancelledDirective =
3439 convertCancellationConstructType(directive: op.getCancelDirective());
3440
3441 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3442 ompBuilder->createCancellationPoint(Loc: ompLoc, CanceledDirective: cancelledDirective);
3443
3444 if (failed(Result: handleError(result&: afterIP, op&: *op.getOperation())))
3445 return failure();
3446
3447 builder.restoreIP(IP: afterIP.get());
3448
3449 return success();
3450}
3451
3452/// Converts an OpenMP Threadprivate operation into LLVM IR using
3453/// OpenMPIRBuilder.
3454static LogicalResult
3455convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
3456 LLVM::ModuleTranslation &moduleTranslation) {
3457 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3458 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3459 auto threadprivateOp = cast<omp::ThreadprivateOp>(Val&: opInst);
3460
3461 if (failed(Result: checkImplementationStatus(op&: opInst)))
3462 return failure();
3463
3464 Value symAddr = threadprivateOp.getSymAddr();
3465 auto *symOp = symAddr.getDefiningOp();
3466
3467 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(Val: symOp))
3468 symOp = asCast.getOperand().getDefiningOp();
3469
3470 if (!isa<LLVM::AddressOfOp>(Val: symOp))
3471 return opInst.emitError(message: "Addressing symbol not found");
3472 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(Val: symOp);
3473
3474 LLVM::GlobalOp global =
3475 addressOfOp.getGlobal(symbolTable&: moduleTranslation.symbolTable());
3476 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(op: global);
3477
3478 if (!ompBuilder->Config.isTargetDevice()) {
3479 llvm::Type *type = globalValue->getValueType();
3480 llvm::TypeSize typeSize =
3481 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3482 Ty: type);
3483 llvm::ConstantInt *size = builder.getInt64(C: typeSize.getFixedValue());
3484 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3485 Loc: ompLoc, Pointer: globalValue, Size: size, Name: global.getSymName() + ".cache");
3486 moduleTranslation.mapValue(mlir: opInst.getResult(idx: 0), llvm: callInst);
3487 } else {
3488 moduleTranslation.mapValue(mlir: opInst.getResult(idx: 0), llvm: globalValue);
3489 }
3490
3491 return success();
3492}
3493
3494static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3495convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
3496 switch (deviceClause) {
3497 case mlir::omp::DeclareTargetDeviceType::host:
3498 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3499 break;
3500 case mlir::omp::DeclareTargetDeviceType::nohost:
3501 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3502 break;
3503 case mlir::omp::DeclareTargetDeviceType::any:
3504 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3505 break;
3506 }
3507 llvm_unreachable("unhandled device clause");
3508}
3509
3510static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3511convertToCaptureClauseKind(
3512 mlir::omp::DeclareTargetCaptureClause captureClause) {
3513 switch (captureClause) {
3514 case mlir::omp::DeclareTargetCaptureClause::to:
3515 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3516 case mlir::omp::DeclareTargetCaptureClause::link:
3517 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3518 case mlir::omp::DeclareTargetCaptureClause::enter:
3519 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3520 }
3521 llvm_unreachable("unhandled capture clause");
3522}
3523
3524static llvm::SmallString<64>
3525getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
3526 llvm::OpenMPIRBuilder &ompBuilder) {
3527 llvm::SmallString<64> suffix;
3528 llvm::raw_svector_ostream os(suffix);
3529 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
3530 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
3531 auto fileInfoCallBack = [&loc]() {
3532 return std::pair<std::string, uint64_t>(
3533 llvm::StringRef(loc.getFilename()), loc.getLine());
3534 };
3535
3536 os << llvm::format(
3537 Fmt: "_%x", Vals: ompBuilder.getTargetEntryUniqueInfo(CallBack: fileInfoCallBack).FileID);
3538 }
3539 os << "_decl_tgt_ref_ptr";
3540
3541 return suffix;
3542}
3543
3544static bool isDeclareTargetLink(mlir::Value value) {
3545 if (auto addressOfOp =
3546 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(Val: value.getDefiningOp())) {
3547 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3548 Operation *gOp = modOp.lookupSymbol(name: addressOfOp.getGlobalName());
3549 if (auto declareTargetGlobal =
3550 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(Val: gOp))
3551 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3552 mlir::omp::DeclareTargetCaptureClause::link)
3553 return true;
3554 }
3555 return false;
3556}
3557
3558// Returns the reference pointer generated by the lowering of the declare target
3559// operation in cases where the link clause is used or the to clause is used in
3560// USM mode.
3561static llvm::Value *
3562getRefPtrIfDeclareTarget(mlir::Value value,
3563 LLVM::ModuleTranslation &moduleTranslation) {
3564 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3565 Operation *op = value.getDefiningOp();
3566 if (auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(Val: op))
3567 op = addrCast->getOperand(idx: 0).getDefiningOp();
3568
3569 // An easier way to do this may just be to keep track of any pointer
3570 // references and their mapping to their respective operation
3571 if (auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(Val: op)) {
3572 if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3573 Val: addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3574 name: addressOfOp.getGlobalName()))) {
3575
3576 if (auto declareTargetGlobal =
3577 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3578 Val: gOp.getOperation())) {
3579
3580 // In this case, we must utilise the reference pointer generated by the
3581 // declare target operation, similar to Clang
3582 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3583 mlir::omp::DeclareTargetCaptureClause::link) ||
3584 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3585 mlir::omp::DeclareTargetCaptureClause::to &&
3586 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3587 llvm::SmallString<64> suffix =
3588 getDeclareTargetRefPtrSuffix(globalOp: gOp, ompBuilder&: *ompBuilder);
3589
3590 if (gOp.getSymName().contains(Other: suffix))
3591 return moduleTranslation.getLLVMModule()->getNamedValue(
3592 Name: gOp.getSymName());
3593
3594 return moduleTranslation.getLLVMModule()->getNamedValue(
3595 Name: (gOp.getSymName().str() + suffix.str()).str());
3596 }
3597 }
3598 }
3599 }
3600
3601 return nullptr;
3602}
3603
3604namespace {
3605// Append customMappers information to existing MapInfosTy
3606struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3607 SmallVector<Operation *, 4> Mappers;
3608
3609 /// Append arrays in \a CurInfo.
3610 void append(MapInfosTy &curInfo) {
3611 Mappers.append(in_start: curInfo.Mappers.begin(), in_end: curInfo.Mappers.end());
3612 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo&: curInfo);
3613 }
3614};
3615// A small helper structure to contain data gathered
3616// for map lowering and coalese it into one area and
3617// avoiding extra computations such as searches in the
3618// llvm module for lowered mapped variables or checking
3619// if something is declare target (and retrieving the
3620// value) more than neccessary.
3621struct MapInfoData : MapInfosTy {
3622 llvm::SmallVector<bool, 4> IsDeclareTarget;
3623 llvm::SmallVector<bool, 4> IsAMember;
3624 // Identify if mapping was added by mapClause or use_device clauses.
3625 llvm::SmallVector<bool, 4> IsAMapping;
3626 llvm::SmallVector<mlir::Operation *, 4> MapClause;
3627 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
3628 // Stripped off array/pointer to get the underlying
3629 // element type
3630 llvm::SmallVector<llvm::Type *, 4> BaseType;
3631
3632 /// Append arrays in \a CurInfo.
3633 void append(MapInfoData &CurInfo) {
3634 IsDeclareTarget.append(in_start: CurInfo.IsDeclareTarget.begin(),
3635 in_end: CurInfo.IsDeclareTarget.end());
3636 MapClause.append(in_start: CurInfo.MapClause.begin(), in_end: CurInfo.MapClause.end());
3637 OriginalValue.append(in_start: CurInfo.OriginalValue.begin(),
3638 in_end: CurInfo.OriginalValue.end());
3639 BaseType.append(in_start: CurInfo.BaseType.begin(), in_end: CurInfo.BaseType.end());
3640 MapInfosTy::append(curInfo&: CurInfo);
3641 }
3642};
3643} // namespace
3644
3645uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
3646 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3647 Val: arrTy.getElementType()))
3648 return getArrayElementSizeInBits(arrTy: nestedArrTy, dl);
3649 return dl.getTypeSizeInBits(t: arrTy.getElementType());
3650}
3651
3652// This function calculates the size to be offloaded for a specified type, given
3653// its associated map clause (which can contain bounds information which affects
3654// the total size), this size is calculated based on the underlying element type
3655// e.g. given a 1-D array of ints, we will calculate the size from the integer
3656// type * number of elements in the array. This size can be used in other
3657// calculations but is ultimately used as an argument to the OpenMP runtimes
3658// kernel argument structure which is generated through the combinedInfo data
3659// structures.
3660// This function is somewhat equivalent to Clang's getExprTypeSize inside of
3661// CGOpenMPRuntime.cpp.
3662llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
3663 Operation *clauseOp, llvm::Value *basePointer,
3664 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3665 LLVM::ModuleTranslation &moduleTranslation) {
3666 if (auto memberClause =
3667 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(Val: clauseOp)) {
3668 // This calculates the size to transfer based on bounds and the underlying
3669 // element type, provided bounds have been specified (Fortran
3670 // pointers/allocatables/target and arrays that have sections specified fall
3671 // into this as well).
3672 if (!memberClause.getBounds().empty()) {
3673 llvm::Value *elementCount = builder.getInt64(C: 1);
3674 for (auto bounds : memberClause.getBounds()) {
3675 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3676 Val: bounds.getDefiningOp())) {
3677 // The below calculation for the size to be mapped calculated from the
3678 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
3679 // multiply by the underlying element types byte size to get the full
3680 // size to be offloaded based on the bounds
3681 elementCount = builder.CreateMul(
3682 LHS: elementCount,
3683 RHS: builder.CreateAdd(
3684 LHS: builder.CreateSub(
3685 LHS: moduleTranslation.lookupValue(value: boundOp.getUpperBound()),
3686 RHS: moduleTranslation.lookupValue(value: boundOp.getLowerBound())),
3687 RHS: builder.getInt64(C: 1)));
3688 }
3689 }
3690
3691 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
3692 // the size in inconsistent byte or bit format.
3693 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(t: type);
3694 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(Val: type))
3695 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
3696
3697 // The size in bytes x number of elements, the sizeInBytes stored is
3698 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
3699 // size, so we do some on the fly runtime math to get the size in
3700 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
3701 // some adjustment for members with more complex types.
3702 return builder.CreateMul(LHS: elementCount,
3703 RHS: builder.getInt64(C: underlyingTypeSzInBits / 8));
3704 }
3705 }
3706
3707 return builder.getInt64(C: dl.getTypeSizeInBits(t: type) / 8);
3708}
3709
3710static void collectMapDataFromMapOperands(
3711 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
3712 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
3713 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
3714 ArrayRef<Value> useDevAddrOperands = {},
3715 ArrayRef<Value> hasDevAddrOperands = {}) {
3716 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
3717 // Check if this is a member mapping and correctly assign that it is, if
3718 // it is a member of a larger object.
3719 // TODO: Need better handling of members, and distinguishing of members
3720 // that are implicitly allocated on device vs explicitly passed in as
3721 // arguments.
3722 // TODO: May require some further additions to support nested record
3723 // types, i.e. member maps that can have member maps.
3724 for (Value mapValue : mapVars) {
3725 auto map = cast<omp::MapInfoOp>(Val: mapValue.getDefiningOp());
3726 for (auto member : map.getMembers())
3727 if (member == mapOp)
3728 return true;
3729 }
3730 return false;
3731 };
3732
3733 // Process MapOperands
3734 for (Value mapValue : mapVars) {
3735 auto mapOp = cast<omp::MapInfoOp>(Val: mapValue.getDefiningOp());
3736 Value offloadPtr =
3737 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3738 mapData.OriginalValue.push_back(Elt: moduleTranslation.lookupValue(value: offloadPtr));
3739 mapData.Pointers.push_back(Elt: mapData.OriginalValue.back());
3740
3741 if (llvm::Value *refPtr =
3742 getRefPtrIfDeclareTarget(value: offloadPtr,
3743 moduleTranslation)) { // declare target
3744 mapData.IsDeclareTarget.push_back(Elt: true);
3745 mapData.BasePointers.push_back(Elt: refPtr);
3746 } else { // regular mapped variable
3747 mapData.IsDeclareTarget.push_back(Elt: false);
3748 mapData.BasePointers.push_back(Elt: mapData.OriginalValue.back());
3749 }
3750
3751 mapData.BaseType.push_back(
3752 Elt: moduleTranslation.convertType(type: mapOp.getVarType()));
3753 mapData.Sizes.push_back(
3754 Elt: getSizeInBytes(dl, type: mapOp.getVarType(), clauseOp: mapOp, basePointer: mapData.Pointers.back(),
3755 baseType: mapData.BaseType.back(), builder, moduleTranslation));
3756 mapData.MapClause.push_back(Elt: mapOp.getOperation());
3757 mapData.Types.push_back(
3758 Elt: llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3759 mapData.Names.push_back(Elt: LLVM::createMappingInformation(
3760 loc: mapOp.getLoc(), builder&: *moduleTranslation.getOpenMPBuilder()));
3761 mapData.DevicePointers.push_back(Elt: llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3762 if (mapOp.getMapperId())
3763 mapData.Mappers.push_back(
3764 Elt: SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3765 from: mapOp, symbol: mapOp.getMapperIdAttr()));
3766 else
3767 mapData.Mappers.push_back(Elt: nullptr);
3768 mapData.IsAMapping.push_back(Elt: true);
3769 mapData.IsAMember.push_back(Elt: checkIsAMember(mapVars, mapOp));
3770 }
3771
3772 auto findMapInfo = [&mapData](llvm::Value *val,
3773 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3774 unsigned index = 0;
3775 bool found = false;
3776 for (llvm::Value *basePtr : mapData.OriginalValue) {
3777 if (basePtr == val && mapData.IsAMapping[index]) {
3778 found = true;
3779 mapData.Types[index] |=
3780 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3781 mapData.DevicePointers[index] = devInfoTy;
3782 }
3783 index++;
3784 }
3785 return found;
3786 };
3787
3788 // Process useDevPtr(Addr)Operands
3789 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
3790 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3791 for (Value mapValue : useDevOperands) {
3792 auto mapOp = cast<omp::MapInfoOp>(Val: mapValue.getDefiningOp());
3793 Value offloadPtr =
3794 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3795 llvm::Value *origValue = moduleTranslation.lookupValue(value: offloadPtr);
3796
3797 // Check if map info is already present for this entry.
3798 if (!findMapInfo(origValue, devInfoTy)) {
3799 mapData.OriginalValue.push_back(Elt: origValue);
3800 mapData.Pointers.push_back(Elt: mapData.OriginalValue.back());
3801 mapData.IsDeclareTarget.push_back(Elt: false);
3802 mapData.BasePointers.push_back(Elt: mapData.OriginalValue.back());
3803 mapData.BaseType.push_back(
3804 Elt: moduleTranslation.convertType(type: mapOp.getVarType()));
3805 mapData.Sizes.push_back(Elt: builder.getInt64(C: 0));
3806 mapData.MapClause.push_back(Elt: mapOp.getOperation());
3807 mapData.Types.push_back(
3808 Elt: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3809 mapData.Names.push_back(Elt: LLVM::createMappingInformation(
3810 loc: mapOp.getLoc(), builder&: *moduleTranslation.getOpenMPBuilder()));
3811 mapData.DevicePointers.push_back(Elt: devInfoTy);
3812 mapData.Mappers.push_back(Elt: nullptr);
3813 mapData.IsAMapping.push_back(Elt: false);
3814 mapData.IsAMember.push_back(Elt: checkIsAMember(useDevOperands, mapOp));
3815 }
3816 }
3817 };
3818
3819 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3820 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3821
3822 for (Value mapValue : hasDevAddrOperands) {
3823 auto mapOp = cast<omp::MapInfoOp>(Val: mapValue.getDefiningOp());
3824 Value offloadPtr =
3825 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3826 llvm::Value *origValue = moduleTranslation.lookupValue(value: offloadPtr);
3827 auto mapType =
3828 static_cast<llvm::omp::OpenMPOffloadMappingFlags>(mapOp.getMapType());
3829 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3830
3831 mapData.OriginalValue.push_back(Elt: origValue);
3832 mapData.BasePointers.push_back(Elt: origValue);
3833 mapData.Pointers.push_back(Elt: origValue);
3834 mapData.IsDeclareTarget.push_back(Elt: false);
3835 mapData.BaseType.push_back(
3836 Elt: moduleTranslation.convertType(type: mapOp.getVarType()));
3837 mapData.Sizes.push_back(
3838 Elt: builder.getInt64(C: dl.getTypeSize(t: mapOp.getVarType())));
3839 mapData.MapClause.push_back(Elt: mapOp.getOperation());
3840 if (llvm::to_underlying(E: mapType & mapTypeAlways)) {
3841 // Descriptors are mapped with the ALWAYS flag, since they can get
3842 // rematerialized, so the address of the decriptor for a given object
3843 // may change from one place to another.
3844 mapData.Types.push_back(Elt: mapType);
3845 // Technically it's possible for a non-descriptor mapping to have
3846 // both has-device-addr and ALWAYS, so lookup the mapper in case it
3847 // exists.
3848 if (mapOp.getMapperId()) {
3849 mapData.Mappers.push_back(
3850 Elt: SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3851 from: mapOp, symbol: mapOp.getMapperIdAttr()));
3852 } else {
3853 mapData.Mappers.push_back(Elt: nullptr);
3854 }
3855 } else {
3856 mapData.Types.push_back(
3857 Elt: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3858 mapData.Mappers.push_back(Elt: nullptr);
3859 }
3860 mapData.Names.push_back(Elt: LLVM::createMappingInformation(
3861 loc: mapOp.getLoc(), builder&: *moduleTranslation.getOpenMPBuilder()));
3862 mapData.DevicePointers.push_back(
3863 Elt: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3864 mapData.IsAMapping.push_back(Elt: false);
3865 mapData.IsAMember.push_back(Elt: checkIsAMember(hasDevAddrOperands, mapOp));
3866 }
3867}
3868
3869static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
3870 auto *res = llvm::find(Range&: mapData.MapClause, Val: memberOp);
3871 assert(res != mapData.MapClause.end() &&
3872 "MapInfoOp for member not found in MapData, cannot return index");
3873 return std::distance(first: mapData.MapClause.begin(), last: res);
3874}
3875
3876static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
3877 bool first) {
3878 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3879 // Only 1 member has been mapped, we can return it.
3880 if (indexAttr.size() == 1)
3881 return cast<omp::MapInfoOp>(Val: mapInfo.getMembers()[0].getDefiningOp());
3882
3883 llvm::SmallVector<size_t> indices(indexAttr.size());
3884 std::iota(first: indices.begin(), last: indices.end(), value: 0);
3885
3886 llvm::sort(Start: indices.begin(), End: indices.end(),
3887 Comp: [&](const size_t a, const size_t b) {
3888 auto memberIndicesA = cast<ArrayAttr>(Val: indexAttr[a]);
3889 auto memberIndicesB = cast<ArrayAttr>(Val: indexAttr[b]);
3890 for (const auto it : llvm::zip(t&: memberIndicesA, u&: memberIndicesB)) {
3891 int64_t aIndex = cast<IntegerAttr>(Val: std::get<0>(t: it)).getInt();
3892 int64_t bIndex = cast<IntegerAttr>(Val: std::get<1>(t: it)).getInt();
3893
3894 if (aIndex == bIndex)
3895 continue;
3896
3897 if (aIndex < bIndex)
3898 return first;
3899
3900 if (aIndex > bIndex)
3901 return !first;
3902 }
3903
3904 // Iterated the up until the end of the smallest member and
3905 // they were found to be equal up to that point, so select
3906 // the member with the lowest index count, so the "parent"
3907 return memberIndicesA.size() < memberIndicesB.size();
3908 });
3909
3910 return llvm::cast<omp::MapInfoOp>(
3911 Val: mapInfo.getMembers()[indices.front()].getDefiningOp());
3912}
3913
3914/// This function calculates the array/pointer offset for map data provided
3915/// with bounds operations, e.g. when provided something like the following:
3916///
3917/// Fortran
3918/// map(tofrom: array(2:5, 3:2))
3919/// or
3920/// C++
3921/// map(tofrom: array[1:4][2:3])
3922/// We must calculate the initial pointer offset to pass across, this function
3923/// performs this using bounds.
3924///
3925/// NOTE: which while specified in row-major order it currently needs to be
3926/// flipped for Fortran's column order array allocation and access (as
3927/// opposed to C++'s row-major, hence the backwards processing where order is
3928/// important). This is likely important to keep in mind for the future when
3929/// we incorporate a C++ frontend, both frontends will need to agree on the
3930/// ordering of generated bounds operations (one may have to flip them) to
3931/// make the below lowering frontend agnostic. The offload size
3932/// calcualtion may also have to be adjusted for C++.
3933std::vector<llvm::Value *>
3934calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
3935 llvm::IRBuilderBase &builder, bool isArrayTy,
3936 OperandRange bounds) {
3937 std::vector<llvm::Value *> idx;
3938 // There's no bounds to calculate an offset from, we can safely
3939 // ignore and return no indices.
3940 if (bounds.empty())
3941 return idx;
3942
3943 // If we have an array type, then we have its type so can treat it as a
3944 // normal GEP instruction where the bounds operations are simply indexes
3945 // into the array. We currently do reverse order of the bounds, which
3946 // I believe leans more towards Fortran's column-major in memory.
3947 if (isArrayTy) {
3948 idx.push_back(x: builder.getInt64(C: 0));
3949 for (int i = bounds.size() - 1; i >= 0; --i) {
3950 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3951 Val: bounds[i].getDefiningOp())) {
3952 idx.push_back(x: moduleTranslation.lookupValue(value: boundOp.getLowerBound()));
3953 }
3954 }
3955 } else {
3956 // If we do not have an array type, but we have bounds, then we're dealing
3957 // with a pointer that's being treated like an array and we have the
3958 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
3959 // address (pointer pointing to the actual data) so we must caclulate the
3960 // offset using a single index which the following two loops attempts to
3961 // compute.
3962
3963 // Calculates the size offset we need to make per row e.g. first row or
3964 // column only needs to be offset by one, but the next would have to be
3965 // the previous row/column offset multiplied by the extent of current row.
3966 //
3967 // For example ([1][10][100]):
3968 //
3969 // - First row/column we move by 1 for each index increment
3970 // - Second row/column we move by 1 (first row/column) * 10 (extent/size of
3971 // current) for 10 for each index increment
3972 // - Third row/column we would move by 10 (second row/column) *
3973 // (extent/size of current) 100 for 1000 for each index increment
3974 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(C: 1)};
3975 for (size_t i = 1; i < bounds.size(); ++i) {
3976 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3977 Val: bounds[i].getDefiningOp())) {
3978 dimensionIndexSizeOffset.push_back(x: builder.CreateMul(
3979 LHS: moduleTranslation.lookupValue(value: boundOp.getExtent()),
3980 RHS: dimensionIndexSizeOffset[i - 1]));
3981 }
3982 }
3983
3984 // Now that we have calculated how much we move by per index, we must
3985 // multiply each lower bound offset in indexes by the size offset we
3986 // have calculated in the previous and accumulate the results to get
3987 // our final resulting offset.
3988 for (int i = bounds.size() - 1; i >= 0; --i) {
3989 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3990 Val: bounds[i].getDefiningOp())) {
3991 if (idx.empty())
3992 idx.emplace_back(args: builder.CreateMul(
3993 LHS: moduleTranslation.lookupValue(value: boundOp.getLowerBound()),
3994 RHS: dimensionIndexSizeOffset[i]));
3995 else
3996 idx.back() = builder.CreateAdd(
3997 LHS: idx.back(), RHS: builder.CreateMul(LHS: moduleTranslation.lookupValue(
3998 value: boundOp.getLowerBound()),
3999 RHS: dimensionIndexSizeOffset[i]));
4000 }
4001 }
4002 }
4003
4004 return idx;
4005}
4006
4007// This creates two insertions into the MapInfosTy data structure for the
4008// "parent" of a set of members, (usually a container e.g.
4009// class/structure/derived type) when subsequent members have also been
4010// explicitly mapped on the same map clause. Certain types, such as Fortran
4011// descriptors are mapped like this as well, however, the members are
4012// implicit as far as a user is concerned, but we must explicitly map them
4013// internally.
4014//
4015// This function also returns the memberOfFlag for this particular parent,
4016// which is utilised in subsequent member mappings (by modifying there map type
4017// with it) to indicate that a member is part of this parent and should be
4018// treated by the runtime as such. Important to achieve the correct mapping.
4019//
4020// This function borrows a lot from Clang's emitCombinedEntry function
4021// inside of CGOpenMPRuntime.cpp
4022static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
4023 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4024 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4025 MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) {
4026 assert(!ompBuilder.Config.isTargetDevice() &&
4027 "function only supported for host device codegen");
4028
4029 // Map the first segment of our structure
4030 combinedInfo.Types.emplace_back(
4031 Args: isTargetParams
4032 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
4033 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
4034 combinedInfo.DevicePointers.emplace_back(
4035 Args&: mapData.DevicePointers[mapDataIndex]);
4036 combinedInfo.Mappers.emplace_back(Args&: mapData.Mappers[mapDataIndex]);
4037 combinedInfo.Names.emplace_back(Args: LLVM::createMappingInformation(
4038 loc: mapData.MapClause[mapDataIndex]->getLoc(), builder&: ompBuilder));
4039 combinedInfo.BasePointers.emplace_back(Args&: mapData.BasePointers[mapDataIndex]);
4040
4041 // Calculate size of the parent object being mapped based on the
4042 // addresses at runtime, highAddr - lowAddr = size. This of course
4043 // doesn't factor in allocated data like pointers, hence the further
4044 // processing of members specified by users, or in the case of
4045 // Fortran pointers and allocatables, the mapping of the pointed to
4046 // data by the descriptor (which itself, is a structure containing
4047 // runtime information on the dynamically allocated data).
4048 auto parentClause =
4049 llvm::cast<omp::MapInfoOp>(Val: mapData.MapClause[mapDataIndex]);
4050
4051 llvm::Value *lowAddr, *highAddr;
4052 if (!parentClause.getPartialMap()) {
4053 lowAddr = builder.CreatePointerCast(V: mapData.Pointers[mapDataIndex],
4054 DestTy: builder.getPtrTy());
4055 highAddr = builder.CreatePointerCast(
4056 V: builder.CreateConstGEP1_32(Ty: mapData.BaseType[mapDataIndex],
4057 Ptr: mapData.Pointers[mapDataIndex], Idx0: 1),
4058 DestTy: builder.getPtrTy());
4059 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[mapDataIndex]);
4060 } else {
4061 auto mapOp = dyn_cast<omp::MapInfoOp>(Val: mapData.MapClause[mapDataIndex]);
4062 int firstMemberIdx = getMapDataMemberIdx(
4063 mapData, memberOp: getFirstOrLastMappedMemberPtr(mapInfo: mapOp, first: true));
4064 lowAddr = builder.CreatePointerCast(V: mapData.Pointers[firstMemberIdx],
4065 DestTy: builder.getPtrTy());
4066 int lastMemberIdx = getMapDataMemberIdx(
4067 mapData, memberOp: getFirstOrLastMappedMemberPtr(mapInfo: mapOp, first: false));
4068 highAddr = builder.CreatePointerCast(
4069 V: builder.CreateGEP(Ty: mapData.BaseType[lastMemberIdx],
4070 Ptr: mapData.Pointers[lastMemberIdx], IdxList: builder.getInt64(C: 1)),
4071 DestTy: builder.getPtrTy());
4072 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[firstMemberIdx]);
4073 }
4074
4075 llvm::Value *size = builder.CreateIntCast(
4076 V: builder.CreatePtrDiff(ElemTy: builder.getInt8Ty(), LHS: highAddr, RHS: lowAddr),
4077 DestTy: builder.getInt64Ty(),
4078 /*isSigned=*/false);
4079 combinedInfo.Sizes.push_back(Elt: size);
4080
4081 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
4082 ompBuilder.getMemberOfFlag(Position: combinedInfo.BasePointers.size() - 1);
4083
4084 // This creates the initial MEMBER_OF mapping that consists of
4085 // the parent/top level container (same as above effectively, except
4086 // with a fixed initial compile time size and separate maptype which
4087 // indicates the true mape type (tofrom etc.). This parent mapping is
4088 // only relevant if the structure in its totality is being mapped,
4089 // otherwise the above suffices.
4090 if (!parentClause.getPartialMap()) {
4091 // TODO: This will need to be expanded to include the whole host of logic
4092 // for the map flags that Clang currently supports (e.g. it should do some
4093 // further case specific flag modifications). For the moment, it handles
4094 // what we support as expected.
4095 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
4096 ompBuilder.setCorrectMemberOfFlag(Flags&: mapFlag, MemberOfFlag: memberOfFlag);
4097 combinedInfo.Types.emplace_back(Args&: mapFlag);
4098 combinedInfo.DevicePointers.emplace_back(
4099 Args: llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4100 combinedInfo.Mappers.emplace_back(Args: nullptr);
4101 combinedInfo.Names.emplace_back(Args: LLVM::createMappingInformation(
4102 loc: mapData.MapClause[mapDataIndex]->getLoc(), builder&: ompBuilder));
4103 combinedInfo.BasePointers.emplace_back(Args&: mapData.BasePointers[mapDataIndex]);
4104 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[mapDataIndex]);
4105 combinedInfo.Sizes.emplace_back(Args&: mapData.Sizes[mapDataIndex]);
4106 }
4107 return memberOfFlag;
4108}
4109
4110// The intent is to verify if the mapped data being passed is a
4111// pointer -> pointee that requires special handling in certain cases,
4112// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4113//
4114// There may be a better way to verify this, but unfortunately with
4115// opaque pointers we lose the ability to easily check if something is
4116// a pointer whilst maintaining access to the underlying type.
4117static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4118 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4119 if (mapOp.getVarPtrPtr())
4120 return true;
4121
4122 // If the map data is declare target with a link clause, then it's represented
4123 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4124 // no relation to pointers.
4125 if (isDeclareTargetLink(value: mapOp.getVarPtr()))
4126 return true;
4127
4128 return false;
4129}
4130
4131// This function is intended to add explicit mappings of members
4132static void processMapMembersWithParent(
4133 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4134 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4135 MapInfoData &mapData, uint64_t mapDataIndex,
4136 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4137 assert(!ompBuilder.Config.isTargetDevice() &&
4138 "function only supported for host device codegen");
4139
4140 auto parentClause =
4141 llvm::cast<omp::MapInfoOp>(Val: mapData.MapClause[mapDataIndex]);
4142
4143 for (auto mappedMembers : parentClause.getMembers()) {
4144 auto memberClause =
4145 llvm::cast<omp::MapInfoOp>(Val: mappedMembers.getDefiningOp());
4146 int memberDataIdx = getMapDataMemberIdx(mapData, memberOp: memberClause);
4147
4148 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
4149
4150 // If we're currently mapping a pointer to a block of data, we must
4151 // initially map the pointer, and then attatch/bind the data with a
4152 // subsequent map to the pointer. This segment of code generates the
4153 // pointer mapping, which can in certain cases be optimised out as Clang
4154 // currently does in its lowering. However, for the moment we do not do so,
4155 // in part as we currently have substantially less information on the data
4156 // being mapped at this stage.
4157 if (checkIfPointerMap(mapOp: memberClause)) {
4158 auto mapFlag =
4159 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4160 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4161 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4162 ompBuilder.setCorrectMemberOfFlag(Flags&: mapFlag, MemberOfFlag: memberOfFlag);
4163 combinedInfo.Types.emplace_back(Args&: mapFlag);
4164 combinedInfo.DevicePointers.emplace_back(
4165 Args: llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4166 combinedInfo.Mappers.emplace_back(Args: nullptr);
4167 combinedInfo.Names.emplace_back(
4168 Args: LLVM::createMappingInformation(loc: memberClause.getLoc(), builder&: ompBuilder));
4169 combinedInfo.BasePointers.emplace_back(
4170 Args&: mapData.BasePointers[mapDataIndex]);
4171 combinedInfo.Pointers.emplace_back(Args&: mapData.BasePointers[memberDataIdx]);
4172 combinedInfo.Sizes.emplace_back(Args: builder.getInt64(
4173 C: moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
4174 }
4175
4176 // Same MemberOfFlag to indicate its link with parent and other members
4177 // of.
4178 auto mapFlag =
4179 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4180 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4181 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4182 ompBuilder.setCorrectMemberOfFlag(Flags&: mapFlag, MemberOfFlag: memberOfFlag);
4183 if (checkIfPointerMap(mapOp: memberClause))
4184 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4185
4186 combinedInfo.Types.emplace_back(Args&: mapFlag);
4187 combinedInfo.DevicePointers.emplace_back(
4188 Args&: mapData.DevicePointers[memberDataIdx]);
4189 combinedInfo.Mappers.emplace_back(Args&: mapData.Mappers[memberDataIdx]);
4190 combinedInfo.Names.emplace_back(
4191 Args: LLVM::createMappingInformation(loc: memberClause.getLoc(), builder&: ompBuilder));
4192 uint64_t basePointerIndex =
4193 checkIfPointerMap(mapOp: memberClause) ? memberDataIdx : mapDataIndex;
4194 combinedInfo.BasePointers.emplace_back(
4195 Args&: mapData.BasePointers[basePointerIndex]);
4196 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[memberDataIdx]);
4197
4198 llvm::Value *size = mapData.Sizes[memberDataIdx];
4199 if (checkIfPointerMap(mapOp: memberClause)) {
4200 size = builder.CreateSelect(
4201 C: builder.CreateIsNull(Arg: mapData.Pointers[memberDataIdx]),
4202 True: builder.getInt64(C: 0), False: size);
4203 }
4204
4205 combinedInfo.Sizes.emplace_back(Args&: size);
4206 }
4207}
4208
4209static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
4210 MapInfosTy &combinedInfo, bool isTargetParams,
4211 int mapDataParentIdx = -1) {
4212 // Declare Target Mappings are excluded from being marked as
4213 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
4214 // marked with OMP_MAP_PTR_AND_OBJ instead.
4215 auto mapFlag = mapData.Types[mapDataIdx];
4216 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(Val: mapData.MapClause[mapDataIdx]);
4217
4218 bool isPtrTy = checkIfPointerMap(mapOp: mapInfoOp);
4219 if (isPtrTy)
4220 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4221
4222 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4223 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4224
4225 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4226 !isPtrTy)
4227 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4228
4229 // if we're provided a mapDataParentIdx, then the data being mapped is
4230 // part of a larger object (in a parent <-> member mapping) and in this
4231 // case our BasePointer should be the parent.
4232 if (mapDataParentIdx >= 0)
4233 combinedInfo.BasePointers.emplace_back(
4234 Args&: mapData.BasePointers[mapDataParentIdx]);
4235 else
4236 combinedInfo.BasePointers.emplace_back(Args&: mapData.BasePointers[mapDataIdx]);
4237
4238 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[mapDataIdx]);
4239 combinedInfo.DevicePointers.emplace_back(Args&: mapData.DevicePointers[mapDataIdx]);
4240 combinedInfo.Mappers.emplace_back(Args&: mapData.Mappers[mapDataIdx]);
4241 combinedInfo.Names.emplace_back(Args&: mapData.Names[mapDataIdx]);
4242 combinedInfo.Types.emplace_back(Args&: mapFlag);
4243 combinedInfo.Sizes.emplace_back(Args&: mapData.Sizes[mapDataIdx]);
4244}
4245
4246static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation,
4247 llvm::IRBuilderBase &builder,
4248 llvm::OpenMPIRBuilder &ompBuilder,
4249 DataLayout &dl, MapInfosTy &combinedInfo,
4250 MapInfoData &mapData, uint64_t mapDataIndex,
4251 bool isTargetParams) {
4252 assert(!ompBuilder.Config.isTargetDevice() &&
4253 "function only supported for host device codegen");
4254
4255 auto parentClause =
4256 llvm::cast<omp::MapInfoOp>(Val: mapData.MapClause[mapDataIndex]);
4257
4258 // If we have a partial map (no parent referenced in the map clauses of the
4259 // directive, only members) and only a single member, we do not need to bind
4260 // the map of the member to the parent, we can pass the member separately.
4261 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4262 auto memberClause = llvm::cast<omp::MapInfoOp>(
4263 Val: parentClause.getMembers()[0].getDefiningOp());
4264 int memberDataIdx = getMapDataMemberIdx(mapData, memberOp: memberClause);
4265 // Note: Clang treats arrays with explicit bounds that fall into this
4266 // category as a parent with map case, however, it seems this isn't a
4267 // requirement, and processing them as an individual map is fine. So,
4268 // we will handle them as individual maps for the moment, as it's
4269 // difficult for us to check this as we always require bounds to be
4270 // specified currently and it's also marginally more optimal (single
4271 // map rather than two). The difference may come from the fact that
4272 // Clang maps array without bounds as pointers (which we do not
4273 // currently do), whereas we treat them as arrays in all cases
4274 // currently.
4275 processIndividualMap(mapData, mapDataIdx: memberDataIdx, combinedInfo, isTargetParams,
4276 mapDataParentIdx: mapDataIndex);
4277 return;
4278 }
4279
4280 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4281 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
4282 combinedInfo, mapData, mapDataIndex, isTargetParams);
4283 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
4284 combinedInfo, mapData, mapDataIndex,
4285 memberOfFlag: memberOfParentFlag);
4286}
4287
4288// This is a variation on Clang's GenerateOpenMPCapturedVars, which
4289// generates different operation (e.g. load/store) combinations for
4290// arguments to the kernel, based on map capture kinds which are then
4291// utilised in the combinedInfo in place of the original Map value.
4292static void
4293createAlteredByCaptureMap(MapInfoData &mapData,
4294 LLVM::ModuleTranslation &moduleTranslation,
4295 llvm::IRBuilderBase &builder) {
4296 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4297 "function only supported for host device codegen");
4298 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4299 // if it's declare target, skip it, it's handled separately.
4300 if (!mapData.IsDeclareTarget[i]) {
4301 auto mapOp = cast<omp::MapInfoOp>(Val: mapData.MapClause[i]);
4302 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4303 bool isPtrTy = checkIfPointerMap(mapOp);
4304
4305 // Currently handles array sectioning lowerbound case, but more
4306 // logic may be required in the future. Clang invokes EmitLValue,
4307 // which has specialised logic for special Clang types such as user
4308 // defines, so it is possible we will have to extend this for
4309 // structures or other complex types. As the general idea is that this
4310 // function mimics some of the logic from Clang that we require for
4311 // kernel argument passing from host -> device.
4312 switch (captureKind) {
4313 case omp::VariableCaptureKind::ByRef: {
4314 llvm::Value *newV = mapData.Pointers[i];
4315 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
4316 moduleTranslation, builder, isArrayTy: mapData.BaseType[i]->isArrayTy(),
4317 bounds: mapOp.getBounds());
4318 if (isPtrTy)
4319 newV = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: newV);
4320
4321 if (!offsetIdx.empty())
4322 newV = builder.CreateInBoundsGEP(Ty: mapData.BaseType[i], Ptr: newV, IdxList: offsetIdx,
4323 Name: "array_offset");
4324 mapData.Pointers[i] = newV;
4325 } break;
4326 case omp::VariableCaptureKind::ByCopy: {
4327 llvm::Type *type = mapData.BaseType[i];
4328 llvm::Value *newV;
4329 if (mapData.Pointers[i]->getType()->isPointerTy())
4330 newV = builder.CreateLoad(Ty: type, Ptr: mapData.Pointers[i]);
4331 else
4332 newV = mapData.Pointers[i];
4333
4334 if (!isPtrTy) {
4335 auto curInsert = builder.saveIP();
4336 builder.restoreIP(IP: findAllocaInsertPoint(builder, moduleTranslation));
4337 auto *memTempAlloc =
4338 builder.CreateAlloca(Ty: builder.getPtrTy(), ArraySize: nullptr, Name: ".casted");
4339 builder.restoreIP(IP: curInsert);
4340
4341 builder.CreateStore(Val: newV, Ptr: memTempAlloc);
4342 newV = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: memTempAlloc);
4343 }
4344
4345 mapData.Pointers[i] = newV;
4346 mapData.BasePointers[i] = newV;
4347 } break;
4348 case omp::VariableCaptureKind::This:
4349 case omp::VariableCaptureKind::VLAType:
4350 mapData.MapClause[i]->emitOpError(message: "Unhandled capture kind");
4351 break;
4352 }
4353 }
4354 }
4355}
4356
4357// Generate all map related information and fill the combinedInfo.
4358static void genMapInfos(llvm::IRBuilderBase &builder,
4359 LLVM::ModuleTranslation &moduleTranslation,
4360 DataLayout &dl, MapInfosTy &combinedInfo,
4361 MapInfoData &mapData, bool isTargetParams = false) {
4362 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4363 "function only supported for host device codegen");
4364
4365 // We wish to modify some of the methods in which arguments are
4366 // passed based on their capture type by the target region, this can
4367 // involve generating new loads and stores, which changes the
4368 // MLIR value to LLVM value mapping, however, we only wish to do this
4369 // locally for the current function/target and also avoid altering
4370 // ModuleTranslation, so we remap the base pointer or pointer stored
4371 // in the map infos corresponding MapInfoData, which is later accessed
4372 // by genMapInfos and createTarget to help generate the kernel and
4373 // kernel arg structure. It primarily becomes relevant in cases like
4374 // bycopy, or byref range'd arrays. In the default case, we simply
4375 // pass thee pointer byref as both basePointer and pointer.
4376 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
4377
4378 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4379
4380 // We operate under the assumption that all vectors that are
4381 // required in MapInfoData are of equal lengths (either filled with
4382 // default constructed data or appropiate information) so we can
4383 // utilise the size from any component of MapInfoData, if we can't
4384 // something is missing from the initial MapInfoData construction.
4385 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4386 // NOTE/TODO: We currently do not support arbitrary depth record
4387 // type mapping.
4388 if (mapData.IsAMember[i])
4389 continue;
4390
4391 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(Val: mapData.MapClause[i]);
4392 if (!mapInfoOp.getMembers().empty()) {
4393 processMapWithMembersOf(moduleTranslation, builder, ompBuilder&: *ompBuilder, dl,
4394 combinedInfo, mapData, mapDataIndex: i, isTargetParams);
4395 continue;
4396 }
4397
4398 processIndividualMap(mapData, mapDataIdx: i, combinedInfo, isTargetParams);
4399 }
4400}
4401
4402static llvm::Expected<llvm::Function *>
4403emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
4404 LLVM::ModuleTranslation &moduleTranslation,
4405 llvm::StringRef mapperFuncName);
4406
4407static llvm::Expected<llvm::Function *>
4408getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
4409 LLVM::ModuleTranslation &moduleTranslation) {
4410 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4411 "function only supported for host device codegen");
4412 auto declMapperOp = cast<omp::DeclareMapperOp>(Val: op);
4413 std::string mapperFuncName =
4414 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
4415 Parts: {"omp_mapper", declMapperOp.getSymName()});
4416
4417 if (auto *lookupFunc = moduleTranslation.lookupFunction(name: mapperFuncName))
4418 return lookupFunc;
4419
4420 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
4421 mapperFuncName);
4422}
4423
4424static llvm::Expected<llvm::Function *>
4425emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
4426 LLVM::ModuleTranslation &moduleTranslation,
4427 llvm::StringRef mapperFuncName) {
4428 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4429 "function only supported for host device codegen");
4430 auto declMapperOp = cast<omp::DeclareMapperOp>(Val: op);
4431 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4432 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
4433 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4434 llvm::Type *varType = moduleTranslation.convertType(type: declMapperOp.getType());
4435 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
4436
4437 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4438
4439 // Fill up the arrays with all the mapped variables.
4440 MapInfosTy combinedInfo;
4441 auto genMapInfoCB =
4442 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4443 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4444 builder.restoreIP(IP: codeGenIP);
4445 moduleTranslation.mapValue(mlir: declMapperOp.getSymVal(), llvm: ptrPHI);
4446 moduleTranslation.mapBlock(mlir: &declMapperOp.getRegion().front(),
4447 llvm: builder.GetInsertBlock());
4448 if (failed(Result: moduleTranslation.convertBlock(bb&: declMapperOp.getRegion().front(),
4449 /*ignoreArguments=*/true,
4450 builder)))
4451 return llvm::make_error<PreviouslyReportedError>();
4452 MapInfoData mapData;
4453 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
4454 builder);
4455 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4456
4457 // Drop the mapping that is no longer necessary so that the same region can
4458 // be processed multiple times.
4459 moduleTranslation.forgetMapping(region&: declMapperOp.getRegion());
4460 return combinedInfo;
4461 };
4462
4463 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
4464 if (!combinedInfo.Mappers[i])
4465 return nullptr;
4466 return getOrCreateUserDefinedMapperFunc(op: combinedInfo.Mappers[i], builder,
4467 moduleTranslation);
4468 };
4469
4470 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
4471 PrivAndGenMapInfoCB: genMapInfoCB, ElemTy: varType, FuncName: mapperFuncName, CustomMapperCB: customMapperCB);
4472 if (!newFn)
4473 return newFn.takeError();
4474 moduleTranslation.mapFunction(name: mapperFuncName, func: *newFn);
4475 return *newFn;
4476}
4477
4478static LogicalResult
4479convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
4480 LLVM::ModuleTranslation &moduleTranslation) {
4481 llvm::Value *ifCond = nullptr;
4482 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4483 SmallVector<Value> mapVars;
4484 SmallVector<Value> useDevicePtrVars;
4485 SmallVector<Value> useDeviceAddrVars;
4486 llvm::omp::RuntimeFunction RTLFn;
4487 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
4488
4489 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4490 llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
4491 /*SeparateBeginEndCalls=*/true);
4492 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4493 bool isOffloadEntry =
4494 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4495
4496 LogicalResult result =
4497 llvm::TypeSwitch<Operation *, LogicalResult>(op)
4498 .Case(caseFn: [&](omp::TargetDataOp dataOp) {
4499 if (failed(Result: checkImplementationStatus(op&: *dataOp)))
4500 return failure();
4501
4502 if (auto ifVar = dataOp.getIfExpr())
4503 ifCond = moduleTranslation.lookupValue(value: ifVar);
4504
4505 if (auto devId = dataOp.getDevice())
4506 if (auto constOp =
4507 dyn_cast<LLVM::ConstantOp>(Val: devId.getDefiningOp()))
4508 if (auto intAttr = dyn_cast<IntegerAttr>(Val: constOp.getValue()))
4509 deviceID = intAttr.getInt();
4510
4511 mapVars = dataOp.getMapVars();
4512 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4513 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4514 return success();
4515 })
4516 .Case(caseFn: [&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4517 if (failed(Result: checkImplementationStatus(op&: *enterDataOp)))
4518 return failure();
4519
4520 if (auto ifVar = enterDataOp.getIfExpr())
4521 ifCond = moduleTranslation.lookupValue(value: ifVar);
4522
4523 if (auto devId = enterDataOp.getDevice())
4524 if (auto constOp =
4525 dyn_cast<LLVM::ConstantOp>(Val: devId.getDefiningOp()))
4526 if (auto intAttr = dyn_cast<IntegerAttr>(Val: constOp.getValue()))
4527 deviceID = intAttr.getInt();
4528 RTLFn =
4529 enterDataOp.getNowait()
4530 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4531 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4532 mapVars = enterDataOp.getMapVars();
4533 info.HasNoWait = enterDataOp.getNowait();
4534 return success();
4535 })
4536 .Case(caseFn: [&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4537 if (failed(Result: checkImplementationStatus(op&: *exitDataOp)))
4538 return failure();
4539
4540 if (auto ifVar = exitDataOp.getIfExpr())
4541 ifCond = moduleTranslation.lookupValue(value: ifVar);
4542
4543 if (auto devId = exitDataOp.getDevice())
4544 if (auto constOp =
4545 dyn_cast<LLVM::ConstantOp>(Val: devId.getDefiningOp()))
4546 if (auto intAttr = dyn_cast<IntegerAttr>(Val: constOp.getValue()))
4547 deviceID = intAttr.getInt();
4548
4549 RTLFn = exitDataOp.getNowait()
4550 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4551 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4552 mapVars = exitDataOp.getMapVars();
4553 info.HasNoWait = exitDataOp.getNowait();
4554 return success();
4555 })
4556 .Case(caseFn: [&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4557 if (failed(Result: checkImplementationStatus(op&: *updateDataOp)))
4558 return failure();
4559
4560 if (auto ifVar = updateDataOp.getIfExpr())
4561 ifCond = moduleTranslation.lookupValue(value: ifVar);
4562
4563 if (auto devId = updateDataOp.getDevice())
4564 if (auto constOp =
4565 dyn_cast<LLVM::ConstantOp>(Val: devId.getDefiningOp()))
4566 if (auto intAttr = dyn_cast<IntegerAttr>(Val: constOp.getValue()))
4567 deviceID = intAttr.getInt();
4568
4569 RTLFn =
4570 updateDataOp.getNowait()
4571 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4572 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4573 mapVars = updateDataOp.getMapVars();
4574 info.HasNoWait = updateDataOp.getNowait();
4575 return success();
4576 })
4577 .Default(defaultFn: [&](Operation *op) {
4578 llvm_unreachable("unexpected operation");
4579 return failure();
4580 });
4581
4582 if (failed(Result: result))
4583 return failure();
4584 // Pretend we have IF(false) if we're not doing offload.
4585 if (!isOffloadEntry)
4586 ifCond = builder.getFalse();
4587
4588 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4589 MapInfoData mapData;
4590 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl&: DL,
4591 builder, useDevPtrOperands: useDevicePtrVars, useDevAddrOperands: useDeviceAddrVars);
4592
4593 // Fill up the arrays with all the mapped variables.
4594 MapInfosTy combinedInfo;
4595 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4596 builder.restoreIP(IP: codeGenIP);
4597 genMapInfos(builder, moduleTranslation, dl&: DL, combinedInfo, mapData);
4598 return combinedInfo;
4599 };
4600
4601 // Define a lambda to apply mappings between use_device_addr and
4602 // use_device_ptr base pointers, and their associated block arguments.
4603 auto mapUseDevice =
4604 [&moduleTranslation](
4605 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4606 llvm::ArrayRef<BlockArgument> blockArgs,
4607 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
4608 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
4609 for (auto [arg, useDevVar] :
4610 llvm::zip_equal(t&: blockArgs, u&: useDeviceVars)) {
4611
4612 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4613 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4614 : mapInfoOp.getVarPtr();
4615 };
4616
4617 auto useDevMap = cast<omp::MapInfoOp>(Val: useDevVar.getDefiningOp());
4618 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4619 t&: mapInfoData.MapClause, u&: mapInfoData.DevicePointers,
4620 args&: mapInfoData.BasePointers)) {
4621 auto mapOp = cast<omp::MapInfoOp>(Val: mapClause);
4622 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4623 devicePointer != type)
4624 continue;
4625
4626 if (llvm::Value *devPtrInfoMap =
4627 mapper ? mapper(basePointer) : basePointer) {
4628 moduleTranslation.mapValue(mlir: arg, llvm: devPtrInfoMap);
4629 break;
4630 }
4631 }
4632 }
4633 };
4634
4635 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4636 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4637 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4638 // We must always restoreIP regardless of doing anything the caller
4639 // does not restore it, leading to incorrect (no) branch generation.
4640 builder.restoreIP(IP: codeGenIP);
4641 assert(isa<omp::TargetDataOp>(op) &&
4642 "BodyGen requested for non TargetDataOp");
4643 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(Val: op);
4644 Region &region = cast<omp::TargetDataOp>(Val: op).getRegion();
4645 switch (bodyGenType) {
4646 case BodyGenTy::Priv:
4647 // Check if any device ptr/addr info is available
4648 if (!info.DevicePtrInfoMap.empty()) {
4649 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4650 blockArgIface.getUseDeviceAddrBlockArgs(),
4651 useDeviceAddrVars, mapData,
4652 [&](llvm::Value *basePointer) -> llvm::Value * {
4653 if (!info.DevicePtrInfoMap[basePointer].second)
4654 return nullptr;
4655 return builder.CreateLoad(
4656 Ty: builder.getPtrTy(),
4657 Ptr: info.DevicePtrInfoMap[basePointer].second);
4658 });
4659 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4660 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4661 mapData, [&](llvm::Value *basePointer) {
4662 return info.DevicePtrInfoMap[basePointer].second;
4663 });
4664
4665 if (failed(Result: inlineConvertOmpRegions(region, blockName: "omp.data.region", builder,
4666 moduleTranslation)))
4667 return llvm::make_error<PreviouslyReportedError>();
4668 }
4669 break;
4670 case BodyGenTy::DupNoPriv:
4671 if (info.DevicePtrInfoMap.empty()) {
4672 // For host device we still need to do the mapping for codegen,
4673 // otherwise it may try to lookup a missing value.
4674 if (!ompBuilder->Config.IsTargetDevice.value_or(u: false)) {
4675 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4676 blockArgIface.getUseDeviceAddrBlockArgs(),
4677 useDeviceAddrVars, mapData);
4678 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4679 blockArgIface.getUseDevicePtrBlockArgs(),
4680 useDevicePtrVars, mapData);
4681 }
4682 }
4683 break;
4684 case BodyGenTy::NoPriv:
4685 // If device info is available then region has already been generated
4686 if (info.DevicePtrInfoMap.empty()) {
4687 // For device pass, if use_device_ptr(addr) mappings were present,
4688 // we need to link them here before codegen.
4689 if (ompBuilder->Config.IsTargetDevice.value_or(u: false)) {
4690 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4691 blockArgIface.getUseDeviceAddrBlockArgs(),
4692 useDeviceAddrVars, mapData);
4693 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4694 blockArgIface.getUseDevicePtrBlockArgs(),
4695 useDevicePtrVars, mapData);
4696 }
4697
4698 if (failed(Result: inlineConvertOmpRegions(region, blockName: "omp.data.region", builder,
4699 moduleTranslation)))
4700 return llvm::make_error<PreviouslyReportedError>();
4701 }
4702 break;
4703 }
4704 return builder.saveIP();
4705 };
4706
4707 auto customMapperCB =
4708 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
4709 if (!combinedInfo.Mappers[i])
4710 return nullptr;
4711 info.HasMapper = true;
4712 return getOrCreateUserDefinedMapperFunc(op: combinedInfo.Mappers[i], builder,
4713 moduleTranslation);
4714 };
4715
4716 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4717 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4718 findAllocaInsertPoint(builder, moduleTranslation);
4719 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4720 if (isa<omp::TargetDataOp>(Val: op))
4721 return ompBuilder->createTargetData(Loc: ompLoc, AllocaIP: allocaIP, CodeGenIP: builder.saveIP(),
4722 DeviceID: builder.getInt64(C: deviceID), IfCond: ifCond,
4723 Info&: info, GenMapInfoCB: genMapInfoCB, CustomMapperCB: customMapperCB,
4724 /*MapperFunc=*/nullptr, BodyGenCB: bodyGenCB,
4725 /*DeviceAddrCB=*/nullptr);
4726 return ompBuilder->createTargetData(
4727 Loc: ompLoc, AllocaIP: allocaIP, CodeGenIP: builder.saveIP(), DeviceID: builder.getInt64(C: deviceID), IfCond: ifCond,
4728 Info&: info, GenMapInfoCB: genMapInfoCB, CustomMapperCB: customMapperCB, MapperFunc: &RTLFn);
4729 }();
4730
4731 if (failed(Result: handleError(result&: afterIP, op&: *op)))
4732 return failure();
4733
4734 builder.restoreIP(IP: *afterIP);
4735 return success();
4736}
4737
4738static LogicalResult
4739convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
4740 LLVM::ModuleTranslation &moduleTranslation) {
4741 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4742 auto distributeOp = cast<omp::DistributeOp>(Val&: opInst);
4743 if (failed(Result: checkImplementationStatus(op&: opInst)))
4744 return failure();
4745
4746 /// Process teams op reduction in distribute if the reduction is contained in
4747 /// the distribute op.
4748 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
4749 bool doDistributeReduction =
4750 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
4751
4752 DenseMap<Value, llvm::Value *> reductionVariableMap;
4753 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4754 SmallVector<omp::DeclareReductionOp> reductionDecls;
4755 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
4756 llvm::ArrayRef<bool> isByRef;
4757
4758 if (doDistributeReduction) {
4759 isByRef = getIsByRef(attr: teamsOp.getReductionByref());
4760 assert(isByRef.size() == teamsOp.getNumReductionVars());
4761
4762 collectReductionDecls(op: teamsOp, reductions&: reductionDecls);
4763 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4764 findAllocaInsertPoint(builder, moduleTranslation);
4765
4766 MutableArrayRef<BlockArgument> reductionArgs =
4767 llvm::cast<omp::BlockArgOpenMPOpInterface>(Val&: *teamsOp)
4768 .getReductionBlockArgs();
4769
4770 if (failed(Result: allocAndInitializeReductionVars(
4771 op: teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4772 reductionDecls, privateReductionVariables, reductionVariableMap,
4773 isByRef)))
4774 return failure();
4775 }
4776
4777 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4778 auto bodyGenCB = [&](InsertPointTy allocaIP,
4779 InsertPointTy codeGenIP) -> llvm::Error {
4780 // Save the alloca insertion point on ModuleTranslation stack for use in
4781 // nested regions.
4782 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
4783 moduleTranslation, allocaIP);
4784
4785 // DistributeOp has only one region associated with it.
4786 builder.restoreIP(IP: codeGenIP);
4787 PrivateVarsInfo privVarsInfo(distributeOp);
4788
4789 llvm::Expected<llvm::BasicBlock *> afterAllocas =
4790 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo&: privVarsInfo, allocaIP);
4791 if (handleError(result&: afterAllocas, op&: opInst).failed())
4792 return llvm::make_error<PreviouslyReportedError>();
4793
4794 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo&: privVarsInfo),
4795 op&: opInst)
4796 .failed())
4797 return llvm::make_error<PreviouslyReportedError>();
4798
4799 if (failed(Result: copyFirstPrivateVars(
4800 op: distributeOp, builder, moduleTranslation, mlirPrivateVars&: privVarsInfo.mlirVars,
4801 llvmPrivateVars: privVarsInfo.llvmVars, privateDecls&: privVarsInfo.privatizers,
4802 insertBarrier: distributeOp.getPrivateNeedsBarrier())))
4803 return llvm::make_error<PreviouslyReportedError>();
4804
4805 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4806 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4807 llvm::Expected<llvm::BasicBlock *> regionBlock =
4808 convertOmpOpRegions(region&: distributeOp.getRegion(), blockName: "omp.distribute.region",
4809 builder, moduleTranslation);
4810 if (!regionBlock)
4811 return regionBlock.takeError();
4812 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
4813
4814 // Skip applying a workshare loop below when translating 'distribute
4815 // parallel do' (it's been already handled by this point while translating
4816 // the nested omp.wsloop).
4817 if (!isa_and_present<omp::WsloopOp>(Val: distributeOp.getNestedWrapper())) {
4818 // TODO: Add support for clauses which are valid for DISTRIBUTE
4819 // constructs. Static schedule is the default.
4820 auto schedule = omp::ClauseScheduleKind::Static;
4821 bool isOrdered = false;
4822 std::optional<omp::ScheduleModifier> scheduleMod;
4823 bool isSimd = false;
4824 llvm::omp::WorksharingLoopType workshareLoopType =
4825 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4826 bool loopNeedsBarrier = false;
4827 llvm::Value *chunk = nullptr;
4828
4829 llvm::CanonicalLoopInfo *loopInfo =
4830 findCurrentLoopInfo(moduleTranslation);
4831 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4832 ompBuilder->applyWorkshareLoop(
4833 DL: ompLoc.DL, CLI: loopInfo, AllocaIP: allocaIP, NeedsBarrier: loopNeedsBarrier,
4834 SchedKind: convertToScheduleKind(schedKind: schedule), ChunkSize: chunk, HasSimdModifier: isSimd,
4835 HasMonotonicModifier: scheduleMod == omp::ScheduleModifier::monotonic,
4836 HasNonmonotonicModifier: scheduleMod == omp::ScheduleModifier::nonmonotonic, HasOrderedClause: isOrdered,
4837 LoopType: workshareLoopType);
4838
4839 if (!wsloopIP)
4840 return wsloopIP.takeError();
4841 }
4842
4843 if (failed(Result: cleanupPrivateVars(builder, moduleTranslation,
4844 loc: distributeOp.getLoc(), llvmPrivateVars&: privVarsInfo.llvmVars,
4845 privateDecls&: privVarsInfo.privatizers)))
4846 return llvm::make_error<PreviouslyReportedError>();
4847
4848 return llvm::Error::success();
4849 };
4850
4851 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4852 findAllocaInsertPoint(builder, moduleTranslation);
4853 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4854 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4855 ompBuilder->createDistribute(Loc: ompLoc, AllocaIP: allocaIP, BodyGenCB: bodyGenCB);
4856
4857 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
4858 return failure();
4859
4860 builder.restoreIP(IP: *afterIP);
4861
4862 if (doDistributeReduction) {
4863 // Process the reductions if required.
4864 return createReductionsAndCleanup(
4865 op: teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4866 privateReductionVariables, isByRef,
4867 /*isNoWait*/ isNowait: false, /*isTeamsReduction*/ true);
4868 }
4869 return success();
4870}
4871
4872/// Lowers the FlagsAttr which is applied to the module on the device
4873/// pass when offloading, this attribute contains OpenMP RTL globals that can
4874/// be passed as flags to the frontend, otherwise they are set to default
4875LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
4876 LLVM::ModuleTranslation &moduleTranslation) {
4877 if (!cast<mlir::ModuleOp>(Val: op))
4878 return failure();
4879
4880 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4881
4882 ompBuilder->M.addModuleFlag(Behavior: llvm::Module::Max, Key: "openmp-device",
4883 Val: attribute.getOpenmpDeviceVersion());
4884
4885 if (attribute.getNoGpuLib())
4886 return success();
4887
4888 ompBuilder->createGlobalFlag(
4889 Value: attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
4890 Name: "__omp_rtl_debug_kind");
4891 ompBuilder->createGlobalFlag(
4892 Value: attribute
4893 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
4894 ,
4895 Name: "__omp_rtl_assume_teams_oversubscription");
4896 ompBuilder->createGlobalFlag(
4897 Value: attribute
4898 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
4899 ,
4900 Name: "__omp_rtl_assume_threads_oversubscription");
4901 ompBuilder->createGlobalFlag(
4902 Value: attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
4903 Name: "__omp_rtl_assume_no_thread_state");
4904 ompBuilder->createGlobalFlag(
4905 Value: attribute
4906 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
4907 ,
4908 Name: "__omp_rtl_assume_no_nested_parallelism");
4909 return success();
4910}
4911
4912static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
4913 omp::TargetOp targetOp,
4914 llvm::StringRef parentName = "") {
4915 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
4916
4917 assert(fileLoc && "No file found from location");
4918 StringRef fileName = fileLoc.getFilename().getValue();
4919
4920 llvm::sys::fs::UniqueID id;
4921 uint64_t line = fileLoc.getLine();
4922 if (auto ec = llvm::sys::fs::getUniqueID(Path: fileName, Result&: id)) {
4923 size_t fileHash = llvm::hash_value(arg: fileName.str());
4924 size_t deviceId = 0xdeadf17e;
4925 targetInfo =
4926 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4927 } else {
4928 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
4929 id.getFile(), line);
4930 }
4931}
4932
4933static void
4934handleDeclareTargetMapVar(MapInfoData &mapData,
4935 LLVM::ModuleTranslation &moduleTranslation,
4936 llvm::IRBuilderBase &builder, llvm::Function *func) {
4937 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4938 "function only supported for target device codegen");
4939 llvm::IRBuilderBase::InsertPointGuard guard(builder);
4940 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4941 // In the case of declare target mapped variables, the basePointer is
4942 // the reference pointer generated by the convertDeclareTargetAttr
4943 // method. Whereas the kernelValue is the original variable, so for
4944 // the device we must replace all uses of this original global variable
4945 // (stored in kernelValue) with the reference pointer (stored in
4946 // basePointer for declare target mapped variables), as for device the
4947 // data is mapped into this reference pointer and should be loaded
4948 // from it, the original variable is discarded. On host both exist and
4949 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
4950 // function to link the two variables in the runtime and then both the
4951 // reference pointer and the pointer are assigned in the kernel argument
4952 // structure for the host.
4953 if (mapData.IsDeclareTarget[i]) {
4954 // If the original map value is a constant, then we have to make sure all
4955 // of it's uses within the current kernel/function that we are going to
4956 // rewrite are converted to instructions, as we will be altering the old
4957 // use (OriginalValue) from a constant to an instruction, which will be
4958 // illegal and ICE the compiler if the user is a constant expression of
4959 // some kind e.g. a constant GEP.
4960 if (auto *constant = dyn_cast<llvm::Constant>(Val: mapData.OriginalValue[i]))
4961 convertUsersOfConstantsToInstructions(Consts: constant, RestrictToFunc: func, RemoveDeadConstants: false);
4962
4963 // The users iterator will get invalidated if we modify an element,
4964 // so we populate this vector of uses to alter each user on an
4965 // individual basis to emit its own load (rather than one load for
4966 // all).
4967 llvm::SmallVector<llvm::User *> userVec;
4968 for (llvm::User *user : mapData.OriginalValue[i]->users())
4969 userVec.push_back(Elt: user);
4970
4971 for (llvm::User *user : userVec) {
4972 if (auto *insn = dyn_cast<llvm::Instruction>(Val: user)) {
4973 if (insn->getFunction() == func) {
4974 builder.SetCurrentDebugLocation(insn->getDebugLoc());
4975 auto *load = builder.CreateLoad(Ty: mapData.BasePointers[i]->getType(),
4976 Ptr: mapData.BasePointers[i]);
4977 load->moveBefore(InsertPos: insn->getIterator());
4978 user->replaceUsesOfWith(From: mapData.OriginalValue[i], To: load);
4979 }
4980 }
4981 }
4982 }
4983 }
4984}
4985
4986// The createDeviceArgumentAccessor function generates
4987// instructions for retrieving (acessing) kernel
4988// arguments inside of the device kernel for use by
4989// the kernel. This enables different semantics such as
4990// the creation of temporary copies of data allowing
4991// semantics like read-only/no host write back kernel
4992// arguments.
4993//
4994// This currently implements a very light version of Clang's
4995// EmitParmDecl's handling of direct argument handling as well
4996// as a portion of the argument access generation based on
4997// capture types found at the end of emitOutlinedFunctionPrologue
4998// in Clang. The indirect path handling of EmitParmDecl's may be
4999// required for future work, but a direct 1-to-1 copy doesn't seem
5000// possible as the logic is rather scattered throughout Clang's
5001// lowering and perhaps we wish to deviate slightly.
5002//
5003// \param mapData - A container containing vectors of information
5004// corresponding to the input argument, which should have a
5005// corresponding entry in the MapInfoData containers
5006// OrigialValue's.
5007// \param arg - This is the generated kernel function argument that
5008// corresponds to the passed in input argument. We generated different
5009// accesses of this Argument, based on capture type and other Input
5010// related information.
5011// \param input - This is the host side value that will be passed to
5012// the kernel i.e. the kernel input, we rewrite all uses of this within
5013// the kernel (as we generate the kernel body based on the target's region
5014// which maintians references to the original input) to the retVal argument
5015// apon exit of this function inside of the OMPIRBuilder. This interlinks
5016// the kernel argument to future uses of it in the function providing
5017// appropriate "glue" instructions inbetween.
5018// \param retVal - This is the value that all uses of input inside of the
5019// kernel will be re-written to, the goal of this function is to generate
5020// an appropriate location for the kernel argument to be accessed from,
5021// e.g. ByRef will result in a temporary allocation location and then
5022// a store of the kernel argument into this allocated memory which
5023// will then be loaded from, ByCopy will use the allocated memory
5024// directly.
5025static llvm::IRBuilderBase::InsertPoint
5026createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
5027 llvm::Value *input, llvm::Value *&retVal,
5028 llvm::IRBuilderBase &builder,
5029 llvm::OpenMPIRBuilder &ompBuilder,
5030 LLVM::ModuleTranslation &moduleTranslation,
5031 llvm::IRBuilderBase::InsertPoint allocaIP,
5032 llvm::IRBuilderBase::InsertPoint codeGenIP) {
5033 assert(ompBuilder.Config.isTargetDevice() &&
5034 "function only supported for target device codegen");
5035 builder.restoreIP(IP: allocaIP);
5036
5037 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
5038 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
5039 ompBuilder.M.getContext());
5040 unsigned alignmentValue = 0;
5041 // Find the associated MapInfoData entry for the current input
5042 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
5043 if (mapData.OriginalValue[i] == input) {
5044 auto mapOp = cast<omp::MapInfoOp>(Val: mapData.MapClause[i]);
5045 capture = mapOp.getMapCaptureType();
5046 // Get information of alignment of mapped object
5047 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
5048 type: mapOp.getVarType(), layout: ompBuilder.M.getDataLayout());
5049 break;
5050 }
5051
5052 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
5053 unsigned int defaultAS =
5054 ompBuilder.M.getDataLayout().getProgramAddressSpace();
5055
5056 // Create the alloca for the argument the current point.
5057 llvm::Value *v = builder.CreateAlloca(Ty: arg.getType(), AddrSpace: allocaAS);
5058
5059 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
5060 v = builder.CreateAddrSpaceCast(V: v, DestTy: builder.getPtrTy(AddrSpace: defaultAS));
5061
5062 builder.CreateStore(Val: &arg, Ptr: v);
5063
5064 builder.restoreIP(IP: codeGenIP);
5065
5066 switch (capture) {
5067 case omp::VariableCaptureKind::ByCopy: {
5068 retVal = v;
5069 break;
5070 }
5071 case omp::VariableCaptureKind::ByRef: {
5072 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
5073 Ty: v->getType(), Ptr: v,
5074 Align: ompBuilder.M.getDataLayout().getPrefTypeAlign(Ty: v->getType()));
5075 // CreateAlignedLoad function creates similar LLVM IR:
5076 // %res = load ptr, ptr %input, align 8
5077 // This LLVM IR does not contain information about alignment
5078 // of the loaded value. We need to add !align metadata to unblock
5079 // optimizer. The existence of the !align metadata on the instruction
5080 // tells the optimizer that the value loaded is known to be aligned to
5081 // a boundary specified by the integer value in the metadata node.
5082 // Example:
5083 // %res = load ptr, ptr %input, align 8, !align !align_md_node
5084 // ^ ^
5085 // | |
5086 // alignment of %input address |
5087 // |
5088 // alignment of %res object
5089 if (v->getType()->isPointerTy() && alignmentValue) {
5090 llvm::MDBuilder MDB(builder.getContext());
5091 loadInst->setMetadata(
5092 KindID: llvm::LLVMContext::MD_align,
5093 Node: llvm::MDNode::get(Context&: builder.getContext(),
5094 MDs: MDB.createConstant(C: llvm::ConstantInt::get(
5095 Ty: llvm::Type::getInt64Ty(C&: builder.getContext()),
5096 V: alignmentValue))));
5097 }
5098 retVal = loadInst;
5099
5100 break;
5101 }
5102 case omp::VariableCaptureKind::This:
5103 case omp::VariableCaptureKind::VLAType:
5104 // TODO: Consider returning error to use standard reporting for
5105 // unimplemented features.
5106 assert(false && "Currently unsupported capture kind");
5107 break;
5108 }
5109
5110 return builder.saveIP();
5111}
5112
5113/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
5114/// operation and populate output variables with their corresponding host value
5115/// (i.e. operand evaluated outside of the target region), based on their uses
5116/// inside of the target region.
5117///
5118/// Loop bounds and steps are only optionally populated, if output vectors are
5119/// provided.
5120static void
5121extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
5122 Value &numTeamsLower, Value &numTeamsUpper,
5123 Value &threadLimit,
5124 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
5125 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
5126 llvm::SmallVectorImpl<Value> *steps = nullptr) {
5127 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(Val&: *targetOp);
5128 for (auto item : llvm::zip_equal(t: targetOp.getHostEvalVars(),
5129 u: blockArgIface.getHostEvalBlockArgs())) {
5130 Value hostEvalVar = std::get<0>(t&: item), blockArg = std::get<1>(t&: item);
5131
5132 for (Operation *user : blockArg.getUsers()) {
5133 llvm::TypeSwitch<Operation *>(user)
5134 .Case(caseFn: [&](omp::TeamsOp teamsOp) {
5135 if (teamsOp.getNumTeamsLower() == blockArg)
5136 numTeamsLower = hostEvalVar;
5137 else if (teamsOp.getNumTeamsUpper() == blockArg)
5138 numTeamsUpper = hostEvalVar;
5139 else if (teamsOp.getThreadLimit() == blockArg)
5140 threadLimit = hostEvalVar;
5141 else
5142 llvm_unreachable("unsupported host_eval use");
5143 })
5144 .Case(caseFn: [&](omp::ParallelOp parallelOp) {
5145 if (parallelOp.getNumThreads() == blockArg)
5146 numThreads = hostEvalVar;
5147 else
5148 llvm_unreachable("unsupported host_eval use");
5149 })
5150 .Case(caseFn: [&](omp::LoopNestOp loopOp) {
5151 auto processBounds =
5152 [&](OperandRange opBounds,
5153 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
5154 bool found = false;
5155 for (auto [i, lb] : llvm::enumerate(First&: opBounds)) {
5156 if (lb == blockArg) {
5157 found = true;
5158 if (outBounds)
5159 (*outBounds)[i] = hostEvalVar;
5160 }
5161 }
5162 return found;
5163 };
5164 bool found =
5165 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5166 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5167 found;
5168 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5169 (void)found;
5170 assert(found && "unsupported host_eval use");
5171 })
5172 .Default(defaultFn: [](Operation *) {
5173 llvm_unreachable("unsupported host_eval use");
5174 });
5175 }
5176 }
5177}
5178
5179/// If \p op is of the given type parameter, return it casted to that type.
5180/// Otherwise, if its immediate parent operation (or some other higher-level
5181/// parent, if \p immediateParent is false) is of that type, return that parent
5182/// casted to the given type.
5183///
5184/// If \p op is \c null or neither it or its parent(s) are of the specified
5185/// type, return a \c null operation.
5186template <typename OpTy>
5187static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
5188 if (!op)
5189 return OpTy();
5190
5191 if (OpTy casted = dyn_cast<OpTy>(op))
5192 return casted;
5193
5194 if (immediateParent)
5195 return dyn_cast_if_present<OpTy>(op->getParentOp());
5196
5197 return op->getParentOfType<OpTy>();
5198}
5199
5200/// If the given \p value is defined by an \c llvm.mlir.constant operation and
5201/// it is of an integer type, return its value.
5202static std::optional<int64_t> extractConstInteger(Value value) {
5203 if (!value)
5204 return std::nullopt;
5205
5206 if (auto constOp =
5207 dyn_cast_if_present<LLVM::ConstantOp>(Val: value.getDefiningOp()))
5208 if (auto constAttr = dyn_cast<IntegerAttr>(Val: constOp.getValue()))
5209 return constAttr.getInt();
5210
5211 return std::nullopt;
5212}
5213
5214static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
5215 uint64_t sizeInBits = dl.getTypeSizeInBits(t: type);
5216 uint64_t sizeInBytes = sizeInBits / 8;
5217 return sizeInBytes;
5218}
5219
5220template <typename OpTy>
5221static uint64_t getReductionDataSize(OpTy &op) {
5222 if (op.getNumReductionVars() > 0) {
5223 SmallVector<omp::DeclareReductionOp> reductions;
5224 collectReductionDecls(op, reductions);
5225
5226 llvm::SmallVector<mlir::Type> members;
5227 members.reserve(N: reductions.size());
5228 for (omp::DeclareReductionOp &red : reductions)
5229 members.push_back(Elt: red.getType());
5230 Operation *opp = op.getOperation();
5231 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5232 context: opp->getContext(), types: members, /*isPacked=*/false);
5233 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
5234 return getTypeByteSize(type: structType, dl);
5235 }
5236 return 0;
5237}
5238
5239/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
5240/// values as stated by the corresponding clauses, if constant.
5241///
5242/// These default values must be set before the creation of the outlined LLVM
5243/// function for the target region, so that they can be used to initialize the
5244/// corresponding global `ConfigurationEnvironmentTy` structure.
5245static void
5246initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
5247 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5248 bool isTargetDevice, bool isGPU) {
5249 // TODO: Handle constant 'if' clauses.
5250
5251 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5252 if (!isTargetDevice) {
5253 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
5254 threadLimit);
5255 } else {
5256 // In the target device, values for these clauses are not passed as
5257 // host_eval, but instead evaluated prior to entry to the region. This
5258 // ensures values are mapped and available inside of the target region.
5259 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(op: capturedOp)) {
5260 numTeamsLower = teamsOp.getNumTeamsLower();
5261 numTeamsUpper = teamsOp.getNumTeamsUpper();
5262 threadLimit = teamsOp.getThreadLimit();
5263 }
5264
5265 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(op: capturedOp))
5266 numThreads = parallelOp.getNumThreads();
5267 }
5268
5269 // Handle clauses impacting the number of teams.
5270
5271 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5272 if (castOrGetParentOfType<omp::TeamsOp>(op: capturedOp)) {
5273 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
5274 // clang and set min and max to the same value.
5275 if (numTeamsUpper) {
5276 if (auto val = extractConstInteger(value: numTeamsUpper))
5277 minTeamsVal = maxTeamsVal = *val;
5278 } else {
5279 minTeamsVal = maxTeamsVal = 0;
5280 }
5281 } else if (castOrGetParentOfType<omp::ParallelOp>(op: capturedOp,
5282 /*immediateParent=*/true) ||
5283 castOrGetParentOfType<omp::SimdOp>(op: capturedOp,
5284 /*immediateParent=*/true)) {
5285 minTeamsVal = maxTeamsVal = 1;
5286 } else {
5287 minTeamsVal = maxTeamsVal = -1;
5288 }
5289
5290 // Handle clauses impacting the number of threads.
5291
5292 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
5293 if (!clauseValue)
5294 return;
5295
5296 if (auto val = extractConstInteger(value: clauseValue))
5297 result = *val;
5298
5299 // Found an applicable clause, so it's not undefined. Mark as unknown
5300 // because it's not constant.
5301 if (result < 0)
5302 result = 0;
5303 };
5304
5305 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
5306 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5307 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5308 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5309
5310 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
5311 int32_t maxThreadsVal = -1;
5312 if (castOrGetParentOfType<omp::ParallelOp>(op: capturedOp))
5313 setMaxValueFromClause(numThreads, maxThreadsVal);
5314 else if (castOrGetParentOfType<omp::SimdOp>(op: capturedOp,
5315 /*immediateParent=*/true))
5316 maxThreadsVal = 1;
5317
5318 // For max values, < 0 means unset, == 0 means set but unknown. Select the
5319 // minimum value between 'max_threads' and 'thread_limit' clauses that were
5320 // set.
5321 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5322 if (combinedMaxThreadsVal < 0 ||
5323 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5324 combinedMaxThreadsVal = teamsThreadLimitVal;
5325
5326 if (combinedMaxThreadsVal < 0 ||
5327 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5328 combinedMaxThreadsVal = maxThreadsVal;
5329
5330 int32_t reductionDataSize = 0;
5331 if (isGPU && capturedOp) {
5332 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(op: capturedOp))
5333 reductionDataSize = getReductionDataSize(op&: teamsOp);
5334 }
5335
5336 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
5337 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5338 assert(
5339 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5340 omp::TargetRegionFlags::spmd) &&
5341 "invalid kernel flags");
5342 attrs.ExecFlags =
5343 omp::bitEnumContainsAny(bits: kernelFlags, bit: omp::TargetRegionFlags::generic)
5344 ? omp::bitEnumContainsAny(bits: kernelFlags, bit: omp::TargetRegionFlags::spmd)
5345 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5346 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5347 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5348 attrs.MinTeams = minTeamsVal;
5349 attrs.MaxTeams.front() = maxTeamsVal;
5350 attrs.MinThreads = 1;
5351 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5352 attrs.ReductionDataSize = reductionDataSize;
5353 // TODO: Allow modified buffer length similar to
5354 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
5355 if (attrs.ReductionDataSize != 0)
5356 attrs.ReductionBufferLength = 1024;
5357}
5358
5359/// Gather LLVM runtime values for all clauses evaluated in the host that are
5360/// passed to the kernel invocation.
5361///
5362/// This function must be called only when compiling for the host. Also, it will
5363/// only provide correct results if it's called after the body of \c targetOp
5364/// has been fully generated.
5365static void
5366initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
5367 LLVM::ModuleTranslation &moduleTranslation,
5368 omp::TargetOp targetOp, Operation *capturedOp,
5369 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5370 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(op: capturedOp);
5371 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5372
5373 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5374 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
5375 steps(numLoops);
5376 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
5377 threadLimit&: teamsThreadLimit, lowerBounds: &lowerBounds, upperBounds: &upperBounds, steps: &steps);
5378
5379 // TODO: Handle constant 'if' clauses.
5380 if (Value targetThreadLimit = targetOp.getThreadLimit())
5381 attrs.TargetThreadLimit.front() =
5382 moduleTranslation.lookupValue(value: targetThreadLimit);
5383
5384 if (numTeamsLower)
5385 attrs.MinTeams = moduleTranslation.lookupValue(value: numTeamsLower);
5386
5387 if (numTeamsUpper)
5388 attrs.MaxTeams.front() = moduleTranslation.lookupValue(value: numTeamsUpper);
5389
5390 if (teamsThreadLimit)
5391 attrs.TeamsThreadLimit.front() =
5392 moduleTranslation.lookupValue(value: teamsThreadLimit);
5393
5394 if (numThreads)
5395 attrs.MaxThreads = moduleTranslation.lookupValue(value: numThreads);
5396
5397 if (omp::bitEnumContainsAny(bits: targetOp.getKernelExecFlags(capturedOp),
5398 bit: omp::TargetRegionFlags::trip_count)) {
5399 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5400 attrs.LoopTripCount = nullptr;
5401
5402 // To calculate the trip count, we multiply together the trip counts of
5403 // every collapsed canonical loop. We don't need to create the loop nests
5404 // here, since we're only interested in the trip count.
5405 for (auto [loopLower, loopUpper, loopStep] :
5406 llvm::zip_equal(t&: lowerBounds, u&: upperBounds, args&: steps)) {
5407 llvm::Value *lowerBound = moduleTranslation.lookupValue(value: loopLower);
5408 llvm::Value *upperBound = moduleTranslation.lookupValue(value: loopUpper);
5409 llvm::Value *step = moduleTranslation.lookupValue(value: loopStep);
5410
5411 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5412 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5413 Loc: loc, Start: lowerBound, Stop: upperBound, Step: step, /*IsSigned=*/true,
5414 InclusiveStop: loopOp.getLoopInclusive());
5415
5416 if (!attrs.LoopTripCount) {
5417 attrs.LoopTripCount = tripCount;
5418 continue;
5419 }
5420
5421 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
5422 attrs.LoopTripCount = builder.CreateMul(LHS: attrs.LoopTripCount, RHS: tripCount,
5423 Name: {}, /*HasNUW=*/true);
5424 }
5425 }
5426}
5427
5428static LogicalResult
5429convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
5430 LLVM::ModuleTranslation &moduleTranslation) {
5431 auto targetOp = cast<omp::TargetOp>(Val&: opInst);
5432 // The current debug location already has the DISubprogram for the outlined
5433 // function that will be created for the target op. We save it here so that
5434 // we can set it on the outlined function.
5435 llvm::DebugLoc outlinedFnLoc = builder.getCurrentDebugLocation();
5436 if (failed(Result: checkImplementationStatus(op&: opInst)))
5437 return failure();
5438
5439 // During the handling of target op, we will generate instructions in the
5440 // parent function like call to the oulined function or branch to a new
5441 // BasicBlock. We set the debug location here to parent function so that those
5442 // get the correct debug locations. For outlined functions, the normal MLIR op
5443 // conversion will automatically pick the correct location.
5444 llvm::BasicBlock *parentBB = builder.GetInsertBlock();
5445 assert(parentBB && "No insert block is set for the builder");
5446 llvm::Function *parentLLVMFn = parentBB->getParent();
5447 assert(parentLLVMFn && "Parent Function must be valid");
5448 if (llvm::DISubprogram *SP = parentLLVMFn->getSubprogram())
5449 builder.SetCurrentDebugLocation(llvm::DILocation::get(
5450 Context&: parentLLVMFn->getContext(), Line: outlinedFnLoc.getLine(),
5451 Column: outlinedFnLoc.getCol(), Scope: SP, InlinedAt: outlinedFnLoc.getInlinedAt()));
5452
5453 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5454 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5455 bool isGPU = ompBuilder->Config.isGPU();
5456
5457 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
5458 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(Val&: opInst);
5459 auto &targetRegion = targetOp.getRegion();
5460 // Holds the private vars that have been mapped along with the block argument
5461 // that corresponds to the MapInfoOp corresponding to the private var in
5462 // question. So, for instance:
5463 //
5464 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
5465 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
5466 //
5467 // Then, %10 has been created so that the descriptor can be used by the
5468 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
5469 // %arg0} in the mappedPrivateVars map.
5470 llvm::DenseMap<Value, Value> mappedPrivateVars;
5471 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
5472 SmallVector<Value> mapVars = targetOp.getMapVars();
5473 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
5474 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
5475 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
5476 llvm::Function *llvmOutlinedFn = nullptr;
5477
5478 // TODO: It can also be false if a compile-time constant `false` IF clause is
5479 // specified.
5480 bool isOffloadEntry =
5481 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5482
5483 // For some private variables, the MapsForPrivatizedVariablesPass
5484 // creates MapInfoOp instances. Go through the private variables and
5485 // the mapped variables so that during codegeneration we are able
5486 // to quickly look up the corresponding map variable, if any for each
5487 // private variable.
5488 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5489 OperandRange privateVars = targetOp.getPrivateVars();
5490 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5491 std::optional<DenseI64ArrayAttr> privateMapIndices =
5492 targetOp.getPrivateMapsAttr();
5493
5494 for (auto [privVarIdx, privVarSymPair] :
5495 llvm::enumerate(First: llvm::zip_equal(t&: privateVars, u&: *privateSyms))) {
5496 auto privVar = std::get<0>(t&: privVarSymPair);
5497 auto privSym = std::get<1>(t&: privVarSymPair);
5498
5499 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(Val&: privSym);
5500 omp::PrivateClauseOp privatizer =
5501 findPrivatizer(from: targetOp, symbolName: privatizerName);
5502
5503 if (!privatizer.needsMap())
5504 continue;
5505
5506 mlir::Value mappedValue =
5507 targetOp.getMappedValueForPrivateVar(privVarIdx);
5508 assert(mappedValue && "Expected to find mapped value for a privatized "
5509 "variable that needs mapping");
5510
5511 // The MapInfoOp defining the map var isn't really needed later.
5512 // So, we don't store it in any datastructure. Instead, we just
5513 // do some sanity checks on it right now.
5514 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
5515 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
5516
5517 // Check #1: Check that the type of the private variable matches
5518 // the type of the variable being mapped.
5519 if (!isa<LLVM::LLVMPointerType>(Val: privVar.getType()))
5520 assert(
5521 varType == privVar.getType() &&
5522 "Type of private var doesn't match the type of the mapped value");
5523
5524 // Ok, only 1 sanity check for now.
5525 // Record the block argument corresponding to this mapvar.
5526 mappedPrivateVars.insert(
5527 KV: {privVar,
5528 targetRegion.getArgument(i: argIface.getMapBlockArgsStart() +
5529 (*privateMapIndices)[privVarIdx])});
5530 }
5531 }
5532
5533 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5534 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5535 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5536 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5537 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5538 // Forward target-cpu and target-features function attributes from the
5539 // original function to the new outlined function.
5540 llvm::Function *llvmParentFn =
5541 moduleTranslation.lookupFunction(name: parentFn.getName());
5542 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5543 assert(llvmParentFn && llvmOutlinedFn &&
5544 "Both parent and outlined functions must exist at this point");
5545
5546 if (outlinedFnLoc && llvmParentFn->getSubprogram())
5547 llvmOutlinedFn->setSubprogram(outlinedFnLoc->getScope()->getSubprogram());
5548
5549 if (auto attr = llvmParentFn->getFnAttribute(Kind: "target-cpu");
5550 attr.isStringAttribute())
5551 llvmOutlinedFn->addFnAttr(Attr: attr);
5552
5553 if (auto attr = llvmParentFn->getFnAttribute(Kind: "target-features");
5554 attr.isStringAttribute())
5555 llvmOutlinedFn->addFnAttr(Attr: attr);
5556
5557 for (auto [arg, mapOp] : llvm::zip_equal(t&: mapBlockArgs, u&: mapVars)) {
5558 auto mapInfoOp = cast<omp::MapInfoOp>(Val: mapOp.getDefiningOp());
5559 llvm::Value *mapOpValue =
5560 moduleTranslation.lookupValue(value: mapInfoOp.getVarPtr());
5561 moduleTranslation.mapValue(mlir: arg, llvm: mapOpValue);
5562 }
5563 for (auto [arg, mapOp] : llvm::zip_equal(t&: hdaBlockArgs, u&: hdaVars)) {
5564 auto mapInfoOp = cast<omp::MapInfoOp>(Val: mapOp.getDefiningOp());
5565 llvm::Value *mapOpValue =
5566 moduleTranslation.lookupValue(value: mapInfoOp.getVarPtr());
5567 moduleTranslation.mapValue(mlir: arg, llvm: mapOpValue);
5568 }
5569
5570 // Do privatization after moduleTranslation has already recorded
5571 // mapped values.
5572 PrivateVarsInfo privateVarsInfo(targetOp);
5573
5574 llvm::Expected<llvm::BasicBlock *> afterAllocas =
5575 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
5576 allocaIP, mappedPrivateVars: &mappedPrivateVars);
5577
5578 if (failed(Result: handleError(result&: afterAllocas, op&: *targetOp)))
5579 return llvm::make_error<PreviouslyReportedError>();
5580
5581 builder.restoreIP(IP: codeGenIP);
5582 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo,
5583 mappedPrivateVars: &mappedPrivateVars),
5584 op&: *targetOp)
5585 .failed())
5586 return llvm::make_error<PreviouslyReportedError>();
5587
5588 if (failed(Result: copyFirstPrivateVars(
5589 op: targetOp, builder, moduleTranslation, mlirPrivateVars&: privateVarsInfo.mlirVars,
5590 llvmPrivateVars: privateVarsInfo.llvmVars, privateDecls&: privateVarsInfo.privatizers,
5591 insertBarrier: targetOp.getPrivateNeedsBarrier(), mappedPrivateVars: &mappedPrivateVars)))
5592 return llvm::make_error<PreviouslyReportedError>();
5593
5594 SmallVector<Region *> privateCleanupRegions;
5595 llvm::transform(Range&: privateVarsInfo.privatizers,
5596 d_first: std::back_inserter(x&: privateCleanupRegions),
5597 F: [](omp::PrivateClauseOp privatizer) {
5598 return &privatizer.getDeallocRegion();
5599 });
5600
5601 llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
5602 region&: targetRegion, blockName: "omp.target", builder, moduleTranslation);
5603
5604 if (!exitBlock)
5605 return exitBlock.takeError();
5606
5607 builder.SetInsertPoint(*exitBlock);
5608 if (!privateCleanupRegions.empty()) {
5609 if (failed(Result: inlineOmpRegionCleanup(
5610 cleanupRegions&: privateCleanupRegions, privateVariables: privateVarsInfo.llvmVars,
5611 moduleTranslation, builder, regionName: "omp.targetop.private.cleanup",
5612 /*shouldLoadCleanupRegionArg=*/false))) {
5613 return llvm::createStringError(
5614 Fmt: "failed to inline `dealloc` region of `omp.private` "
5615 "op in the target region");
5616 }
5617 return builder.saveIP();
5618 }
5619
5620 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5621 };
5622
5623 StringRef parentName = parentFn.getName();
5624
5625 llvm::TargetRegionEntryInfo entryInfo;
5626
5627 getTargetEntryUniqueInfo(targetInfo&: entryInfo, targetOp, parentName);
5628
5629 MapInfoData mapData;
5630 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5631 builder, /*useDevPtrOperands=*/{},
5632 /*useDevAddrOperands=*/{}, hasDevAddrOperands: hdaVars);
5633
5634 MapInfosTy combinedInfos;
5635 auto genMapInfoCB =
5636 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5637 builder.restoreIP(IP: codeGenIP);
5638 genMapInfos(builder, moduleTranslation, dl, combinedInfo&: combinedInfos, mapData, isTargetParams: true);
5639 return combinedInfos;
5640 };
5641
5642 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
5643 llvm::Value *&retVal, InsertPointTy allocaIP,
5644 InsertPointTy codeGenIP)
5645 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5646 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5647 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5648 // We just return the unaltered argument for the host function
5649 // for now, some alterations may be required in the future to
5650 // keep host fallback functions working identically to the device
5651 // version (e.g. pass ByCopy values should be treated as such on
5652 // host and device, currently not always the case)
5653 if (!isTargetDevice) {
5654 retVal = cast<llvm::Value>(Val: &arg);
5655 return codeGenIP;
5656 }
5657
5658 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
5659 ompBuilder&: *ompBuilder, moduleTranslation,
5660 allocaIP, codeGenIP);
5661 };
5662
5663 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5664 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5665 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5666 initTargetDefaultAttrs(targetOp, capturedOp: targetCapturedOp, attrs&: defaultAttrs,
5667 isTargetDevice, isGPU);
5668
5669 // Collect host-evaluated values needed to properly launch the kernel from the
5670 // host.
5671 if (!isTargetDevice)
5672 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
5673 capturedOp: targetCapturedOp, attrs&: runtimeAttrs);
5674
5675 // Pass host-evaluated values as parameters to the kernel / host fallback,
5676 // except if they are constants. In any case, map the MLIR block argument to
5677 // the corresponding LLVM values.
5678 llvm::SmallVector<llvm::Value *, 4> kernelInput;
5679 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
5680 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
5681 for (auto [arg, var] : llvm::zip_equal(t&: hostEvalBlockArgs, u&: hostEvalVars)) {
5682 llvm::Value *value = moduleTranslation.lookupValue(value: var);
5683 moduleTranslation.mapValue(mlir: arg, llvm: value);
5684
5685 if (!llvm::isa<llvm::Constant>(Val: value))
5686 kernelInput.push_back(Elt: value);
5687 }
5688
5689 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5690 // declare target arguments are not passed to kernels as arguments
5691 // TODO: We currently do not handle cases where a member is explicitly
5692 // passed in as an argument, this will likley need to be handled in
5693 // the near future, rather than using IsAMember, it may be better to
5694 // test if the relevant BlockArg is used within the target region and
5695 // then use that as a basis for exclusion in the kernel inputs.
5696 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5697 kernelInput.push_back(Elt: mapData.OriginalValue[i]);
5698 }
5699
5700 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
5701 buildDependData(dependKinds: targetOp.getDependKinds(), dependVars: targetOp.getDependVars(),
5702 moduleTranslation, dds);
5703
5704 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5705 findAllocaInsertPoint(builder, moduleTranslation);
5706 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5707
5708 llvm::OpenMPIRBuilder::TargetDataInfo info(
5709 /*RequiresDevicePointerInfo=*/false,
5710 /*SeparateBeginEndCalls=*/true);
5711
5712 auto customMapperCB =
5713 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5714 if (!combinedInfos.Mappers[i])
5715 return nullptr;
5716 info.HasMapper = true;
5717 return getOrCreateUserDefinedMapperFunc(op: combinedInfos.Mappers[i], builder,
5718 moduleTranslation);
5719 };
5720
5721 llvm::Value *ifCond = nullptr;
5722 if (Value targetIfCond = targetOp.getIfExpr())
5723 ifCond = moduleTranslation.lookupValue(value: targetIfCond);
5724
5725 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5726 moduleTranslation.getOpenMPBuilder()->createTarget(
5727 Loc: ompLoc, IsOffloadEntry: isOffloadEntry, AllocaIP: allocaIP, CodeGenIP: builder.saveIP(), Info&: info, EntryInfo&: entryInfo,
5728 DefaultAttrs: defaultAttrs, RuntimeAttrs: runtimeAttrs, IfCond: ifCond, Inputs&: kernelInput, GenMapInfoCB: genMapInfoCB, BodyGenCB: bodyCB,
5729 ArgAccessorFuncCB: argAccessorCB, CustomMapperCB: customMapperCB, Dependencies: dds, HasNowait: targetOp.getNowait());
5730
5731 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
5732 return failure();
5733
5734 builder.restoreIP(IP: *afterIP);
5735
5736 // Remap access operations to declare target reference pointers for the
5737 // device, essentially generating extra loadop's as necessary
5738 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
5739 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
5740 func: llvmOutlinedFn);
5741
5742 return success();
5743}
5744
5745static LogicalResult
5746convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
5747 LLVM::ModuleTranslation &moduleTranslation) {
5748 // Amend omp.declare_target by deleting the IR of the outlined functions
5749 // created for target regions. They cannot be filtered out from MLIR earlier
5750 // because the omp.target operation inside must be translated to LLVM, but
5751 // the wrapper functions themselves must not remain at the end of the
5752 // process. We know that functions where omp.declare_target does not match
5753 // omp.is_target_device at this stage can only be wrapper functions because
5754 // those that aren't are removed earlier as an MLIR transformation pass.
5755 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(Val: op)) {
5756 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5757 Val: op->getParentOfType<ModuleOp>().getOperation())) {
5758 if (!offloadMod.getIsTargetDevice())
5759 return success();
5760
5761 omp::DeclareTargetDeviceType declareType =
5762 attribute.getDeviceType().getValue();
5763
5764 if (declareType == omp::DeclareTargetDeviceType::host) {
5765 llvm::Function *llvmFunc =
5766 moduleTranslation.lookupFunction(name: funcOp.getName());
5767 llvmFunc->dropAllReferences();
5768 llvmFunc->eraseFromParent();
5769 }
5770 }
5771 return success();
5772 }
5773
5774 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(Val: op)) {
5775 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
5776 if (auto *gVal = llvmModule->getNamedValue(Name: gOp.getSymName())) {
5777 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5778 bool isDeclaration = gOp.isDeclaration();
5779 bool isExternallyVisible =
5780 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
5781 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
5782 llvm::StringRef mangledName = gOp.getSymName();
5783 auto captureClause =
5784 convertToCaptureClauseKind(captureClause: attribute.getCaptureClause().getValue());
5785 auto deviceClause =
5786 convertToDeviceClauseKind(deviceClause: attribute.getDeviceType().getValue());
5787 // unused for MLIR at the moment, required in Clang for book
5788 // keeping
5789 std::vector<llvm::GlobalVariable *> generatedRefs;
5790
5791 std::vector<llvm::Triple> targetTriple;
5792 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5793 Val: op->getParentOfType<mlir::ModuleOp>()->getAttr(
5794 name: LLVM::LLVMDialect::getTargetTripleAttrName()));
5795 if (targetTripleAttr)
5796 targetTriple.emplace_back(args: targetTripleAttr.data());
5797
5798 auto fileInfoCallBack = [&loc]() {
5799 std::string filename = "";
5800 std::uint64_t lineNo = 0;
5801
5802 if (loc) {
5803 filename = loc.getFilename().str();
5804 lineNo = loc.getLine();
5805 }
5806
5807 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5808 lineNo);
5809 };
5810
5811 ompBuilder->registerTargetGlobalVariable(
5812 CaptureClause: captureClause, DeviceClause: deviceClause, IsDeclaration: isDeclaration, IsExternallyVisible: isExternallyVisible,
5813 EntryInfo: ompBuilder->getTargetEntryUniqueInfo(CallBack: fileInfoCallBack), MangledName: mangledName,
5814 GeneratedRefs&: generatedRefs, /*OpenMPSimd*/ OpenMPSIMD: false, TargetTriple: targetTriple,
5815 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
5816 LlvmPtrTy: gVal->getType(), Addr: gVal);
5817
5818 if (ompBuilder->Config.isTargetDevice() &&
5819 (attribute.getCaptureClause().getValue() !=
5820 mlir::omp::DeclareTargetCaptureClause::to ||
5821 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5822 ompBuilder->getAddrOfDeclareTargetVar(
5823 CaptureClause: captureClause, DeviceClause: deviceClause, IsDeclaration: isDeclaration, IsExternallyVisible: isExternallyVisible,
5824 EntryInfo: ompBuilder->getTargetEntryUniqueInfo(CallBack: fileInfoCallBack), MangledName: mangledName,
5825 GeneratedRefs&: generatedRefs, /*OpenMPSimd*/ OpenMPSIMD: false, TargetTriple: targetTriple, LlvmPtrTy: gVal->getType(),
5826 /*GlobalInitializer*/ nullptr,
5827 /*VariableLinkage*/ nullptr);
5828 }
5829 }
5830 }
5831
5832 return success();
5833}
5834
5835// Returns true if the operation is inside a TargetOp or
5836// is part of a declare target function.
5837static bool isTargetDeviceOp(Operation *op) {
5838 // Assumes no reverse offloading
5839 if (op->getParentOfType<omp::TargetOp>())
5840 return true;
5841
5842 // Certain operations return results, and whether utilised in host or
5843 // target there is a chance an LLVM Dialect operation depends on it
5844 // by taking it in as an operand, so we must always lower these in
5845 // some manner or result in an ICE (whether they end up in a no-op
5846 // or otherwise).
5847 if (mlir::isa<omp::ThreadprivateOp>(Val: op))
5848 return true;
5849
5850 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
5851 if (auto declareTargetIface =
5852 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5853 Val: parentFn.getOperation()))
5854 if (declareTargetIface.isDeclareTarget() &&
5855 declareTargetIface.getDeclareTargetDeviceType() !=
5856 mlir::omp::DeclareTargetDeviceType::host)
5857 return true;
5858
5859 return false;
5860}
5861
5862/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
5863/// OpenMP runtime calls).
5864static LogicalResult
5865convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
5866 LLVM::ModuleTranslation &moduleTranslation) {
5867 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5868
5869 // For each loop, introduce one stack frame to hold loop information. Ensure
5870 // this is only done for the outermost loop wrapper to prevent introducing
5871 // multiple stack frames for a single loop. Initially set to null, the loop
5872 // information structure is initialized during translation of the nested
5873 // omp.loop_nest operation, making it available to translation of all loop
5874 // wrappers after their body has been successfully translated.
5875 bool isOutermostLoopWrapper =
5876 isa_and_present<omp::LoopWrapperInterface>(Val: op) &&
5877 !dyn_cast_if_present<omp::LoopWrapperInterface>(Val: op->getParentOp());
5878
5879 if (isOutermostLoopWrapper)
5880 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
5881
5882 auto result =
5883 llvm::TypeSwitch<Operation *, LogicalResult>(op)
5884 .Case(caseFn: [&](omp::BarrierOp op) -> LogicalResult {
5885 if (failed(Result: checkImplementationStatus(op&: *op)))
5886 return failure();
5887
5888 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5889 ompBuilder->createBarrier(Loc: builder.saveIP(),
5890 Kind: llvm::omp::OMPD_barrier);
5891 LogicalResult res = handleError(result&: afterIP, op&: *op);
5892 if (res.succeeded()) {
5893 // If the barrier generated a cancellation check, the insertion
5894 // point might now need to be changed to a new continuation block
5895 builder.restoreIP(IP: *afterIP);
5896 }
5897 return res;
5898 })
5899 .Case(caseFn: [&](omp::TaskyieldOp op) {
5900 if (failed(Result: checkImplementationStatus(op&: *op)))
5901 return failure();
5902
5903 ompBuilder->createTaskyield(Loc: builder.saveIP());
5904 return success();
5905 })
5906 .Case(caseFn: [&](omp::FlushOp op) {
5907 if (failed(Result: checkImplementationStatus(op&: *op)))
5908 return failure();
5909
5910 // No support in Openmp runtime function (__kmpc_flush) to accept
5911 // the argument list.
5912 // OpenMP standard states the following:
5913 // "An implementation may implement a flush with a list by ignoring
5914 // the list, and treating it the same as a flush without a list."
5915 //
5916 // The argument list is discarded so that, flush with a list is
5917 // treated same as a flush without a list.
5918 ompBuilder->createFlush(Loc: builder.saveIP());
5919 return success();
5920 })
5921 .Case(caseFn: [&](omp::ParallelOp op) {
5922 return convertOmpParallel(opInst: op, builder, moduleTranslation);
5923 })
5924 .Case(caseFn: [&](omp::MaskedOp) {
5925 return convertOmpMasked(opInst&: *op, builder, moduleTranslation);
5926 })
5927 .Case(caseFn: [&](omp::MasterOp) {
5928 return convertOmpMaster(opInst&: *op, builder, moduleTranslation);
5929 })
5930 .Case(caseFn: [&](omp::CriticalOp) {
5931 return convertOmpCritical(opInst&: *op, builder, moduleTranslation);
5932 })
5933 .Case(caseFn: [&](omp::OrderedRegionOp) {
5934 return convertOmpOrderedRegion(opInst&: *op, builder, moduleTranslation);
5935 })
5936 .Case(caseFn: [&](omp::OrderedOp) {
5937 return convertOmpOrdered(opInst&: *op, builder, moduleTranslation);
5938 })
5939 .Case(caseFn: [&](omp::WsloopOp) {
5940 return convertOmpWsloop(opInst&: *op, builder, moduleTranslation);
5941 })
5942 .Case(caseFn: [&](omp::SimdOp) {
5943 return convertOmpSimd(opInst&: *op, builder, moduleTranslation);
5944 })
5945 .Case(caseFn: [&](omp::AtomicReadOp) {
5946 return convertOmpAtomicRead(opInst&: *op, builder, moduleTranslation);
5947 })
5948 .Case(caseFn: [&](omp::AtomicWriteOp) {
5949 return convertOmpAtomicWrite(opInst&: *op, builder, moduleTranslation);
5950 })
5951 .Case(caseFn: [&](omp::AtomicUpdateOp op) {
5952 return convertOmpAtomicUpdate(opInst&: op, builder, moduleTranslation);
5953 })
5954 .Case(caseFn: [&](omp::AtomicCaptureOp op) {
5955 return convertOmpAtomicCapture(atomicCaptureOp: op, builder, moduleTranslation);
5956 })
5957 .Case(caseFn: [&](omp::CancelOp op) {
5958 return convertOmpCancel(op, builder, moduleTranslation);
5959 })
5960 .Case(caseFn: [&](omp::CancellationPointOp op) {
5961 return convertOmpCancellationPoint(op, builder, moduleTranslation);
5962 })
5963 .Case(caseFn: [&](omp::SectionsOp) {
5964 return convertOmpSections(opInst&: *op, builder, moduleTranslation);
5965 })
5966 .Case(caseFn: [&](omp::SingleOp op) {
5967 return convertOmpSingle(singleOp&: op, builder, moduleTranslation);
5968 })
5969 .Case(caseFn: [&](omp::TeamsOp op) {
5970 return convertOmpTeams(op, builder, moduleTranslation);
5971 })
5972 .Case(caseFn: [&](omp::TaskOp op) {
5973 return convertOmpTaskOp(taskOp: op, builder, moduleTranslation);
5974 })
5975 .Case(caseFn: [&](omp::TaskgroupOp op) {
5976 return convertOmpTaskgroupOp(tgOp: op, builder, moduleTranslation);
5977 })
5978 .Case(caseFn: [&](omp::TaskwaitOp op) {
5979 return convertOmpTaskwaitOp(twOp: op, builder, moduleTranslation);
5980 })
5981 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5982 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5983 omp::CriticalDeclareOp>(caseFn: [](auto op) {
5984 // `yield` and `terminator` can be just omitted. The block structure
5985 // was created in the region that handles their parent operation.
5986 // `declare_reduction` will be used by reductions and is not
5987 // converted directly, skip it.
5988 // `declare_mapper` and `declare_mapper.info` are handled whenever
5989 // they are referred to through a `map` clause.
5990 // `critical.declare` is only used to declare names of critical
5991 // sections which will be used by `critical` ops and hence can be
5992 // ignored for lowering. The OpenMP IRBuilder will create unique
5993 // name for critical section names.
5994 return success();
5995 })
5996 .Case(caseFn: [&](omp::ThreadprivateOp) {
5997 return convertOmpThreadprivate(opInst&: *op, builder, moduleTranslation);
5998 })
5999 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
6000 omp::TargetExitDataOp, omp::TargetUpdateOp>(caseFn: [&](auto op) {
6001 return convertOmpTargetData(op, builder, moduleTranslation);
6002 })
6003 .Case(caseFn: [&](omp::TargetOp) {
6004 return convertOmpTarget(opInst&: *op, builder, moduleTranslation);
6005 })
6006 .Case(caseFn: [&](omp::DistributeOp) {
6007 return convertOmpDistribute(opInst&: *op, builder, moduleTranslation);
6008 })
6009 .Case(caseFn: [&](omp::LoopNestOp) {
6010 return convertOmpLoopNest(opInst&: *op, builder, moduleTranslation);
6011 })
6012 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
6013 caseFn: [&](auto op) {
6014 // No-op, should be handled by relevant owning operations e.g.
6015 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
6016 // etc. and then discarded
6017 return success();
6018 })
6019 .Case(caseFn: [&](omp::NewCliOp op) {
6020 // Meta-operation: Doesn't do anything by itself, but used to
6021 // identify a loop.
6022 return success();
6023 })
6024 .Case(caseFn: [&](omp::CanonicalLoopOp op) {
6025 return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
6026 })
6027 .Case(caseFn: [&](omp::UnrollHeuristicOp op) {
6028 // FIXME: Handling omp.unroll_heuristic as an executable requires
6029 // that the generator (e.g. omp.canonical_loop) has been seen first.
6030 // For construct that require all codegen to occur inside a callback
6031 // (e.g. OpenMPIRBilder::createParallel), all codegen of that
6032 // contained region including their transformations must occur at
6033 // the omp.canonical_loop.
6034 return applyUnrollHeuristic(op, builder, moduleTranslation);
6035 })
6036 .Default(defaultFn: [&](Operation *inst) {
6037 return inst->emitError()
6038 << "not yet implemented: " << inst->getName();
6039 });
6040
6041 if (isOutermostLoopWrapper)
6042 moduleTranslation.stackPop();
6043
6044 return result;
6045}
6046
6047static LogicalResult
6048convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
6049 LLVM::ModuleTranslation &moduleTranslation) {
6050 return convertHostOrTargetOperation(op, builder, moduleTranslation);
6051}
6052
6053static LogicalResult
6054convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
6055 LLVM::ModuleTranslation &moduleTranslation) {
6056 if (isa<omp::TargetOp>(Val: op))
6057 return convertOmpTarget(opInst&: *op, builder, moduleTranslation);
6058 if (isa<omp::TargetDataOp>(Val: op))
6059 return convertOmpTargetData(op, builder, moduleTranslation);
6060 bool interrupted =
6061 op->walk<WalkOrder::PreOrder>(callback: [&](Operation *oper) {
6062 if (isa<omp::TargetOp>(Val: oper)) {
6063 if (failed(Result: convertOmpTarget(opInst&: *oper, builder, moduleTranslation)))
6064 return WalkResult::interrupt();
6065 return WalkResult::skip();
6066 }
6067 if (isa<omp::TargetDataOp>(Val: oper)) {
6068 if (failed(Result: convertOmpTargetData(op: oper, builder, moduleTranslation)))
6069 return WalkResult::interrupt();
6070 return WalkResult::skip();
6071 }
6072
6073 // Non-target ops might nest target-related ops, therefore, we
6074 // translate them as non-OpenMP scopes. Translating them is needed by
6075 // nested target-related ops since they might need LLVM values defined
6076 // in their parent non-target ops.
6077 if (isa<omp::OpenMPDialect>(Val: oper->getDialect()) &&
6078 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
6079 !oper->getRegions().empty()) {
6080 if (auto blockArgsIface =
6081 dyn_cast<omp::BlockArgOpenMPOpInterface>(Val: oper))
6082 forwardArgs(moduleTranslation, blockArgIface: blockArgsIface);
6083 else {
6084 // Here we map entry block arguments of
6085 // non-BlockArgOpenMPOpInterface ops if they can be encountered
6086 // inside of a function and they define any of these arguments.
6087 if (isa<mlir::omp::AtomicUpdateOp>(Val: oper))
6088 for (auto [operand, arg] :
6089 llvm::zip_equal(t: oper->getOperands(),
6090 u: oper->getRegion(index: 0).getArguments())) {
6091 moduleTranslation.mapValue(
6092 mlir: arg, llvm: builder.CreateLoad(
6093 Ty: moduleTranslation.convertType(type: arg.getType()),
6094 Ptr: moduleTranslation.lookupValue(value: operand)));
6095 }
6096 }
6097
6098 if (auto loopNest = dyn_cast<omp::LoopNestOp>(Val: oper)) {
6099 assert(builder.GetInsertBlock() &&
6100 "No insert block is set for the builder");
6101 for (auto iv : loopNest.getIVs()) {
6102 // Map iv to an undefined value just to keep the IR validity.
6103 moduleTranslation.mapValue(
6104 mlir: iv, llvm: llvm::PoisonValue::get(
6105 T: moduleTranslation.convertType(type: iv.getType())));
6106 }
6107 }
6108
6109 for (Region &region : oper->getRegions()) {
6110 // Regions are fake in the sense that they are not a truthful
6111 // translation of the OpenMP construct being converted (e.g. no
6112 // OpenMP runtime calls will be generated). We just need this to
6113 // prepare the kernel invocation args.
6114 SmallVector<llvm::PHINode *> phis;
6115 auto result = convertOmpOpRegions(
6116 region, blockName: oper->getName().getStringRef().str() + ".fake.region",
6117 builder, moduleTranslation, continuationBlockPHIs: &phis);
6118 if (failed(Result: handleError(result, op&: *oper)))
6119 return WalkResult::interrupt();
6120
6121 builder.SetInsertPoint(TheBB: result.get(), IP: result.get()->end());
6122 }
6123
6124 return WalkResult::skip();
6125 }
6126
6127 return WalkResult::advance();
6128 }).wasInterrupted();
6129 return failure(IsFailure: interrupted);
6130}
6131
6132namespace {
6133
6134/// Implementation of the dialect interface that converts operations belonging
6135/// to the OpenMP dialect to LLVM IR.
6136class OpenMPDialectLLVMIRTranslationInterface
6137 : public LLVMTranslationDialectInterface {
6138public:
6139 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
6140
6141 /// Translates the given operation to LLVM IR using the provided IR builder
6142 /// and saving the state in `moduleTranslation`.
6143 LogicalResult
6144 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
6145 LLVM::ModuleTranslation &moduleTranslation) const final;
6146
6147 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
6148 /// runtime calls, or operation amendments
6149 LogicalResult
6150 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
6151 NamedAttribute attribute,
6152 LLVM::ModuleTranslation &moduleTranslation) const final;
6153};
6154
6155} // namespace
6156
6157LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
6158 Operation *op, ArrayRef<llvm::Instruction *> instructions,
6159 NamedAttribute attribute,
6160 LLVM::ModuleTranslation &moduleTranslation) const {
6161 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
6162 attribute.getName())
6163 .Case(S: "omp.is_target_device",
6164 Value: [&](Attribute attr) {
6165 if (auto deviceAttr = dyn_cast<BoolAttr>(Val&: attr)) {
6166 llvm::OpenMPIRBuilderConfig &config =
6167 moduleTranslation.getOpenMPBuilder()->Config;
6168 config.setIsTargetDevice(deviceAttr.getValue());
6169 return success();
6170 }
6171 return failure();
6172 })
6173 .Case(S: "omp.is_gpu",
6174 Value: [&](Attribute attr) {
6175 if (auto gpuAttr = dyn_cast<BoolAttr>(Val&: attr)) {
6176 llvm::OpenMPIRBuilderConfig &config =
6177 moduleTranslation.getOpenMPBuilder()->Config;
6178 config.setIsGPU(gpuAttr.getValue());
6179 return success();
6180 }
6181 return failure();
6182 })
6183 .Case(S: "omp.host_ir_filepath",
6184 Value: [&](Attribute attr) {
6185 if (auto filepathAttr = dyn_cast<StringAttr>(Val&: attr)) {
6186 llvm::OpenMPIRBuilder *ompBuilder =
6187 moduleTranslation.getOpenMPBuilder();
6188 ompBuilder->loadOffloadInfoMetadata(HostFilePath: filepathAttr.getValue());
6189 return success();
6190 }
6191 return failure();
6192 })
6193 .Case(S: "omp.flags",
6194 Value: [&](Attribute attr) {
6195 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(Val&: attr))
6196 return convertFlagsAttr(op, attribute: rtlAttr, moduleTranslation);
6197 return failure();
6198 })
6199 .Case(S: "omp.version",
6200 Value: [&](Attribute attr) {
6201 if (auto versionAttr = dyn_cast<omp::VersionAttr>(Val&: attr)) {
6202 llvm::OpenMPIRBuilder *ompBuilder =
6203 moduleTranslation.getOpenMPBuilder();
6204 ompBuilder->M.addModuleFlag(Behavior: llvm::Module::Max, Key: "openmp",
6205 Val: versionAttr.getVersion());
6206 return success();
6207 }
6208 return failure();
6209 })
6210 .Case(S: "omp.declare_target",
6211 Value: [&](Attribute attr) {
6212 if (auto declareTargetAttr =
6213 dyn_cast<omp::DeclareTargetAttr>(Val&: attr))
6214 return convertDeclareTargetAttr(op, attribute: declareTargetAttr,
6215 moduleTranslation);
6216 return failure();
6217 })
6218 .Case(S: "omp.requires",
6219 Value: [&](Attribute attr) {
6220 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(Val&: attr)) {
6221 using Requires = omp::ClauseRequires;
6222 Requires flags = requiresAttr.getValue();
6223 llvm::OpenMPIRBuilderConfig &config =
6224 moduleTranslation.getOpenMPBuilder()->Config;
6225 config.setHasRequiresReverseOffload(
6226 bitEnumContainsAll(bits: flags, bit: Requires::reverse_offload));
6227 config.setHasRequiresUnifiedAddress(
6228 bitEnumContainsAll(bits: flags, bit: Requires::unified_address));
6229 config.setHasRequiresUnifiedSharedMemory(
6230 bitEnumContainsAll(bits: flags, bit: Requires::unified_shared_memory));
6231 config.setHasRequiresDynamicAllocators(
6232 bitEnumContainsAll(bits: flags, bit: Requires::dynamic_allocators));
6233 return success();
6234 }
6235 return failure();
6236 })
6237 .Case(S: "omp.target_triples",
6238 Value: [&](Attribute attr) {
6239 if (auto triplesAttr = dyn_cast<ArrayAttr>(Val&: attr)) {
6240 llvm::OpenMPIRBuilderConfig &config =
6241 moduleTranslation.getOpenMPBuilder()->Config;
6242 config.TargetTriples.clear();
6243 config.TargetTriples.reserve(N: triplesAttr.size());
6244 for (Attribute tripleAttr : triplesAttr) {
6245 if (auto tripleStrAttr = dyn_cast<StringAttr>(Val&: tripleAttr))
6246 config.TargetTriples.emplace_back(Args: tripleStrAttr.getValue());
6247 else
6248 return failure();
6249 }
6250 return success();
6251 }
6252 return failure();
6253 })
6254 .Default(Value: [](Attribute) {
6255 // Fall through for omp attributes that do not require lowering.
6256 return success();
6257 })(attribute.getValue());
6258
6259 return failure();
6260}
6261
6262/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
6263/// (including OpenMP runtime calls).
6264LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6265 Operation *op, llvm::IRBuilderBase &builder,
6266 LLVM::ModuleTranslation &moduleTranslation) const {
6267
6268 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6269 if (ompBuilder->Config.isTargetDevice()) {
6270 if (isTargetDeviceOp(op)) {
6271 return convertTargetDeviceOp(op, builder, moduleTranslation);
6272 } else {
6273 return convertTargetOpsInNest(op, builder, moduleTranslation);
6274 }
6275 }
6276 return convertHostOrTargetOperation(op, builder, moduleTranslation);
6277}
6278
6279void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
6280 registry.insert<omp::OpenMPDialect>();
6281 registry.addExtension(extensionFn: +[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
6282 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
6283 });
6284}
6285
6286void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
6287 DialectRegistry registry;
6288 registerOpenMPDialectTranslation(registry);
6289 context.appendDialectRegistry(registry);
6290}
6291

source code of mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp